66from dataclasses import dataclass
77from itertools import combinations
88from math import inf , log
9- from typing import Any , TypeAlias
9+ from typing import Any , Literal , TypeAlias
1010
1111import numpy as np
1212
2626from .tensorkrowch ._history import _recover_contraction_history
2727
2828NumericArray : TypeAlias = np .ndarray [Any , Any ]
29+ TensorElementsSourceName : TypeAlias = EngineName | Literal ["numpy" ]
2930
3031
3132@dataclass (frozen = True )
@@ -41,7 +42,7 @@ class _TensorRecord:
4142
4243 array : NumericArray
4344 axis_names : tuple [str , ...]
44- engine : EngineName
45+ engine : TensorElementsSourceName
4546 name : str
4647
4748
@@ -129,6 +130,70 @@ def _detect_tensor_elements_engine(data: Any) -> tuple[EngineName, Any]:
129130 return _detect_tensor_engine_with_input (data )
130131
131132
133+ def _looks_like_backend_tensor_input (value : Any ) -> bool :
134+ return (
135+ isinstance (value , EinsumTrace )
136+ or _is_tenpy_tensor (value )
137+ or hasattr (value , "tensors" )
138+ or hasattr (value , "leaf_nodes" )
139+ or hasattr (value , "nodes" )
140+ or hasattr (value , "axes_names" )
141+ or (hasattr (value , "axis_names" ) and hasattr (value , "tensor" ))
142+ or (hasattr (value , "inds" ) and hasattr (value , "data" ))
143+ )
144+
145+
146+ def _is_direct_array_like_tensor (value : Any ) -> bool :
147+ if isinstance (value , (str , bytes , bytearray , dict )):
148+ return False
149+ if _looks_like_backend_tensor_input (value ):
150+ return False
151+ if not hasattr (value , "shape" ) and not hasattr (value , "__array__" ):
152+ return False
153+ try :
154+ array = np .asarray (value )
155+ except Exception :
156+ return False
157+ return array .dtype != np .dtype ("O" )
158+
159+
160+ def _direct_array_record_name (index : int , * , total : int ) -> str :
161+ if total <= 1 :
162+ return "Tensor"
163+ return f"Tensor { index + 1 } "
164+
165+
166+ def _extract_direct_array_records (data : Any ) -> list [_TensorRecord ] | None :
167+ if _is_direct_array_like_tensor (data ):
168+ return [
169+ _TensorRecord (
170+ array = _to_numpy_array (data ),
171+ axis_names = (),
172+ engine = "numpy" ,
173+ name = _direct_array_record_name (0 , total = 1 ),
174+ )
175+ ]
176+ if isinstance (data , (str , bytes , bytearray , dict )) or _looks_like_backend_tensor_input (data ):
177+ return None
178+ if not isinstance (data , Iterable ):
179+ return None
180+ try :
181+ items = list (data )
182+ except TypeError :
183+ return None
184+ if not items or not all (_is_direct_array_like_tensor (item ) for item in items ):
185+ return None
186+ return [
187+ _TensorRecord (
188+ array = _to_numpy_array (item ),
189+ axis_names = (),
190+ engine = "numpy" ,
191+ name = _direct_array_record_name (index , total = len (items )),
192+ )
193+ for index , item in enumerate (items )
194+ ]
195+
196+
132197def _normalize_axis_selector (
133198 selector : TensorAxisSelector ,
134199 * ,
@@ -344,6 +409,8 @@ def _matrixize_tensor(
344409
345410def _downsample_axis_mean (array : NumericArray , * , axis : int , target_size : int ) -> NumericArray :
346411 source = np .asarray (array )
412+ if int (target_size ) <= 0 :
413+ raise ValueError ("max_matrix_shape must contain exactly two positive integers." )
347414 length = int (source .shape [axis ])
348415 if length <= target_size :
349416 return source
@@ -409,30 +476,36 @@ def _collect_items(
409476 if isinstance (source , dict ):
410477 raw_items = list (source .values ())
411478 should_sort = False
479+ deduplicate = True
412480 elif attr_sources and any (hasattr (source , attr ) for attr in attr_sources ):
413481 raw_items = _iter_attr_values (source , attr_sources )
414482 should_sort = False
483+ deduplicate = True
415484 elif isinstance (source , (str , bytes , bytearray )):
416485 raise TypeError ("Tensor collection input must be iterable." )
417486 elif isinstance (source , Iterable ):
418487 raw_items = list (source )
419488 should_sort = _is_unordered_collection (source )
489+ deduplicate = False
420490 else :
421491 raise TypeError ("Tensor collection input must be a supported tensor object or iterable." )
422492
423- unique : list [Any ] = []
424- seen : set [int ] = set ()
425- for item in raw_items :
426- if item is None :
427- continue
428- item_id = id (item )
429- if item_id in seen :
430- continue
431- seen .add (item_id )
432- unique .append (item )
493+ if deduplicate :
494+ items : list [Any ] = []
495+ seen : set [int ] = set ()
496+ for item in raw_items :
497+ if item is None :
498+ continue
499+ item_id = id (item )
500+ if item_id in seen :
501+ continue
502+ seen .add (item_id )
503+ items .append (item )
504+ else :
505+ items = [item for item in raw_items if item is not None ]
433506 if should_sort :
434- unique .sort (key = _name_sort_key )
435- return unique
507+ items .sort (key = _name_sort_key )
508+ return items
436509
437510
438511def _to_numpy_array (value : Any ) -> NumericArray :
@@ -728,7 +801,7 @@ def _extract_tensor_records(
728801 data : Any ,
729802 * ,
730803 engine : EngineName | None ,
731- ) -> tuple [EngineName , list [_TensorRecord ]]:
804+ ) -> tuple [TensorElementsSourceName , list [_TensorRecord ]]:
732805 """Extract normalized tensor records from supported public inputs.
733806
734807 Args:
@@ -746,6 +819,9 @@ def _extract_tensor_records(
746819 resolved_engine = engine
747820 prepared_input = data
748821 if resolved_engine is None :
822+ direct_array_records = _extract_direct_array_records (data )
823+ if direct_array_records is not None :
824+ return "numpy" , direct_array_records
749825 resolved_engine , prepared_input = _detect_tensor_elements_engine (data )
750826
751827 package_logger .debug ("Extracting tensor records with engine=%r." , resolved_engine )
@@ -780,6 +856,45 @@ def _format_float(value: float) -> str:
780856 return f"{ value :.4g} "
781857
782858
859+ def _non_nan_values (values : NumericArray ) -> NumericArray :
860+ array = np .asarray (values , dtype = float )
861+ return array [~ np .isnan (array )]
862+
863+
864+ def _safe_min_max_mean_std (values : NumericArray ) -> tuple [float , float , float , float ]:
865+ non_nan = _non_nan_values (values )
866+ if non_nan .size == 0 :
867+ nan = float ("nan" )
868+ return nan , nan , nan , nan
869+ with np .errstate (invalid = "ignore" , divide = "ignore" , over = "ignore" , under = "ignore" ):
870+ return (
871+ float (np .min (non_nan )),
872+ float (np .max (non_nan )),
873+ float (np .mean (non_nan )),
874+ float (np .std (non_nan )),
875+ )
876+
877+
878+ def _safe_nanmean_axis (values : NumericArray , * , axis : int | tuple [int , ...]) -> NumericArray :
879+ array = np .asarray (values , dtype = float )
880+ valid_mask = ~ np .isnan (array )
881+ counts = np .sum (valid_mask , axis = axis )
882+ totals = np .sum (np .where (valid_mask , array , 0.0 ), axis = axis , dtype = float )
883+ result = np .full (np .shape (totals ), np .nan , dtype = float )
884+ np .divide (totals , counts , out = result , where = counts > 0 )
885+ return result
886+
887+
888+ def _safe_nan_norm_axis (values : NumericArray , * , axis : int | tuple [int , ...]) -> NumericArray :
889+ array = np .asarray (values , dtype = float )
890+ valid_mask = ~ np .isnan (array )
891+ squared = np .square (np .where (valid_mask , array , 0.0 ))
892+ totals = np .sum (squared , axis = axis , dtype = float )
893+ counts = np .sum (valid_mask , axis = axis )
894+ norms = np .sqrt (totals )
895+ return np .where (counts > 0 , norms , np .nan )
896+
897+
783898def _format_scalar (value : complex | float ) -> str :
784899 if np .iscomplexobj (value ):
785900 complex_value = complex (value )
@@ -826,8 +941,8 @@ def _build_axis_summary_lines(record: _TensorRecord) -> list[str]:
826941 for axis_index , (axis_name , axis_size ) in enumerate (zip (axis_names , shape , strict = True )):
827942 other_axes = tuple (index for index in range (array .ndim ) if index != axis_index )
828943 if other_axes :
829- marginal_mean = np . nanmean (metrics , axis = other_axes )
830- marginal_norm = np . sqrt ( np . nansum ( np . square ( metrics ) , axis = other_axes ) )
944+ marginal_mean = _safe_nanmean_axis (metrics , axis = other_axes )
945+ marginal_norm = _safe_nan_norm_axis ( metrics , axis = other_axes )
831946 else :
832947 marginal_mean = metrics
833948 marginal_norm = np .abs (metrics )
@@ -1088,31 +1203,27 @@ def _build_stats(record: _TensorRecord) -> _TensorStats:
10881203
10891204 if np .iscomplexobj (array ):
10901205 magnitude = np .abs (flat )
1206+ min_mag , max_mag , mean_mag , std_mag = _safe_min_max_mean_std (magnitude )
1207+ real_min , real_max , _ , _ = _safe_min_max_mean_std (np .real (flat ))
1208+ imag_min , imag_max , _ , _ = _safe_min_max_mean_std (np .imag (flat ))
10911209 lines .append (
10921210 "magnitude: "
1093- f"min={ _format_float (float (np .nanmin (magnitude )))} , "
1094- f"max={ _format_float (float (np .nanmax (magnitude )))} , "
1095- f"mean={ _format_float (float (np .nanmean (magnitude )))} , "
1096- f"std={ _format_float (float (np .nanstd (magnitude )))} "
1097- )
1098- lines .append (
1099- "real range: "
1100- f"{ _format_float (float (np .nanmin (np .real (flat ))))} .. "
1101- f"{ _format_float (float (np .nanmax (np .real (flat ))))} "
1102- )
1103- lines .append (
1104- "imag range: "
1105- f"{ _format_float (float (np .nanmin (np .imag (flat ))))} .. "
1106- f"{ _format_float (float (np .nanmax (np .imag (flat ))))} "
1211+ f"min={ _format_float (min_mag )} , "
1212+ f"max={ _format_float (max_mag )} , "
1213+ f"mean={ _format_float (mean_mag )} , "
1214+ f"std={ _format_float (std_mag )} "
11071215 )
1216+ lines .append (f"real range: { _format_float (real_min )} .. { _format_float (real_max )} " )
1217+ lines .append (f"imag range: { _format_float (imag_min )} .. { _format_float (imag_max )} " )
11081218 else :
11091219 values = np .real (flat )
1220+ value_min , value_max , value_mean , value_std = _safe_min_max_mean_std (values )
11101221 lines .append (
11111222 "stats: "
1112- f"min={ _format_float (float ( np . nanmin ( values )) )} , "
1113- f"max={ _format_float (float ( np . nanmax ( values )) )} , "
1114- f"mean={ _format_float (float ( np . nanmean ( values )) )} , "
1115- f"std={ _format_float (float ( np . nanstd ( values )) )} "
1223+ f"min={ _format_float (value_min )} , "
1224+ f"max={ _format_float (value_max )} , "
1225+ f"mean={ _format_float (value_mean )} , "
1226+ f"std={ _format_float (value_std )} "
11161227 )
11171228
11181229 return _TensorStats (
0 commit comments