From 75eaca2fc11b85824add7cd805cbab12f16e8296 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 7 Nov 2025 05:59:18 -0800 Subject: [PATCH 01/11] group_by_index --- .../segmented_polynomials/utils.py | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/utils.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/utils.py index 1293e282..58361044 100644 --- a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/utils.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/utils.py @@ -116,3 +116,55 @@ def math_dtype_for_naive_method( f"method='naive' does not support math_dtype '{math_dtype}'. " "Supported options are any JAX dtype (e.g., 'float32', 'float64', 'float16', 'bfloat16') or 'tensor_float32'." ) + + +def group_by_index( + primary_idx: jax.Array, + secondary_idx: jax.Array, + max_primary_idx: int, + axis: int = -1, +) -> tuple[jax.Array, jax.Array]: + """Group and reorder indices in CSR-like format along a specified axis. + + Groups ``secondary_idx`` values by their corresponding ``primary_idx`` values, + enabling efficient contiguous access to all elements with the same primary index. + + Args: + primary_idx: Indices to group by along ``axis``. + secondary_idx: Indices to reorder. Must have matching size with ``primary_idx`` along ``axis``. + max_primary_idx: Maximum value in ``primary_idx`` (exclusive). + axis: Axis along which to perform grouping. Defaults to -1. + + Returns: + tuple: ``(indptr, reordered_indices)`` where: + - ``indptr``: Offsets of shape ``(..., max_primary_idx + 1, ...)`` where + ``reordered_indices[..., indptr[k]:indptr[k+1], ...]`` contains all elements + with ``primary_idx == k``. + - ``reordered_indices``: Reordered ``secondary_idx`` with elements grouped by ``primary_idx``. + + Example: + >>> primary_idx = jnp.array([1, 0, 2, 1, 0]) + >>> secondary_idx = jnp.array([10, 20, 30, 40, 50]) + >>> indptr, reordered = group_by_index(primary_idx, secondary_idx, max_primary_idx=3) + >>> reordered[indptr[0]:indptr[1]] # Elements where primary_idx == 0: [20, 50] + >>> reordered[indptr[1]:indptr[2]] # Elements where primary_idx == 1: [10, 40] + """ + assert primary_idx.ndim == secondary_idx.ndim + assert primary_idx.shape[axis] == secondary_idx.shape[axis] + + reordered = jnp.take_along_axis( + secondary_idx, jnp.argsort(primary_idx, axis=axis), axis=axis + ) + + def compute_indptr(p): + return jnp.append( + 0, jnp.cumsum(jnp.zeros((max_primary_idx,), jnp.int32).at[p].add(1)) + ) + + p = jnp.moveaxis(primary_idx, axis, -1) + indptr = jax.vmap(compute_indptr)(p.reshape(-1, p.shape[-1])).reshape( + p.shape[:-1] + (max_primary_idx + 1,) + ) + indptr = jnp.moveaxis(indptr, -1, axis) + + return indptr, reordered From 890978c8c3e9860673e22ea24efe333bab7f40cc Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 7 Nov 2025 06:00:48 -0800 Subject: [PATCH 02/11] options --- .../equivariant_polynomial.py | 6 ++ .../segmented_polynomial.py | 67 ++++++++++++------- .../segmented_polynomial_indexed_linear.py | 3 +- .../segmented_polynomial_naive.py | 3 +- .../segmented_polynomial_uniform_1d.py | 5 +- 5 files changed, 57 insertions(+), 27 deletions(-) diff --git a/cuequivariance_jax/cuequivariance_jax/equivariant_polynomial.py b/cuequivariance_jax/cuequivariance_jax/equivariant_polynomial.py index afbe92c1..d83b0fdf 100644 --- a/cuequivariance_jax/cuequivariance_jax/equivariant_polynomial.py +++ b/cuequivariance_jax/cuequivariance_jax/equivariant_polynomial.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + import jax import jax.numpy as jnp @@ -30,6 +32,7 @@ def equivariant_polynomial( *, method: str = "", math_dtype: str | None = None, + options: dict[str, Any] | None = None, name: str | None = None, precision: jax.lax.Precision = "undefined", ) -> list[cuex.RepArray] | cuex.RepArray: @@ -51,6 +54,7 @@ def equivariant_polynomial( operands. Defaults to None. Note that indices are not supported for all methods. method: Method to use for computation. See :func:`cuex.segmented_polynomial ` for available methods. math_dtype: See :func:`cuex.segmented_polynomial ` for supported options. + options: Optional dictionary of method-specific options. See :func:`cuex.segmented_polynomial ` for supported options. name: Optional name for the operation. Defaults to None. Returns: @@ -186,8 +190,10 @@ def equivariant_polynomial( outputs_shape_dtype, indices, math_dtype=math_dtype, + options=options, name=name, method=method, + precision=precision, ) outputs = [cuex.RepArray(rep, x) for rep, x in zip(poly.outputs, outputs)] diff --git a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py index 849d3271..d12acb4c 100644 --- a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py @@ -17,6 +17,7 @@ import os import warnings from functools import partial +from typing import Any import jax import jax.core @@ -55,6 +56,7 @@ def segmented_polynomial( *, method: str = "", math_dtype: str | None = None, + options: dict[str, Any] | None = None, name: str | None = None, precision: jax.lax.Precision = "undefined", ) -> list[jax.Array]: @@ -79,15 +81,19 @@ def segmented_polynomial( .. note:: The ``"fused_tp"`` method is only available in the PyTorch implementation. - math_dtype: Data type for computational operations. If None, automatically determined from input types. Defaults to None. + math_dtype: (Deprecated) Data type for computational operations. Prefer using ``options`` instead. Defaults to None. + options: Optional dictionary of method-specific options. Defaults to None. - Supported options vary by method: + Common options: - - ``"naive"``: String dtype names (e.g., ``"float32"``, ``"float64"``, ``"float16"``, ``"bfloat16"``). - Also supports ``"tensor_float32"`` for TensorFloat-32 mode. - - ``"uniform_1d"``: String dtype names (e.g., ``"float32"``, ``"float64"``, ``"float16"``, ``"bfloat16"``). - - ``"indexed_linear"``: CUBLAS compute type strings such as ``"CUBLAS_COMPUTE_32F"``, ``"CUBLAS_COMPUTE_32F_FAST_TF32"``, - ``"CUBLAS_COMPUTE_32F_PEDANTIC"``, ``"CUBLAS_COMPUTE_64F"``, etc. + - ``"math_dtype"``: Data type for computational operations. If None, automatically determined from input types. + Supported values vary by method: + + - ``"naive"``: String dtype names (e.g., ``"float32"``, ``"float64"``, ``"float16"``, ``"bfloat16"``). + Also supports ``"tensor_float32"`` for TensorFloat-32 mode. + - ``"uniform_1d"``: String dtype names (e.g., ``"float32"``, ``"float64"``, ``"float16"``, ``"bfloat16"``). + - ``"indexed_linear"``: CUBLAS compute type strings such as ``"CUBLAS_COMPUTE_32F"``, ``"CUBLAS_COMPUTE_32F_FAST_TF32"``, + ``"CUBLAS_COMPUTE_32F_PEDANTIC"``, ``"CUBLAS_COMPUTE_64F"``, etc. name: Optional name for the operation. @@ -163,9 +169,24 @@ def segmented_polynomial( if name is None: name = "segmented_polynomial" - if math_dtype is not None and not isinstance(math_dtype, str): - math_dtype = jnp.dtype(math_dtype).name - assert isinstance(math_dtype, str) or math_dtype is None + if options is None: + options = {} + else: + options = dict(options) + + if math_dtype is not None: + if "math_dtype" in options: + raise ValueError( + "math_dtype provided both as argument and in options dict. " + "Please use only options['math_dtype']." + ) + options["math_dtype"] = math_dtype + + if "math_dtype" in options: + math_dtype_val = options["math_dtype"] + if math_dtype_val is not None and not isinstance(math_dtype_val, str): + options["math_dtype"] = jnp.dtype(math_dtype_val).name + assert isinstance(options["math_dtype"], str) or options["math_dtype"] is None if precision != "undefined": raise ValueError( @@ -289,7 +310,7 @@ def fn(x, n: int): index_configuration=index_configuration, index_mode=index_mode, polynomial=polynomial, - math_dtype=math_dtype, + options=tuple(sorted(options.items())), name=name, ) @@ -352,7 +373,7 @@ def segmented_polynomial_prim( index_configuration: list[list[int]], # maps: buffer index -> unique indices index index_mode: list[list[IndexingMode]], # shared, batched, indexed, repeated polynomial: cue.SegmentedPolynomial, - math_dtype: str | None, + options: tuple[tuple[str, Any], ...], name: str, method: str, return_none_if_empty: bool = False, @@ -383,7 +404,7 @@ def segmented_polynomial_prim( x for x, used in zip(outputs_shape_dtype, used_outputs) if used ), polynomial=polynomial.filter_keep_operands(used_inputs + used_outputs), - math_dtype=math_dtype, + options=options, name=str(name), method=method, ) @@ -441,7 +462,7 @@ def segmented_polynomial_abstract_eval( index_mode: tuple[tuple[IndexingMode, ...], ...], outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...], polynomial: cue.SegmentedPolynomial, - math_dtype: str | None, + options: tuple[tuple[str, Any], ...], name: str, method: str, ) -> tuple[jax.core.ShapedArray, ...]: @@ -457,7 +478,7 @@ def segmented_polynomial_impl( index_mode: tuple[tuple[IndexingMode, ...], ...], outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...], polynomial: cue.SegmentedPolynomial, - math_dtype: str | None, + options: tuple[tuple[str, Any], ...], name: str, method: str, ) -> tuple[jax.Array, ...]: @@ -501,7 +522,7 @@ def segmented_polynomial_impl( indices=indices, index_configuration=index_configuration, polynomial=polynomial, - math_dtype=math_dtype, + options=dict(options), name=name, ) @@ -528,7 +549,7 @@ def segmented_polynomial_jvp( index_mode: tuple[tuple[IndexingMode, ...], ...], outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...], polynomial: cue.SegmentedPolynomial, - math_dtype: str | None, + options: tuple[tuple[str, Any], ...], name: str, method: str, ) -> tuple[tuple[jax.Array, ...], tuple[jax.Array | ad.Zero, ...]]: @@ -549,7 +570,7 @@ def segmented_polynomial_jvp( index_configuration, index_mode, polynomial, - math_dtype, + options, name, method=method, ) @@ -571,7 +592,7 @@ def segmented_polynomial_jvp( jvp_index_configuration, jvp_index_mode, jvp_poly, - math_dtype, + options, name + "_jvp" + "".join("0" if isinstance(t, ad.Zero) else "1" for t in tangents), @@ -588,7 +609,7 @@ def segmented_polynomial_transpose( index_mode: tuple[tuple[IndexingMode, ...], ...], outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...], polynomial: cue.SegmentedPolynomial, - math_dtype: str | None, + options: tuple[tuple[str, Any], ...], name: str, method: str, ) -> tuple[jax.Array | ad.Zero | None, ...]: @@ -629,7 +650,7 @@ def segmented_polynomial_transpose( tr_index_configuration, tr_index_mode, tr_poly, - math_dtype, + options, name + "_T", method=method, return_none_if_empty=True, @@ -652,7 +673,7 @@ def segmented_polynomial_batching( index_mode: tuple[tuple[IndexingMode, ...], ...], outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...], polynomial: cue.SegmentedPolynomial, - math_dtype: str | None, + options: tuple[tuple[str, Any], ...], name: str, method: str, ) -> tuple[tuple[jax.Array, ...], tuple[int, ...]]: @@ -692,7 +713,7 @@ def prepare(input: jax.Array, axis: int | None) -> jax.Array: index_mode=index_mode, outputs_shape_dtype=outputs_shape_dtype, polynomial=polynomial, - math_dtype=math_dtype, + options=options, name=name + "_batching", method=method, ) diff --git a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_indexed_linear.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_indexed_linear.py index dd23fa87..7c43f4a9 100644 --- a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_indexed_linear.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_indexed_linear.py @@ -46,10 +46,11 @@ def execute_indexed_linear( index_configuration: tuple[tuple[int, ...], ...], index_mode: tuple[tuple[IndexingMode, ...], ...], polynomial: cue.SegmentedPolynomial, - math_dtype: str | None, + options, name: str, run_kernel: bool = True, ) -> list[jax.Array]: # output buffers + math_dtype = options.get("math_dtype") num_inputs = len(index_configuration) - len(outputs_shape_dtype) io_buffers = list(inputs) + [ diff --git a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_naive.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_naive.py index 00f0a18f..aa01ce9d 100644 --- a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_naive.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_naive.py @@ -43,9 +43,10 @@ def execute_naive( index_configuration: tuple[tuple[int, ...], ...], index_mode: tuple[tuple[IndexingMode, ...], ...], polynomial: cue.SegmentedPolynomial, - math_dtype: str | None, + options, name: str, ) -> list[jax.Array]: # output buffers + math_dtype = options.get("math_dtype") if any(mode == IndexingMode.REPEATED for modes in index_mode for mode in modes): return execute_indexed_linear( inputs, diff --git a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_uniform_1d.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_uniform_1d.py index cf9bf3a7..b30585a1 100644 --- a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_uniform_1d.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_uniform_1d.py @@ -38,7 +38,7 @@ def execute_uniform_1d( indices: list[jax.Array], index_configuration: tuple[tuple[int, ...], ...], polynomial: cue.SegmentedPolynomial, - math_dtype: str | None, + options: dict, name: str, ) -> list[jax.Array]: error_message = f"Failed to execute 'uniform_1d' method for the following polynomial:\n{polynomial}\n" @@ -151,8 +151,9 @@ def fn(op, d: cue.SegmentedTensorProduct): if len({b.shape[-1] for b in buffers}.union({1})) > 2: raise ValueError(f"Buffer shapes not compatible {[b.shape for b in buffers]}") - if math_dtype is not None: + if "math_dtype" in options: supported_dtypes = {"float32", "float64", "float16", "bfloat16"} + math_dtype = options["math_dtype"] if math_dtype not in supported_dtypes: raise ValueError( f"method='uniform_1d' only supports math_dtype equal to {supported_dtypes}, got '{math_dtype}'." From 41eeaa644236b54c85c1c3a356fb442e74a41425 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 7 Nov 2025 09:32:35 -0800 Subject: [PATCH 03/11] deterministic indexing strategy --- .../segmented_polynomial_uniform_1d.py | 293 +++++++++++++++++- 1 file changed, 279 insertions(+), 14 deletions(-) diff --git a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_uniform_1d.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_uniform_1d.py index b30585a1..303a0217 100644 --- a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_uniform_1d.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_uniform_1d.py @@ -13,13 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +import os import re import warnings import jax import jax.numpy as jnp import numpy as np -from cuequivariance_jax.segmented_polynomials.utils import reshape +from cuequivariance_jax.segmented_polynomials.utils import group_by_index, reshape from packaging import version import cuequivariance as cue @@ -45,9 +46,6 @@ def execute_uniform_1d( index_configuration = np.array(index_configuration) num_batch_axes = index_configuration.shape[1] - assert ( - polynomial.num_inputs + len(outputs_shape_dtype) == index_configuration.shape[0] - ) assert polynomial.num_outputs == len(outputs_shape_dtype) try: @@ -69,9 +67,14 @@ def fn(op, d: cue.SegmentedTensorProduct): polynomial = polynomial.apply_fn(fn) - # We don't use the feature that indices can index themselves - index_configuration = np.concatenate( - [index_configuration, np.full((len(indices), num_batch_axes), -1, np.int32)] + if polynomial.num_inputs + len(outputs_shape_dtype) == index_configuration.shape[0]: + index_configuration = np.concatenate( + [index_configuration, np.full((len(indices), num_batch_axes), -1, np.int32)] + ) + + assert ( + polynomial.num_inputs + len(outputs_shape_dtype) + len(indices) + == index_configuration.shape[0] ) buffers = list(inputs) + list(outputs_shape_dtype) @@ -187,14 +190,276 @@ def fn(op, d: cue.SegmentedTensorProduct): for path in stp.paths: paths.append(Path(path.indices, path.coefficients.item())) - outputs = tensor_product_uniform_1d_jit( - buffers[: polynomial.num_inputs], - buffers[polynomial.num_inputs :], - list(indices), + if options.get("indptr"): + outputs = indptr_optimization_outer( + buffers[: polynomial.num_inputs], + buffers[polynomial.num_inputs :], + list(indices), + index_configuration, + operations=operations, + paths=paths, + math_dtype=compute_dtype, + name=sanitize_string(name), + ) + else: + outputs = tensor_product_uniform_1d_jit( + buffers[: polynomial.num_inputs], + buffers[polynomial.num_inputs :], + list(indices), + index_configuration, + operations=operations, + paths=paths, + math_dtype=compute_dtype, + name=sanitize_string(name), + ) + return [jnp.reshape(x, y.shape) for x, y in zip(outputs, outputs_shape_dtype)] + + +def indptr_optimization_outer( + inputs: list[jax.Array], + outputs: list[jax.Array], + indices: list[jax.Array], + index_configuration: np.ndarray, + operations: list, + paths: list, + math_dtype: jnp.dtype, + name: str, +): + from collections import defaultdict + + from cuequivariance_ops_jax import Operation + + ni, no = len(inputs), len(outputs) + assert index_configuration.shape[0] == ni + no + len(indices) + assert np.all(index_configuration[ni + no :] == -1) + + output_configs = index_configuration[ni : ni + no] + unique_configs, group_assignments = np.unique( + output_configs, axis=0, return_inverse=True + ) + + groups = defaultdict(list) + for output_idx, group_idx in enumerate(group_assignments): + groups[group_idx].append(output_idx) + + if os.environ.get("CUEQUIVARIANCE_DEBUG_UNIFORM_1D"): + print(f"\n{'=' * 80}") + print(f"šŸŽÆ indptr_optimization_outer: {name}") + print(f"{'=' * 80}") + print( + f"šŸ“Š {ni} inputs, {no} outputs, {len(indices)} indices, {len(operations)} ops, {len(paths)} paths" + ) + print(f"šŸ”¢ Input shapes: {[tuple(x.shape) for x in inputs]}") + print(f"šŸ“¦ Output shapes: {[tuple(x.shape) for x in outputs]}") + print(f"šŸŽ² Index shapes: {[tuple(x.shape) for x in indices]}") + print(f"\nšŸ“‹ Index Configuration ({index_configuration.shape}):") + print(f" Inputs [{0:2d}:{ni:2d}]: {index_configuration[:ni].tolist()}") + print( + f" Outputs [{ni:2d}:{ni + no:2d}]: {index_configuration[ni : ni + no].tolist()}" + ) + print( + f" Indices [{ni + no:2d}:{len(index_configuration):2d}]: {index_configuration[ni + no :].tolist()}" + ) + print(f"šŸŽØ Found {len(groups)} unique output groups:") + for group_idx, output_indices in groups.items(): + print( + f" Group {group_idx}: {unique_configs[group_idx].tolist()} → outputs {output_indices} ({len(output_indices)} outputs)" + ) + print(f"{'=' * 80}\n") + + result_outputs = [None] * no + + for group_idx, output_indices in groups.items(): + group_outputs = [outputs[i] for i in output_indices] + output_buffer_map = { + ni + old_idx: ni + new_idx for new_idx, old_idx in enumerate(output_indices) + } + + group_operations = [ + Operation( + tuple( + output_buffer_map.get(b, b) if ni <= b < ni + no else b + for b in op.buffers + ), + op.start_path, + op.num_paths, + ) + for op in operations + if any(ni <= b < ni + no and (b - ni) in output_indices for b in op.buffers) + ] + + group_index_config = np.concatenate( + [ + index_configuration[:ni], + index_configuration[ni : ni + no][output_indices], + index_configuration[ni + no :], + ] + ) + + used_indices = sorted(set(group_index_config.flatten()) - {-1}) + if len(used_indices) < len(indices): + index_map = { + old_idx: new_idx for new_idx, old_idx in enumerate(used_indices) + } + group_indices = [indices[i] for i in used_indices] + group_index_config_remapped = group_index_config.copy() + for i in range(group_index_config_remapped.shape[0]): + for j in range(group_index_config_remapped.shape[1]): + val = group_index_config_remapped[i, j] + if val != -1: + group_index_config_remapped[i, j] = index_map[val] + no_group = len(group_outputs) + rows_to_keep = list(range(ni + no_group)) + [ + ni + no_group + i for i in used_indices + ] + group_index_config_remapped = group_index_config_remapped[rows_to_keep] + else: + group_indices = indices + group_index_config_remapped = group_index_config + + assert group_index_config_remapped.shape[0] == ni + len(group_outputs) + len( + group_indices + ) + + group_result = indptr_optimization_inner( + inputs, + group_outputs, + group_indices, + group_index_config_remapped, + operations=group_operations, + paths=paths, + math_dtype=math_dtype, + name=f"{name}_group{group_idx}", + ) + + for new_idx, old_idx in enumerate(output_indices): + result_outputs[old_idx] = group_result[new_idx] + + return result_outputs + + +def indptr_optimization_inner( + inputs: list[jax.Array], + outputs: list[jax.Array], + indices: list[jax.Array], + index_configuration: np.ndarray, + operations: list, + paths: list, + math_dtype: jnp.dtype, + name: str, +): + from cuequivariance_ops_jax import tensor_product_uniform_1d_jit + + ni, no = len(inputs), len(outputs) + assert index_configuration.shape[0] == ni + no + len(indices) + + first_output_config = index_configuration[ni] + if np.all(first_output_config == -1): + if os.environ.get("CUEQUIVARIANCE_DEBUG_UNIFORM_1D"): + print(f"\n{'=' * 80}") + print( + f"šŸŽÆ indptr_optimization_inner: {name} (early return - all outputs unindexed)" + ) + print(f"{'=' * 80}") + print( + f"šŸ“Š {ni} inputs, {no} outputs, {len(indices)} indices, {len(operations)} ops, {len(paths)} paths" + ) + print("ā­ļø Skipping indptr optimization (all outputs have -1 indices)") + print(f"{'=' * 80}\n") + return tensor_product_uniform_1d_jit( + inputs, + outputs, + indices, + index_configuration, + operations=operations, + paths=paths, + math_dtype=math_dtype, + name=name, + ) + + axis = np.argmax(first_output_config != -1) + primary_index = int(first_output_config[axis]) + + buffers = list(inputs) + list(outputs) + extents = { + buffers[i].shape[j] + for i in range(ni + no) + for j in range(index_configuration.shape[1]) + if index_configuration[i, j] == primary_index + } + assert len(extents) == 1, ( + f"Expected unique extent for primary_index={primary_index}, got {extents}" + ) + extent = extents.pop() + + primary_index_array = indices[primary_index] + new_indices = [] + for idx, old_index in enumerate(indices): + indptr, grouped_index = group_by_index( + primary_index_array, old_index, extent, axis + ) + new_indices.append(indptr if idx == primary_index else grouped_index) + + index_configuration = index_configuration.copy() + index_configuration[ni + no + primary_index, axis] = -2 + + unindexed_mask = index_configuration[:ni, axis] == -1 + if np.any(unindexed_mask): + first_unindexed_input = np.argmax(unindexed_mask) + batch_size = inputs[first_unindexed_input].shape[axis] + target_shape = list(primary_index_array.shape) + target_shape[axis] = batch_size + + shape = [1] * len(target_shape) + shape[axis] = batch_size + sequential_index = jnp.broadcast_to( + jnp.arange(batch_size).reshape(shape), + target_shape, + ) + _, grouped_sequential = group_by_index( + primary_index_array, sequential_index, extent, axis + ) + new_indices.append(grouped_sequential) + + new_index_id = len(new_indices) - 1 + index_configuration[:ni, axis] = np.where( + unindexed_mask, new_index_id, index_configuration[:ni, axis] + ) + index_configuration = np.concatenate( + [ + index_configuration, + np.full((1, index_configuration.shape[1]), -1, dtype=np.int32), + ], + axis=0, + ) + + if os.environ.get("CUEQUIVARIANCE_DEBUG_UNIFORM_1D"): + print(f"\n{'=' * 80}") + print(f"šŸŽÆ indptr_optimization_inner: {name}") + print(f"{'=' * 80}") + print( + f"šŸ“Š {ni} inputs, {no} outputs, {len(new_indices)} indices, {len(operations)} ops, {len(paths)} paths" + ) + print(f"šŸ”¢ Input shapes: {[tuple(x.shape) for x in inputs]}") + print(f"šŸ“¦ Output shapes: {[tuple(x.shape) for x in outputs]}") + print(f"šŸŽ² Index shapes: {[tuple(x.shape) for x in new_indices]}") + print(f"\nšŸ“‹ Index Configuration ({index_configuration.shape}):") + print(f" Inputs [{0:2d}:{ni:2d}]: {index_configuration[:ni].tolist()}") + print( + f" Outputs [{ni:2d}:{ni + no:2d}]: {index_configuration[ni : ni + no].tolist()}" + ) + print( + f" Indices [{ni + no:2d}:{len(index_configuration):2d}]: {index_configuration[ni + no :].tolist()}" + ) + print(f"{'=' * 80}\n") + + return tensor_product_uniform_1d_jit( + inputs, + outputs, + new_indices, index_configuration, operations=operations, paths=paths, - math_dtype=compute_dtype, - name=sanitize_string(name), + math_dtype=math_dtype, + name=name, ) - return [jnp.reshape(x, y.shape) for x, y in zip(outputs, outputs_shape_dtype)] From 8e52fbd8656ff306d2e0cad0fece60689237751a Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 11 Nov 2025 06:52:40 -0800 Subject: [PATCH 04/11] rename --- .../segmented_polynomial_uniform_1d.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_uniform_1d.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_uniform_1d.py index 303a0217..8316f464 100644 --- a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_uniform_1d.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_uniform_1d.py @@ -190,8 +190,8 @@ def fn(op, d: cue.SegmentedTensorProduct): for path in stp.paths: paths.append(Path(path.indices, path.coefficients.item())) - if options.get("indptr"): - outputs = indptr_optimization_outer( + if options.get("auto_deterministic_indexing"): + outputs = deterministic_indexing_grouped( buffers[: polynomial.num_inputs], buffers[polynomial.num_inputs :], list(indices), @@ -215,7 +215,7 @@ def fn(op, d: cue.SegmentedTensorProduct): return [jnp.reshape(x, y.shape) for x, y in zip(outputs, outputs_shape_dtype)] -def indptr_optimization_outer( +def deterministic_indexing_grouped( inputs: list[jax.Array], outputs: list[jax.Array], indices: list[jax.Array], @@ -244,7 +244,7 @@ def indptr_optimization_outer( if os.environ.get("CUEQUIVARIANCE_DEBUG_UNIFORM_1D"): print(f"\n{'=' * 80}") - print(f"šŸŽÆ indptr_optimization_outer: {name}") + print(f"šŸŽÆ deterministic_indexing_grouped: {name}") print(f"{'=' * 80}") print( f"šŸ“Š {ni} inputs, {no} outputs, {len(indices)} indices, {len(operations)} ops, {len(paths)} paths" @@ -321,7 +321,7 @@ def indptr_optimization_outer( group_indices ) - group_result = indptr_optimization_inner( + group_result = deterministic_indexing( inputs, group_outputs, group_indices, @@ -338,7 +338,7 @@ def indptr_optimization_outer( return result_outputs -def indptr_optimization_inner( +def deterministic_indexing( inputs: list[jax.Array], outputs: list[jax.Array], indices: list[jax.Array], @@ -358,13 +358,13 @@ def indptr_optimization_inner( if os.environ.get("CUEQUIVARIANCE_DEBUG_UNIFORM_1D"): print(f"\n{'=' * 80}") print( - f"šŸŽÆ indptr_optimization_inner: {name} (early return - all outputs unindexed)" + f"šŸŽÆ deterministic_indexing: {name} (early return - all outputs unindexed)" ) print(f"{'=' * 80}") print( f"šŸ“Š {ni} inputs, {no} outputs, {len(indices)} indices, {len(operations)} ops, {len(paths)} paths" ) - print("ā­ļø Skipping indptr optimization (all outputs have -1 indices)") + print("ā­ļø Skipping deterministic indexing (all outputs have -1 indices)") print(f"{'=' * 80}\n") return tensor_product_uniform_1d_jit( inputs, @@ -435,7 +435,7 @@ def indptr_optimization_inner( if os.environ.get("CUEQUIVARIANCE_DEBUG_UNIFORM_1D"): print(f"\n{'=' * 80}") - print(f"šŸŽÆ indptr_optimization_inner: {name}") + print(f"šŸŽÆ deterministic_indexing: {name}") print(f"{'=' * 80}") print( f"šŸ“Š {ni} inputs, {no} outputs, {len(new_indices)} indices, {len(operations)} ops, {len(paths)} paths" From 8ceb887731a6e625a932e34f7e051ae8b9d2cfa6 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 13 Nov 2025 09:56:34 -0800 Subject: [PATCH 05/11] IVARS and OVARS --- .../segmented_polynomials/operation.py | 29 +++++++++++++++++-- .../segmented_polynomial.py | 1 + 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/cuequivariance/cuequivariance/segmented_polynomials/operation.py b/cuequivariance/cuequivariance/segmented_polynomials/operation.py index f26c724e..2dddd0ee 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/operation.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/operation.py @@ -18,8 +18,33 @@ import itertools from collections import defaultdict -IVARS = "abcdefghijklmnopqrstuvwxyz" -OVARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + +class _IndexableList: + """A list-like object that uses a function to map indices to strings.""" + + def __init__(self, func): + self._func = func + + def __getitem__(self, index: int | slice) -> str | list[str]: + if isinstance(index, slice): + start = index.start if index.start is not None else 0 + assert index.stop is not None + stop = index.stop + step = index.step if index.step is not None else 1 + + result = [] + for i in range(start, stop, step): + result.append(self._func(i)) + return result + return self._func(index) + + +IVARS = _IndexableList( + lambda i: "abcdefghijklmnopqrstuvwxyz"[i] if 0 <= i < 26 else str(i) +) +OVARS = _IndexableList( + lambda i: "ABCDEFGHIJKLMNOPQRSTUVWXYZ"[i] if 0 <= i < 26 else str(i) +) @dataclasses.dataclass(init=False, frozen=True) diff --git a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py index cba9429b..7685f0c0 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py @@ -74,6 +74,7 @@ def __init__( assert len(opt.buffers) == stp.num_operands for i, operand in zip(opt.buffers, stp.operands): assert operand == operands[i] + assert all(b < len(inputs) + len(outputs) for b in opt.buffers) bid = opt.output_buffer(len(inputs)) perm = list(range(stp.num_operands)) From d87f0ad0acfc7a1ad5b6bf385608e09cbb08a24c Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 19 Nov 2025 08:23:48 -0800 Subject: [PATCH 06/11] torch - add options arg --- .../primitives/segmented_polynomial.py | 9 +++++++-- .../primitives/segmented_polynomial_uniform_1d.py | 13 ++++++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial.py b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial.py index 099f6981..baaf8936 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial.py @@ -14,7 +14,7 @@ # limitations under the License. import warnings -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import torch import torch.nn as nn @@ -84,6 +84,9 @@ class SegmentedPolynomial(nn.Module): -1 means the math_dtype is used. Default is 0 if there are input tensors, otherwise -1. name: Optional name for the operation. Defaults to "segmented_polynomial". + options: Optional dictionary of method-specific options. Currently supported: + - ``auto_deterministic_indexing`` (bool): For method ``"uniform_1d"``, enables + automatic deterministic indexing optimization. Defaults to False. Examples: Basic usage with spherical harmonics: @@ -143,6 +146,7 @@ def __init__( math_dtype: str | torch.dtype = None, output_dtype_map: List[int] = None, name: str = "segmented_polynomial", + options: Optional[Dict[str, Any]] = None, ): super().__init__() @@ -150,6 +154,7 @@ def __init__( self.num_outputs = polynomial.num_outputs self.method = method self.repr = polynomial.__repr__() + self.options = options if options is not None else {} if method == "": warnings.warn( @@ -179,7 +184,7 @@ def __init__( if method == "uniform_1d": self.m = SegmentedPolynomialFromUniform1dJit( - polynomial, math_dtype, output_dtype_map, name + polynomial, math_dtype, output_dtype_map, name, self.options ) self.fallback = self.m elif method == "naive": diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_uniform_1d.py b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_uniform_1d.py index 45c57114..5f9b01be 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_uniform_1d.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_uniform_1d.py @@ -14,7 +14,7 @@ # limitations under the License. from itertools import accumulate -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import torch import torch.nn as nn @@ -58,6 +58,7 @@ def tensor_product_uniform_1d_jit( path_coefficients: List[float], batch_size: int, tensors: List[torch.Tensor], + auto_deterministic_indexing: bool = False, ) -> List[torch.Tensor]: return torch.ops.cuequivariance.tensor_product_uniform_1d_jit( name, @@ -81,6 +82,7 @@ def tensor_product_uniform_1d_jit( path_coefficients, batch_size, tensors, + auto_deterministic_indexing=auto_deterministic_indexing, ) except ImportError: tensor_product_uniform_1d_jit = None @@ -93,6 +95,7 @@ def __init__( math_dtype: Optional[str | torch.dtype] = None, output_dtype_map: List[int] = None, name: str = "segmented_polynomial", + options: Optional[Dict[str, Any]] = None, ): super().__init__() @@ -199,6 +202,13 @@ def __init__( self.BATCH_DIM_BATCHED = BATCH_DIM_BATCHED self.BATCH_DIM_INDEXED = BATCH_DIM_INDEXED + # Extract auto_deterministic_indexing from options + self.auto_deterministic_indexing = False + if options is not None: + self.auto_deterministic_indexing = options.get( + "auto_deterministic_indexing", False + ) + def forward( self, inputs: List[torch.Tensor], @@ -273,4 +283,5 @@ def forward( self.path_coefficients, batch_size, tensors, + auto_deterministic_indexing=self.auto_deterministic_indexing, ) From 33f71c7e5ff9a3f7cfacfa7b421b9adf635fc7d2 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 15 Dec 2025 09:50:47 -0800 Subject: [PATCH 07/11] fix --- .../segmented_polynomials/segmented_polynomial_naive.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_naive.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_naive.py index aa01ce9d..c2a88f09 100644 --- a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_naive.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_naive.py @@ -46,7 +46,6 @@ def execute_naive( options, name: str, ) -> list[jax.Array]: # output buffers - math_dtype = options.get("math_dtype") if any(mode == IndexingMode.REPEATED for modes in index_mode for mode in modes): return execute_indexed_linear( inputs, @@ -55,7 +54,7 @@ def execute_naive( index_configuration, index_mode, polynomial, - math_dtype, + options, name, run_kernel=False, ) @@ -64,9 +63,8 @@ def execute_naive( jnp.result_type( *[x.dtype for x in inputs] + [x.dtype for x in outputs_shape_dtype] ), - math_dtype, + options.get("math_dtype"), ) - del math_dtype num_inputs = len(index_configuration) - len(outputs_shape_dtype) From 4ff3cbff8f8dfb824915519ec297f3b997ab5778 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 15 Dec 2025 10:13:50 -0800 Subject: [PATCH 08/11] fix --- .../cuequivariance_jax/segmented_polynomials/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/utils.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/utils.py index 58361044..d9299a6c 100644 --- a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/utils.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/utils.py @@ -146,8 +146,10 @@ def group_by_index( >>> primary_idx = jnp.array([1, 0, 2, 1, 0]) >>> secondary_idx = jnp.array([10, 20, 30, 40, 50]) >>> indptr, reordered = group_by_index(primary_idx, secondary_idx, max_primary_idx=3) - >>> reordered[indptr[0]:indptr[1]] # Elements where primary_idx == 0: [20, 50] - >>> reordered[indptr[1]:indptr[2]] # Elements where primary_idx == 1: [10, 40] + >>> print(reordered[indptr[0]:indptr[1]]) # Elements where primary_idx == 0 + [20 50] + >>> print(reordered[indptr[1]:indptr[2]]) # Elements where primary_idx == 1 + [10 40] """ assert primary_idx.ndim == secondary_idx.ndim assert primary_idx.shape[axis] == secondary_idx.shape[axis] From 86a6a519ebd666f75a9f782ede152335a3e3739a Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 19 Dec 2025 01:39:27 -0800 Subject: [PATCH 09/11] revert --- .../segmented_polynomial_uniform_1d.py | 279 +----------------- 1 file changed, 8 insertions(+), 271 deletions(-) diff --git a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_uniform_1d.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_uniform_1d.py index 8316f464..acbf6850 100644 --- a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_uniform_1d.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_uniform_1d.py @@ -13,14 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -import os import re import warnings import jax import jax.numpy as jnp import numpy as np -from cuequivariance_jax.segmented_polynomials.utils import group_by_index, reshape +from cuequivariance_jax.segmented_polynomials.utils import reshape from packaging import version import cuequivariance as cue @@ -190,276 +189,14 @@ def fn(op, d: cue.SegmentedTensorProduct): for path in stp.paths: paths.append(Path(path.indices, path.coefficients.item())) - if options.get("auto_deterministic_indexing"): - outputs = deterministic_indexing_grouped( - buffers[: polynomial.num_inputs], - buffers[polynomial.num_inputs :], - list(indices), - index_configuration, - operations=operations, - paths=paths, - math_dtype=compute_dtype, - name=sanitize_string(name), - ) - else: - outputs = tensor_product_uniform_1d_jit( - buffers[: polynomial.num_inputs], - buffers[polynomial.num_inputs :], - list(indices), - index_configuration, - operations=operations, - paths=paths, - math_dtype=compute_dtype, - name=sanitize_string(name), - ) - return [jnp.reshape(x, y.shape) for x, y in zip(outputs, outputs_shape_dtype)] - - -def deterministic_indexing_grouped( - inputs: list[jax.Array], - outputs: list[jax.Array], - indices: list[jax.Array], - index_configuration: np.ndarray, - operations: list, - paths: list, - math_dtype: jnp.dtype, - name: str, -): - from collections import defaultdict - - from cuequivariance_ops_jax import Operation - - ni, no = len(inputs), len(outputs) - assert index_configuration.shape[0] == ni + no + len(indices) - assert np.all(index_configuration[ni + no :] == -1) - - output_configs = index_configuration[ni : ni + no] - unique_configs, group_assignments = np.unique( - output_configs, axis=0, return_inverse=True - ) - - groups = defaultdict(list) - for output_idx, group_idx in enumerate(group_assignments): - groups[group_idx].append(output_idx) - - if os.environ.get("CUEQUIVARIANCE_DEBUG_UNIFORM_1D"): - print(f"\n{'=' * 80}") - print(f"šŸŽÆ deterministic_indexing_grouped: {name}") - print(f"{'=' * 80}") - print( - f"šŸ“Š {ni} inputs, {no} outputs, {len(indices)} indices, {len(operations)} ops, {len(paths)} paths" - ) - print(f"šŸ”¢ Input shapes: {[tuple(x.shape) for x in inputs]}") - print(f"šŸ“¦ Output shapes: {[tuple(x.shape) for x in outputs]}") - print(f"šŸŽ² Index shapes: {[tuple(x.shape) for x in indices]}") - print(f"\nšŸ“‹ Index Configuration ({index_configuration.shape}):") - print(f" Inputs [{0:2d}:{ni:2d}]: {index_configuration[:ni].tolist()}") - print( - f" Outputs [{ni:2d}:{ni + no:2d}]: {index_configuration[ni : ni + no].tolist()}" - ) - print( - f" Indices [{ni + no:2d}:{len(index_configuration):2d}]: {index_configuration[ni + no :].tolist()}" - ) - print(f"šŸŽØ Found {len(groups)} unique output groups:") - for group_idx, output_indices in groups.items(): - print( - f" Group {group_idx}: {unique_configs[group_idx].tolist()} → outputs {output_indices} ({len(output_indices)} outputs)" - ) - print(f"{'=' * 80}\n") - - result_outputs = [None] * no - - for group_idx, output_indices in groups.items(): - group_outputs = [outputs[i] for i in output_indices] - output_buffer_map = { - ni + old_idx: ni + new_idx for new_idx, old_idx in enumerate(output_indices) - } - - group_operations = [ - Operation( - tuple( - output_buffer_map.get(b, b) if ni <= b < ni + no else b - for b in op.buffers - ), - op.start_path, - op.num_paths, - ) - for op in operations - if any(ni <= b < ni + no and (b - ni) in output_indices for b in op.buffers) - ] - - group_index_config = np.concatenate( - [ - index_configuration[:ni], - index_configuration[ni : ni + no][output_indices], - index_configuration[ni + no :], - ] - ) - - used_indices = sorted(set(group_index_config.flatten()) - {-1}) - if len(used_indices) < len(indices): - index_map = { - old_idx: new_idx for new_idx, old_idx in enumerate(used_indices) - } - group_indices = [indices[i] for i in used_indices] - group_index_config_remapped = group_index_config.copy() - for i in range(group_index_config_remapped.shape[0]): - for j in range(group_index_config_remapped.shape[1]): - val = group_index_config_remapped[i, j] - if val != -1: - group_index_config_remapped[i, j] = index_map[val] - no_group = len(group_outputs) - rows_to_keep = list(range(ni + no_group)) + [ - ni + no_group + i for i in used_indices - ] - group_index_config_remapped = group_index_config_remapped[rows_to_keep] - else: - group_indices = indices - group_index_config_remapped = group_index_config - - assert group_index_config_remapped.shape[0] == ni + len(group_outputs) + len( - group_indices - ) - - group_result = deterministic_indexing( - inputs, - group_outputs, - group_indices, - group_index_config_remapped, - operations=group_operations, - paths=paths, - math_dtype=math_dtype, - name=f"{name}_group{group_idx}", - ) - - for new_idx, old_idx in enumerate(output_indices): - result_outputs[old_idx] = group_result[new_idx] - - return result_outputs - - -def deterministic_indexing( - inputs: list[jax.Array], - outputs: list[jax.Array], - indices: list[jax.Array], - index_configuration: np.ndarray, - operations: list, - paths: list, - math_dtype: jnp.dtype, - name: str, -): - from cuequivariance_ops_jax import tensor_product_uniform_1d_jit - - ni, no = len(inputs), len(outputs) - assert index_configuration.shape[0] == ni + no + len(indices) - - first_output_config = index_configuration[ni] - if np.all(first_output_config == -1): - if os.environ.get("CUEQUIVARIANCE_DEBUG_UNIFORM_1D"): - print(f"\n{'=' * 80}") - print( - f"šŸŽÆ deterministic_indexing: {name} (early return - all outputs unindexed)" - ) - print(f"{'=' * 80}") - print( - f"šŸ“Š {ni} inputs, {no} outputs, {len(indices)} indices, {len(operations)} ops, {len(paths)} paths" - ) - print("ā­ļø Skipping deterministic indexing (all outputs have -1 indices)") - print(f"{'=' * 80}\n") - return tensor_product_uniform_1d_jit( - inputs, - outputs, - indices, - index_configuration, - operations=operations, - paths=paths, - math_dtype=math_dtype, - name=name, - ) - - axis = np.argmax(first_output_config != -1) - primary_index = int(first_output_config[axis]) - - buffers = list(inputs) + list(outputs) - extents = { - buffers[i].shape[j] - for i in range(ni + no) - for j in range(index_configuration.shape[1]) - if index_configuration[i, j] == primary_index - } - assert len(extents) == 1, ( - f"Expected unique extent for primary_index={primary_index}, got {extents}" - ) - extent = extents.pop() - - primary_index_array = indices[primary_index] - new_indices = [] - for idx, old_index in enumerate(indices): - indptr, grouped_index = group_by_index( - primary_index_array, old_index, extent, axis - ) - new_indices.append(indptr if idx == primary_index else grouped_index) - - index_configuration = index_configuration.copy() - index_configuration[ni + no + primary_index, axis] = -2 - - unindexed_mask = index_configuration[:ni, axis] == -1 - if np.any(unindexed_mask): - first_unindexed_input = np.argmax(unindexed_mask) - batch_size = inputs[first_unindexed_input].shape[axis] - target_shape = list(primary_index_array.shape) - target_shape[axis] = batch_size - - shape = [1] * len(target_shape) - shape[axis] = batch_size - sequential_index = jnp.broadcast_to( - jnp.arange(batch_size).reshape(shape), - target_shape, - ) - _, grouped_sequential = group_by_index( - primary_index_array, sequential_index, extent, axis - ) - new_indices.append(grouped_sequential) - - new_index_id = len(new_indices) - 1 - index_configuration[:ni, axis] = np.where( - unindexed_mask, new_index_id, index_configuration[:ni, axis] - ) - index_configuration = np.concatenate( - [ - index_configuration, - np.full((1, index_configuration.shape[1]), -1, dtype=np.int32), - ], - axis=0, - ) - - if os.environ.get("CUEQUIVARIANCE_DEBUG_UNIFORM_1D"): - print(f"\n{'=' * 80}") - print(f"šŸŽÆ deterministic_indexing: {name}") - print(f"{'=' * 80}") - print( - f"šŸ“Š {ni} inputs, {no} outputs, {len(new_indices)} indices, {len(operations)} ops, {len(paths)} paths" - ) - print(f"šŸ”¢ Input shapes: {[tuple(x.shape) for x in inputs]}") - print(f"šŸ“¦ Output shapes: {[tuple(x.shape) for x in outputs]}") - print(f"šŸŽ² Index shapes: {[tuple(x.shape) for x in new_indices]}") - print(f"\nšŸ“‹ Index Configuration ({index_configuration.shape}):") - print(f" Inputs [{0:2d}:{ni:2d}]: {index_configuration[:ni].tolist()}") - print( - f" Outputs [{ni:2d}:{ni + no:2d}]: {index_configuration[ni : ni + no].tolist()}" - ) - print( - f" Indices [{ni + no:2d}:{len(index_configuration):2d}]: {index_configuration[ni + no :].tolist()}" - ) - print(f"{'=' * 80}\n") - - return tensor_product_uniform_1d_jit( - inputs, - outputs, - new_indices, + outputs = tensor_product_uniform_1d_jit( + buffers[: polynomial.num_inputs], + buffers[polynomial.num_inputs :], + list(indices), index_configuration, operations=operations, paths=paths, - math_dtype=math_dtype, - name=name, + math_dtype=compute_dtype, + name=sanitize_string(name), ) + return [jnp.reshape(x, y.shape) for x, y in zip(outputs, outputs_shape_dtype)] From 7bffb9c069c24c21e85220c02a4806ff387dd8cc Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 19 Dec 2025 01:44:43 -0800 Subject: [PATCH 10/11] revert --- .../primitives/segmented_polynomial.py | 4 +--- .../primitives/segmented_polynomial_uniform_1d.py | 10 ---------- 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial.py b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial.py index baaf8936..1b09bb7d 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial.py @@ -84,9 +84,7 @@ class SegmentedPolynomial(nn.Module): -1 means the math_dtype is used. Default is 0 if there are input tensors, otherwise -1. name: Optional name for the operation. Defaults to "segmented_polynomial". - options: Optional dictionary of method-specific options. Currently supported: - - ``auto_deterministic_indexing`` (bool): For method ``"uniform_1d"``, enables - automatic deterministic indexing optimization. Defaults to False. + options: Optional dictionary of method-specific options. Examples: Basic usage with spherical harmonics: diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_uniform_1d.py b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_uniform_1d.py index c02dbd7c..c5290a53 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_uniform_1d.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_uniform_1d.py @@ -58,7 +58,6 @@ def tensor_product_uniform_1d_jit( path_coefficients: List[float], batch_size: int, tensors: List[torch.Tensor], - auto_deterministic_indexing: bool = False, ) -> List[torch.Tensor]: return torch.ops.cuequivariance.tensor_product_uniform_1d_jit( name, @@ -82,7 +81,6 @@ def tensor_product_uniform_1d_jit( path_coefficients, batch_size, tensors, - auto_deterministic_indexing=auto_deterministic_indexing, ) except ImportError: tensor_product_uniform_1d_jit = None @@ -227,13 +225,6 @@ def __init__( self.BATCH_DIM_BATCHED = BATCH_DIM_BATCHED self.BATCH_DIM_INDEXED = BATCH_DIM_INDEXED - # Extract auto_deterministic_indexing from options - self.auto_deterministic_indexing = False - if options is not None: - self.auto_deterministic_indexing = options.get( - "auto_deterministic_indexing", False - ) - def forward( self, inputs: List[torch.Tensor], @@ -308,5 +299,4 @@ def forward( self.path_coefficients, batch_size, tensors, - auto_deterministic_indexing=self.auto_deterministic_indexing, ) From d08bf8fc010eb061954d11d32885a7427b7b35e2 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 24 Dec 2025 04:26:51 -0800 Subject: [PATCH 11/11] use new api --- .../segmented_polynomial_uniform_1d.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_uniform_1d.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_uniform_1d.py index acbf6850..646bd16d 100644 --- a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_uniform_1d.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_uniform_1d.py @@ -168,12 +168,7 @@ def fn(op, d: cue.SegmentedTensorProduct): compute_dtype = jnp.float32 try: - from cuequivariance_ops_jax import ( - Operation, - Path, - __version__, - tensor_product_uniform_1d_jit, - ) + from cuequivariance_ops_jax import Operation, Path, __version__, uniform_1d except ImportError as e: raise ValueError(f"cuequivariance_ops_jax is not installed: {e}") @@ -189,7 +184,7 @@ def fn(op, d: cue.SegmentedTensorProduct): for path in stp.paths: paths.append(Path(path.indices, path.coefficients.item())) - outputs = tensor_product_uniform_1d_jit( + outputs = uniform_1d( buffers[: polynomial.num_inputs], buffers[polynomial.num_inputs :], list(indices),