diff --git a/cuequivariance/cuequivariance/segmented_polynomials/operation.py b/cuequivariance/cuequivariance/segmented_polynomials/operation.py index f26c724e..2dddd0ee 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/operation.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/operation.py @@ -18,8 +18,33 @@ import itertools from collections import defaultdict -IVARS = "abcdefghijklmnopqrstuvwxyz" -OVARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + +class _IndexableList: + """A list-like object that uses a function to map indices to strings.""" + + def __init__(self, func): + self._func = func + + def __getitem__(self, index: int | slice) -> str | list[str]: + if isinstance(index, slice): + start = index.start if index.start is not None else 0 + assert index.stop is not None + stop = index.stop + step = index.step if index.step is not None else 1 + + result = [] + for i in range(start, stop, step): + result.append(self._func(i)) + return result + return self._func(index) + + +IVARS = _IndexableList( + lambda i: "abcdefghijklmnopqrstuvwxyz"[i] if 0 <= i < 26 else str(i) +) +OVARS = _IndexableList( + lambda i: "ABCDEFGHIJKLMNOPQRSTUVWXYZ"[i] if 0 <= i < 26 else str(i) +) @dataclasses.dataclass(init=False, frozen=True) diff --git a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py index cba9429b..7685f0c0 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py @@ -74,6 +74,7 @@ def __init__( assert len(opt.buffers) == stp.num_operands for i, operand in zip(opt.buffers, stp.operands): assert operand == operands[i] + assert all(b < len(inputs) + len(outputs) for b in opt.buffers) bid = opt.output_buffer(len(inputs)) perm = list(range(stp.num_operands)) diff --git a/cuequivariance_jax/cuequivariance_jax/equivariant_polynomial.py b/cuequivariance_jax/cuequivariance_jax/equivariant_polynomial.py index afbe92c1..d83b0fdf 100644 --- a/cuequivariance_jax/cuequivariance_jax/equivariant_polynomial.py +++ b/cuequivariance_jax/cuequivariance_jax/equivariant_polynomial.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + import jax import jax.numpy as jnp @@ -30,6 +32,7 @@ def equivariant_polynomial( *, method: str = "", math_dtype: str | None = None, + options: dict[str, Any] | None = None, name: str | None = None, precision: jax.lax.Precision = "undefined", ) -> list[cuex.RepArray] | cuex.RepArray: @@ -51,6 +54,7 @@ def equivariant_polynomial( operands. Defaults to None. Note that indices are not supported for all methods. method: Method to use for computation. See :func:`cuex.segmented_polynomial ` for available methods. math_dtype: See :func:`cuex.segmented_polynomial ` for supported options. + options: Optional dictionary of method-specific options. See :func:`cuex.segmented_polynomial ` for supported options. name: Optional name for the operation. Defaults to None. Returns: @@ -186,8 +190,10 @@ def equivariant_polynomial( outputs_shape_dtype, indices, math_dtype=math_dtype, + options=options, name=name, method=method, + precision=precision, ) outputs = [cuex.RepArray(rep, x) for rep, x in zip(poly.outputs, outputs)] diff --git a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py index 37d3f9cc..d27cf95c 100644 --- a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py @@ -17,6 +17,7 @@ import os import warnings from functools import partial +from typing import Any import jax import jax.core @@ -55,6 +56,7 @@ def segmented_polynomial( *, method: str = "", math_dtype: str | None = None, + options: dict[str, Any] | None = None, name: str | None = None, precision: jax.lax.Precision = "undefined", ) -> list[jax.Array]: @@ -79,15 +81,19 @@ def segmented_polynomial( .. note:: The ``"fused_tp"`` method is only available in the PyTorch implementation. - math_dtype: Data type for computational operations. If None, automatically determined from input types. Defaults to None. + math_dtype: (Deprecated) Data type for computational operations. Prefer using ``options`` instead. Defaults to None. + options: Optional dictionary of method-specific options. Defaults to None. - Supported options vary by method: + Common options: - - ``"naive"``: String dtype names (e.g., ``"float32"``, ``"float64"``, ``"float16"``, ``"bfloat16"``). - Also supports ``"tensor_float32"`` for TensorFloat-32 mode. - - ``"uniform_1d"``: String dtype names (e.g., ``"float32"``, ``"float64"``, ``"float16"``, ``"bfloat16"``). - - ``"indexed_linear"``: CUBLAS compute type strings such as ``"CUBLAS_COMPUTE_32F"``, ``"CUBLAS_COMPUTE_32F_FAST_TF32"``, - ``"CUBLAS_COMPUTE_32F_PEDANTIC"``, ``"CUBLAS_COMPUTE_64F"``, etc. + - ``"math_dtype"``: Data type for computational operations. If None, automatically determined from input types. + Supported values vary by method: + + - ``"naive"``: String dtype names (e.g., ``"float32"``, ``"float64"``, ``"float16"``, ``"bfloat16"``). + Also supports ``"tensor_float32"`` for TensorFloat-32 mode. + - ``"uniform_1d"``: String dtype names (e.g., ``"float32"``, ``"float64"``, ``"float16"``, ``"bfloat16"``). + - ``"indexed_linear"``: CUBLAS compute type strings such as ``"CUBLAS_COMPUTE_32F"``, ``"CUBLAS_COMPUTE_32F_FAST_TF32"``, + ``"CUBLAS_COMPUTE_32F_PEDANTIC"``, ``"CUBLAS_COMPUTE_64F"``, etc. name: Optional name for the operation. @@ -171,9 +177,24 @@ def segmented_polynomial( if name is None: name = "segmented_polynomial" - if math_dtype is not None and not isinstance(math_dtype, str): - math_dtype = jnp.dtype(math_dtype).name - assert isinstance(math_dtype, str) or math_dtype is None + if options is None: + options = {} + else: + options = dict(options) + + if math_dtype is not None: + if "math_dtype" in options: + raise ValueError( + "math_dtype provided both as argument and in options dict. " + "Please use only options['math_dtype']." + ) + options["math_dtype"] = math_dtype + + if "math_dtype" in options: + math_dtype_val = options["math_dtype"] + if math_dtype_val is not None and not isinstance(math_dtype_val, str): + options["math_dtype"] = jnp.dtype(math_dtype_val).name + assert isinstance(options["math_dtype"], str) or options["math_dtype"] is None if precision != "undefined": raise ValueError( @@ -297,7 +318,7 @@ def fn(x, n: int): index_configuration=index_configuration, index_mode=index_mode, polynomial=polynomial, - math_dtype=math_dtype, + options=tuple(sorted(options.items())), name=name, ) @@ -360,7 +381,7 @@ def segmented_polynomial_prim( index_configuration: list[list[int]], # maps: buffer index -> unique indices index index_mode: list[list[IndexingMode]], # shared, batched, indexed, repeated polynomial: cue.SegmentedPolynomial, - math_dtype: str | None, + options: tuple[tuple[str, Any], ...], name: str, method: str, return_none_if_empty: bool = False, @@ -391,7 +412,7 @@ def segmented_polynomial_prim( x for x, used in zip(outputs_shape_dtype, used_outputs) if used ), polynomial=polynomial.filter_keep_operands(used_inputs + used_outputs), - math_dtype=math_dtype, + options=options, name=str(name), method=method, ) @@ -449,7 +470,7 @@ def segmented_polynomial_abstract_eval( index_mode: tuple[tuple[IndexingMode, ...], ...], outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...], polynomial: cue.SegmentedPolynomial, - math_dtype: str | None, + options: tuple[tuple[str, Any], ...], name: str, method: str, ) -> tuple[jax.core.ShapedArray, ...]: @@ -465,7 +486,7 @@ def segmented_polynomial_impl( index_mode: tuple[tuple[IndexingMode, ...], ...], outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...], polynomial: cue.SegmentedPolynomial, - math_dtype: str | None, + options: tuple[tuple[str, Any], ...], name: str, method: str, ) -> tuple[jax.Array, ...]: @@ -509,7 +530,7 @@ def segmented_polynomial_impl( indices=indices, index_configuration=index_configuration, polynomial=polynomial, - math_dtype=math_dtype, + options=dict(options), name=name, ) @@ -536,7 +557,7 @@ def segmented_polynomial_jvp( index_mode: tuple[tuple[IndexingMode, ...], ...], outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...], polynomial: cue.SegmentedPolynomial, - math_dtype: str | None, + options: tuple[tuple[str, Any], ...], name: str, method: str, ) -> tuple[tuple[jax.Array, ...], tuple[jax.Array | ad.Zero, ...]]: @@ -557,7 +578,7 @@ def segmented_polynomial_jvp( index_configuration, index_mode, polynomial, - math_dtype, + options, name, method=method, ) @@ -579,7 +600,7 @@ def segmented_polynomial_jvp( jvp_index_configuration, jvp_index_mode, jvp_poly, - math_dtype, + options, name + "_jvp" + "".join("0" if isinstance(t, ad.Zero) else "1" for t in tangents), @@ -596,7 +617,7 @@ def segmented_polynomial_transpose( index_mode: tuple[tuple[IndexingMode, ...], ...], outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...], polynomial: cue.SegmentedPolynomial, - math_dtype: str | None, + options: tuple[tuple[str, Any], ...], name: str, method: str, ) -> tuple[jax.Array | ad.Zero | None, ...]: @@ -637,7 +658,7 @@ def segmented_polynomial_transpose( tr_index_configuration, tr_index_mode, tr_poly, - math_dtype, + options, name + "_T", method=method, return_none_if_empty=True, @@ -660,7 +681,7 @@ def segmented_polynomial_batching( index_mode: tuple[tuple[IndexingMode, ...], ...], outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...], polynomial: cue.SegmentedPolynomial, - math_dtype: str | None, + options: tuple[tuple[str, Any], ...], name: str, method: str, ) -> tuple[tuple[jax.Array, ...], tuple[int, ...]]: @@ -700,7 +721,7 @@ def prepare(input: jax.Array, axis: int | None) -> jax.Array: index_mode=index_mode, outputs_shape_dtype=outputs_shape_dtype, polynomial=polynomial, - math_dtype=math_dtype, + options=options, name=name + "_batching", method=method, ) diff --git a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_indexed_linear.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_indexed_linear.py index dd23fa87..7c43f4a9 100644 --- a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_indexed_linear.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_indexed_linear.py @@ -46,10 +46,11 @@ def execute_indexed_linear( index_configuration: tuple[tuple[int, ...], ...], index_mode: tuple[tuple[IndexingMode, ...], ...], polynomial: cue.SegmentedPolynomial, - math_dtype: str | None, + options, name: str, run_kernel: bool = True, ) -> list[jax.Array]: # output buffers + math_dtype = options.get("math_dtype") num_inputs = len(index_configuration) - len(outputs_shape_dtype) io_buffers = list(inputs) + [ diff --git a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_naive.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_naive.py index 00f0a18f..c2a88f09 100644 --- a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_naive.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_naive.py @@ -43,7 +43,7 @@ def execute_naive( index_configuration: tuple[tuple[int, ...], ...], index_mode: tuple[tuple[IndexingMode, ...], ...], polynomial: cue.SegmentedPolynomial, - math_dtype: str | None, + options, name: str, ) -> list[jax.Array]: # output buffers if any(mode == IndexingMode.REPEATED for modes in index_mode for mode in modes): @@ -54,7 +54,7 @@ def execute_naive( index_configuration, index_mode, polynomial, - math_dtype, + options, name, run_kernel=False, ) @@ -63,9 +63,8 @@ def execute_naive( jnp.result_type( *[x.dtype for x in inputs] + [x.dtype for x in outputs_shape_dtype] ), - math_dtype, + options.get("math_dtype"), ) - del math_dtype num_inputs = len(index_configuration) - len(outputs_shape_dtype) diff --git a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_uniform_1d.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_uniform_1d.py index cf9bf3a7..646bd16d 100644 --- a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_uniform_1d.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_uniform_1d.py @@ -38,16 +38,13 @@ def execute_uniform_1d( indices: list[jax.Array], index_configuration: tuple[tuple[int, ...], ...], polynomial: cue.SegmentedPolynomial, - math_dtype: str | None, + options: dict, name: str, ) -> list[jax.Array]: error_message = f"Failed to execute 'uniform_1d' method for the following polynomial:\n{polynomial}\n" index_configuration = np.array(index_configuration) num_batch_axes = index_configuration.shape[1] - assert ( - polynomial.num_inputs + len(outputs_shape_dtype) == index_configuration.shape[0] - ) assert polynomial.num_outputs == len(outputs_shape_dtype) try: @@ -69,9 +66,14 @@ def fn(op, d: cue.SegmentedTensorProduct): polynomial = polynomial.apply_fn(fn) - # We don't use the feature that indices can index themselves - index_configuration = np.concatenate( - [index_configuration, np.full((len(indices), num_batch_axes), -1, np.int32)] + if polynomial.num_inputs + len(outputs_shape_dtype) == index_configuration.shape[0]: + index_configuration = np.concatenate( + [index_configuration, np.full((len(indices), num_batch_axes), -1, np.int32)] + ) + + assert ( + polynomial.num_inputs + len(outputs_shape_dtype) + len(indices) + == index_configuration.shape[0] ) buffers = list(inputs) + list(outputs_shape_dtype) @@ -151,8 +153,9 @@ def fn(op, d: cue.SegmentedTensorProduct): if len({b.shape[-1] for b in buffers}.union({1})) > 2: raise ValueError(f"Buffer shapes not compatible {[b.shape for b in buffers]}") - if math_dtype is not None: + if "math_dtype" in options: supported_dtypes = {"float32", "float64", "float16", "bfloat16"} + math_dtype = options["math_dtype"] if math_dtype not in supported_dtypes: raise ValueError( f"method='uniform_1d' only supports math_dtype equal to {supported_dtypes}, got '{math_dtype}'." @@ -165,12 +168,7 @@ def fn(op, d: cue.SegmentedTensorProduct): compute_dtype = jnp.float32 try: - from cuequivariance_ops_jax import ( - Operation, - Path, - __version__, - tensor_product_uniform_1d_jit, - ) + from cuequivariance_ops_jax import Operation, Path, __version__, uniform_1d except ImportError as e: raise ValueError(f"cuequivariance_ops_jax is not installed: {e}") @@ -186,7 +184,7 @@ def fn(op, d: cue.SegmentedTensorProduct): for path in stp.paths: paths.append(Path(path.indices, path.coefficients.item())) - outputs = tensor_product_uniform_1d_jit( + outputs = uniform_1d( buffers[: polynomial.num_inputs], buffers[polynomial.num_inputs :], list(indices), diff --git a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/utils.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/utils.py index 1293e282..d9299a6c 100644 --- a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/utils.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/utils.py @@ -116,3 +116,57 @@ def math_dtype_for_naive_method( f"method='naive' does not support math_dtype '{math_dtype}'. " "Supported options are any JAX dtype (e.g., 'float32', 'float64', 'float16', 'bfloat16') or 'tensor_float32'." ) + + +def group_by_index( + primary_idx: jax.Array, + secondary_idx: jax.Array, + max_primary_idx: int, + axis: int = -1, +) -> tuple[jax.Array, jax.Array]: + """Group and reorder indices in CSR-like format along a specified axis. + + Groups ``secondary_idx`` values by their corresponding ``primary_idx`` values, + enabling efficient contiguous access to all elements with the same primary index. + + Args: + primary_idx: Indices to group by along ``axis``. + secondary_idx: Indices to reorder. Must have matching size with ``primary_idx`` along ``axis``. + max_primary_idx: Maximum value in ``primary_idx`` (exclusive). + axis: Axis along which to perform grouping. Defaults to -1. + + Returns: + tuple: ``(indptr, reordered_indices)`` where: + - ``indptr``: Offsets of shape ``(..., max_primary_idx + 1, ...)`` where + ``reordered_indices[..., indptr[k]:indptr[k+1], ...]`` contains all elements + with ``primary_idx == k``. + - ``reordered_indices``: Reordered ``secondary_idx`` with elements grouped by ``primary_idx``. + + Example: + >>> primary_idx = jnp.array([1, 0, 2, 1, 0]) + >>> secondary_idx = jnp.array([10, 20, 30, 40, 50]) + >>> indptr, reordered = group_by_index(primary_idx, secondary_idx, max_primary_idx=3) + >>> print(reordered[indptr[0]:indptr[1]]) # Elements where primary_idx == 0 + [20 50] + >>> print(reordered[indptr[1]:indptr[2]]) # Elements where primary_idx == 1 + [10 40] + """ + assert primary_idx.ndim == secondary_idx.ndim + assert primary_idx.shape[axis] == secondary_idx.shape[axis] + + reordered = jnp.take_along_axis( + secondary_idx, jnp.argsort(primary_idx, axis=axis), axis=axis + ) + + def compute_indptr(p): + return jnp.append( + 0, jnp.cumsum(jnp.zeros((max_primary_idx,), jnp.int32).at[p].add(1)) + ) + + p = jnp.moveaxis(primary_idx, axis, -1) + indptr = jax.vmap(compute_indptr)(p.reshape(-1, p.shape[-1])).reshape( + p.shape[:-1] + (max_primary_idx + 1,) + ) + indptr = jnp.moveaxis(indptr, -1, axis) + + return indptr, reordered diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial.py b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial.py index 099f6981..1b09bb7d 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial.py @@ -14,7 +14,7 @@ # limitations under the License. import warnings -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import torch import torch.nn as nn @@ -84,6 +84,7 @@ class SegmentedPolynomial(nn.Module): -1 means the math_dtype is used. Default is 0 if there are input tensors, otherwise -1. name: Optional name for the operation. Defaults to "segmented_polynomial". + options: Optional dictionary of method-specific options. Examples: Basic usage with spherical harmonics: @@ -143,6 +144,7 @@ def __init__( math_dtype: str | torch.dtype = None, output_dtype_map: List[int] = None, name: str = "segmented_polynomial", + options: Optional[Dict[str, Any]] = None, ): super().__init__() @@ -150,6 +152,7 @@ def __init__( self.num_outputs = polynomial.num_outputs self.method = method self.repr = polynomial.__repr__() + self.options = options if options is not None else {} if method == "": warnings.warn( @@ -179,7 +182,7 @@ def __init__( if method == "uniform_1d": self.m = SegmentedPolynomialFromUniform1dJit( - polynomial, math_dtype, output_dtype_map, name + polynomial, math_dtype, output_dtype_map, name, self.options ) self.fallback = self.m elif method == "naive": diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_uniform_1d.py b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_uniform_1d.py index 7f58f766..c5290a53 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_uniform_1d.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/segmented_polynomial_uniform_1d.py @@ -14,7 +14,7 @@ # limitations under the License. from itertools import accumulate -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import torch import torch.nn as nn @@ -93,6 +93,7 @@ def __init__( math_dtype: Optional[str | torch.dtype] = None, output_dtype_map: List[int] = None, name: str = "segmented_polynomial", + options: Optional[Dict[str, Any]] = None, ): super().__init__()