Skip to content

Commit 13e7abb

Browse files
committed
fix: harden tensor elements inputs and layout
1 parent 202d083 commit 13e7abb

9 files changed

Lines changed: 262 additions & 47 deletions

File tree

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ to switch between them. The interactive controls are grouped: `basic` (`elements
162162
`diagnostic` (`sign`, `signed_value`, `sparsity`, `nan_inf`, `singular_values`, `eigen_real`,
163163
`eigen_imag`).
164164

165-
- `data`: single tensor, iterable of tensors, supported backend-native tensor collections, or an
165+
- `data`: direct numeric tensor input (for example a NumPy array), direct iterables of tensors
166+
preserving order and duplicates, supported backend-native tensor collections, or an
166167
`EinsumTrace` with live tensor values.
167168
- `engine`: optional backend override. If omitted, the library auto-detects it.
168169
- `config`: tensor-inspection behavior lives here.

docs/guide.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ show_tensor_elements(
6666

6767
| Parameter | Meaning |
6868
| --- | --- |
69-
| `data` | Supported tensor object, tensor collection, backend-native tensor container, or `EinsumTrace` with live tensors. |
69+
| `data` | Direct numeric tensor input, direct iterable of tensors (order preserved, duplicates allowed), backend-native tensor container, or `EinsumTrace` with live tensors. |
7070
| `engine` | Optional explicit backend: `"tensorkrowch"`, `"tensornetwork"`, `"quimb"`, `"tenpy"`, or `"einsum"`. |
7171
| `config` | A `TensorElementsConfig` instance. If omitted, `TensorElementsConfig()` is used. |
7272
| `ax` | Existing Matplotlib axis for single-tensor rendering only. |
@@ -292,6 +292,12 @@ summary without leaving the same figure.
292292
- explicit `TenPyTensorNetwork`
293293
- single TeNPy tensor exposing `to_ndarray()` and `get_leg_labels()`
294294

295+
### Direct tensor inputs
296+
297+
- single NumPy / array-like tensor input
298+
- direct iterables of tensors preserve order and duplicates; they are treated as inspection data,
299+
not as backend container objects
300+
295301
### `einsum`
296302

297303
- `EinsumTrace`

src/tensor_network_viz/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ def show_tensor_elements(
7878
"""Lazily dispatch to :func:`tensor_network_viz.tensor_elements.show_tensor_elements`.
7979
8080
Args:
81-
data: Tensor data accepted by the public tensor-elements entry point.
81+
data: Tensor data accepted by the public tensor-elements entry point, including direct
82+
numeric arrays and iterables of tensors.
8283
engine: Optional backend override.
8384
config: Optional tensor-inspection configuration.
8485
ax: Optional Matplotlib axes for single-tensor rendering.

src/tensor_network_viz/_core/renderer.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,7 @@ def _plot_graph(
491491
build_scene_state: bool = True,
492492
) -> tuple[Figure, RenderedAxes]:
493493
style = config or PlotConfig()
494+
created_figure = ax is None
494495
fig, resolved_ax = _prepare_axes(
495496
ax=ax,
496497
figsize=style.figsize,
@@ -520,13 +521,14 @@ def _plot_graph(
520521
register_contraction_controls_on_figure=register_contraction_controls_on_figure,
521522
build_scene_state=build_scene_state,
522523
)
523-
reserved_bottom = get_reserved_bottom(fig)
524-
fig.subplots_adjust(
525-
left=_FIGURE_ADJUST_LEFT,
526-
right=_FIGURE_ADJUST_RIGHT,
527-
bottom=reserved_bottom,
528-
top=0.98,
529-
)
524+
if created_figure:
525+
reserved_bottom = get_reserved_bottom(fig)
526+
fig.subplots_adjust(
527+
left=_FIGURE_ADJUST_LEFT,
528+
right=_FIGURE_ADJUST_RIGHT,
529+
bottom=reserved_bottom,
530+
top=0.98,
531+
)
530532
return fig, resolved_ax
531533

532534

src/tensor_network_viz/_tensor_elements_data.py

Lines changed: 146 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dataclasses import dataclass
77
from itertools import combinations
88
from math import inf, log
9-
from typing import Any, TypeAlias
9+
from typing import Any, Literal, TypeAlias
1010

1111
import numpy as np
1212

@@ -26,6 +26,7 @@
2626
from .tensorkrowch._history import _recover_contraction_history
2727

2828
NumericArray: TypeAlias = np.ndarray[Any, Any]
29+
TensorElementsSourceName: TypeAlias = EngineName | Literal["numpy"]
2930

3031

3132
@dataclass(frozen=True)
@@ -41,7 +42,7 @@ class _TensorRecord:
4142

4243
array: NumericArray
4344
axis_names: tuple[str, ...]
44-
engine: EngineName
45+
engine: TensorElementsSourceName
4546
name: str
4647

4748

@@ -129,6 +130,70 @@ def _detect_tensor_elements_engine(data: Any) -> tuple[EngineName, Any]:
129130
return _detect_tensor_engine_with_input(data)
130131

131132

133+
def _looks_like_backend_tensor_input(value: Any) -> bool:
134+
return (
135+
isinstance(value, EinsumTrace)
136+
or _is_tenpy_tensor(value)
137+
or hasattr(value, "tensors")
138+
or hasattr(value, "leaf_nodes")
139+
or hasattr(value, "nodes")
140+
or hasattr(value, "axes_names")
141+
or (hasattr(value, "axis_names") and hasattr(value, "tensor"))
142+
or (hasattr(value, "inds") and hasattr(value, "data"))
143+
)
144+
145+
146+
def _is_direct_array_like_tensor(value: Any) -> bool:
147+
if isinstance(value, (str, bytes, bytearray, dict)):
148+
return False
149+
if _looks_like_backend_tensor_input(value):
150+
return False
151+
if not hasattr(value, "shape") and not hasattr(value, "__array__"):
152+
return False
153+
try:
154+
array = np.asarray(value)
155+
except Exception:
156+
return False
157+
return array.dtype != np.dtype("O")
158+
159+
160+
def _direct_array_record_name(index: int, *, total: int) -> str:
161+
if total <= 1:
162+
return "Tensor"
163+
return f"Tensor {index + 1}"
164+
165+
166+
def _extract_direct_array_records(data: Any) -> list[_TensorRecord] | None:
167+
if _is_direct_array_like_tensor(data):
168+
return [
169+
_TensorRecord(
170+
array=_to_numpy_array(data),
171+
axis_names=(),
172+
engine="numpy",
173+
name=_direct_array_record_name(0, total=1),
174+
)
175+
]
176+
if isinstance(data, (str, bytes, bytearray, dict)) or _looks_like_backend_tensor_input(data):
177+
return None
178+
if not isinstance(data, Iterable):
179+
return None
180+
try:
181+
items = list(data)
182+
except TypeError:
183+
return None
184+
if not items or not all(_is_direct_array_like_tensor(item) for item in items):
185+
return None
186+
return [
187+
_TensorRecord(
188+
array=_to_numpy_array(item),
189+
axis_names=(),
190+
engine="numpy",
191+
name=_direct_array_record_name(index, total=len(items)),
192+
)
193+
for index, item in enumerate(items)
194+
]
195+
196+
132197
def _normalize_axis_selector(
133198
selector: TensorAxisSelector,
134199
*,
@@ -344,6 +409,8 @@ def _matrixize_tensor(
344409

345410
def _downsample_axis_mean(array: NumericArray, *, axis: int, target_size: int) -> NumericArray:
346411
source = np.asarray(array)
412+
if int(target_size) <= 0:
413+
raise ValueError("max_matrix_shape must contain exactly two positive integers.")
347414
length = int(source.shape[axis])
348415
if length <= target_size:
349416
return source
@@ -409,30 +476,36 @@ def _collect_items(
409476
if isinstance(source, dict):
410477
raw_items = list(source.values())
411478
should_sort = False
479+
deduplicate = True
412480
elif attr_sources and any(hasattr(source, attr) for attr in attr_sources):
413481
raw_items = _iter_attr_values(source, attr_sources)
414482
should_sort = False
483+
deduplicate = True
415484
elif isinstance(source, (str, bytes, bytearray)):
416485
raise TypeError("Tensor collection input must be iterable.")
417486
elif isinstance(source, Iterable):
418487
raw_items = list(source)
419488
should_sort = _is_unordered_collection(source)
489+
deduplicate = False
420490
else:
421491
raise TypeError("Tensor collection input must be a supported tensor object or iterable.")
422492

423-
unique: list[Any] = []
424-
seen: set[int] = set()
425-
for item in raw_items:
426-
if item is None:
427-
continue
428-
item_id = id(item)
429-
if item_id in seen:
430-
continue
431-
seen.add(item_id)
432-
unique.append(item)
493+
if deduplicate:
494+
items: list[Any] = []
495+
seen: set[int] = set()
496+
for item in raw_items:
497+
if item is None:
498+
continue
499+
item_id = id(item)
500+
if item_id in seen:
501+
continue
502+
seen.add(item_id)
503+
items.append(item)
504+
else:
505+
items = [item for item in raw_items if item is not None]
433506
if should_sort:
434-
unique.sort(key=_name_sort_key)
435-
return unique
507+
items.sort(key=_name_sort_key)
508+
return items
436509

437510

438511
def _to_numpy_array(value: Any) -> NumericArray:
@@ -728,7 +801,7 @@ def _extract_tensor_records(
728801
data: Any,
729802
*,
730803
engine: EngineName | None,
731-
) -> tuple[EngineName, list[_TensorRecord]]:
804+
) -> tuple[TensorElementsSourceName, list[_TensorRecord]]:
732805
"""Extract normalized tensor records from supported public inputs.
733806
734807
Args:
@@ -746,6 +819,9 @@ def _extract_tensor_records(
746819
resolved_engine = engine
747820
prepared_input = data
748821
if resolved_engine is None:
822+
direct_array_records = _extract_direct_array_records(data)
823+
if direct_array_records is not None:
824+
return "numpy", direct_array_records
749825
resolved_engine, prepared_input = _detect_tensor_elements_engine(data)
750826

751827
package_logger.debug("Extracting tensor records with engine=%r.", resolved_engine)
@@ -780,6 +856,45 @@ def _format_float(value: float) -> str:
780856
return f"{value:.4g}"
781857

782858

859+
def _non_nan_values(values: NumericArray) -> NumericArray:
860+
array = np.asarray(values, dtype=float)
861+
return array[~np.isnan(array)]
862+
863+
864+
def _safe_min_max_mean_std(values: NumericArray) -> tuple[float, float, float, float]:
865+
non_nan = _non_nan_values(values)
866+
if non_nan.size == 0:
867+
nan = float("nan")
868+
return nan, nan, nan, nan
869+
with np.errstate(invalid="ignore", divide="ignore", over="ignore", under="ignore"):
870+
return (
871+
float(np.min(non_nan)),
872+
float(np.max(non_nan)),
873+
float(np.mean(non_nan)),
874+
float(np.std(non_nan)),
875+
)
876+
877+
878+
def _safe_nanmean_axis(values: NumericArray, *, axis: int | tuple[int, ...]) -> NumericArray:
879+
array = np.asarray(values, dtype=float)
880+
valid_mask = ~np.isnan(array)
881+
counts = np.sum(valid_mask, axis=axis)
882+
totals = np.sum(np.where(valid_mask, array, 0.0), axis=axis, dtype=float)
883+
result = np.full(np.shape(totals), np.nan, dtype=float)
884+
np.divide(totals, counts, out=result, where=counts > 0)
885+
return result
886+
887+
888+
def _safe_nan_norm_axis(values: NumericArray, *, axis: int | tuple[int, ...]) -> NumericArray:
889+
array = np.asarray(values, dtype=float)
890+
valid_mask = ~np.isnan(array)
891+
squared = np.square(np.where(valid_mask, array, 0.0))
892+
totals = np.sum(squared, axis=axis, dtype=float)
893+
counts = np.sum(valid_mask, axis=axis)
894+
norms = np.sqrt(totals)
895+
return np.where(counts > 0, norms, np.nan)
896+
897+
783898
def _format_scalar(value: complex | float) -> str:
784899
if np.iscomplexobj(value):
785900
complex_value = complex(value)
@@ -826,8 +941,8 @@ def _build_axis_summary_lines(record: _TensorRecord) -> list[str]:
826941
for axis_index, (axis_name, axis_size) in enumerate(zip(axis_names, shape, strict=True)):
827942
other_axes = tuple(index for index in range(array.ndim) if index != axis_index)
828943
if other_axes:
829-
marginal_mean = np.nanmean(metrics, axis=other_axes)
830-
marginal_norm = np.sqrt(np.nansum(np.square(metrics), axis=other_axes))
944+
marginal_mean = _safe_nanmean_axis(metrics, axis=other_axes)
945+
marginal_norm = _safe_nan_norm_axis(metrics, axis=other_axes)
831946
else:
832947
marginal_mean = metrics
833948
marginal_norm = np.abs(metrics)
@@ -1088,31 +1203,27 @@ def _build_stats(record: _TensorRecord) -> _TensorStats:
10881203

10891204
if np.iscomplexobj(array):
10901205
magnitude = np.abs(flat)
1206+
min_mag, max_mag, mean_mag, std_mag = _safe_min_max_mean_std(magnitude)
1207+
real_min, real_max, _, _ = _safe_min_max_mean_std(np.real(flat))
1208+
imag_min, imag_max, _, _ = _safe_min_max_mean_std(np.imag(flat))
10911209
lines.append(
10921210
"magnitude: "
1093-
f"min={_format_float(float(np.nanmin(magnitude)))}, "
1094-
f"max={_format_float(float(np.nanmax(magnitude)))}, "
1095-
f"mean={_format_float(float(np.nanmean(magnitude)))}, "
1096-
f"std={_format_float(float(np.nanstd(magnitude)))}"
1097-
)
1098-
lines.append(
1099-
"real range: "
1100-
f"{_format_float(float(np.nanmin(np.real(flat))))} .. "
1101-
f"{_format_float(float(np.nanmax(np.real(flat))))}"
1102-
)
1103-
lines.append(
1104-
"imag range: "
1105-
f"{_format_float(float(np.nanmin(np.imag(flat))))} .. "
1106-
f"{_format_float(float(np.nanmax(np.imag(flat))))}"
1211+
f"min={_format_float(min_mag)}, "
1212+
f"max={_format_float(max_mag)}, "
1213+
f"mean={_format_float(mean_mag)}, "
1214+
f"std={_format_float(std_mag)}"
11071215
)
1216+
lines.append(f"real range: {_format_float(real_min)} .. {_format_float(real_max)}")
1217+
lines.append(f"imag range: {_format_float(imag_min)} .. {_format_float(imag_max)}")
11081218
else:
11091219
values = np.real(flat)
1220+
value_min, value_max, value_mean, value_std = _safe_min_max_mean_std(values)
11101221
lines.append(
11111222
"stats: "
1112-
f"min={_format_float(float(np.nanmin(values)))}, "
1113-
f"max={_format_float(float(np.nanmax(values)))}, "
1114-
f"mean={_format_float(float(np.nanmean(values)))}, "
1115-
f"std={_format_float(float(np.nanstd(values)))}"
1223+
f"min={_format_float(value_min)}, "
1224+
f"max={_format_float(value_max)}, "
1225+
f"mean={_format_float(value_mean)}, "
1226+
f"std={_format_float(value_std)}"
11161227
)
11171228

11181229
return _TensorStats(

src/tensor_network_viz/tensor_elements.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,9 @@ def show_tensor_elements(
9595
"""Render tensor values in a single Matplotlib figure.
9696
9797
Args:
98-
data: Single tensor, iterable of tensors, supported backend-native tensor container,
99-
or playback-aware inputs such as ``EinsumTrace``.
98+
data: Direct numeric tensor/array-like input, iterable of tensors (preserving order and
99+
duplicates), supported backend-native tensor container, or playback-aware inputs such
100+
as ``EinsumTrace``.
100101
engine: Optional backend override. When omitted, the backend is inferred from ``data``.
101102
config: Optional tensor-inspection configuration. When omitted,
102103
``TensorElementsConfig()`` is used. The config is ordered from mode/axis selection

src/tensor_network_viz/tensor_elements_config.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,28 @@ class TensorElementsConfig:
7272

7373
def __post_init__(self) -> None:
7474
"""Validate numeric configuration values and normalize percentile input."""
75+
try:
76+
max_rows_raw, max_cols_raw = self.max_matrix_shape
77+
except (TypeError, ValueError) as exc:
78+
raise ValueError(
79+
"max_matrix_shape must contain exactly two positive integers."
80+
) from exc
81+
max_rows = int(max_rows_raw)
82+
max_cols = int(max_cols_raw)
83+
if max_rows <= 0 or max_cols <= 0:
84+
raise ValueError("max_matrix_shape must contain exactly two positive integers.")
85+
object.__setattr__(self, "max_matrix_shape", (max_rows, max_cols))
86+
87+
histogram_bins = int(self.histogram_bins)
88+
if histogram_bins <= 0:
89+
raise ValueError("histogram_bins must be positive.")
90+
object.__setattr__(self, "histogram_bins", histogram_bins)
91+
92+
histogram_max_samples = int(self.histogram_max_samples)
93+
if histogram_max_samples <= 0:
94+
raise ValueError("histogram_max_samples must be positive.")
95+
object.__setattr__(self, "histogram_max_samples", histogram_max_samples)
96+
7597
if int(self.topk_count) <= 0:
7698
raise ValueError("topk_count must be positive.")
7799
if float(self.zero_threshold) <= 0.0:

0 commit comments

Comments
 (0)