Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions cuequivariance/cuequivariance/segmented_polynomials/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,33 @@
import itertools
from collections import defaultdict

IVARS = "abcdefghijklmnopqrstuvwxyz"
OVARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"

class _IndexableList:
"""A list-like object that uses a function to map indices to strings."""

def __init__(self, func):
self._func = func

def __getitem__(self, index: int | slice) -> str | list[str]:
if isinstance(index, slice):
start = index.start if index.start is not None else 0
assert index.stop is not None
stop = index.stop
step = index.step if index.step is not None else 1

result = []
for i in range(start, stop, step):
result.append(self._func(i))
return result
return self._func(index)


IVARS = _IndexableList(
lambda i: "abcdefghijklmnopqrstuvwxyz"[i] if 0 <= i < 26 else str(i)
)
OVARS = _IndexableList(
lambda i: "ABCDEFGHIJKLMNOPQRSTUVWXYZ"[i] if 0 <= i < 26 else str(i)
)


@dataclasses.dataclass(init=False, frozen=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
assert len(opt.buffers) == stp.num_operands
for i, operand in zip(opt.buffers, stp.operands):
assert operand == operands[i]
assert all(b < len(inputs) + len(outputs) for b in opt.buffers)

bid = opt.output_buffer(len(inputs))
perm = list(range(stp.num_operands))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

import jax
import jax.numpy as jnp

Expand All @@ -30,6 +32,7 @@ def equivariant_polynomial(
*,
method: str = "",
math_dtype: str | None = None,
options: dict[str, Any] | None = None,
name: str | None = None,
precision: jax.lax.Precision = "undefined",
) -> list[cuex.RepArray] | cuex.RepArray:
Expand All @@ -51,6 +54,7 @@ def equivariant_polynomial(
operands. Defaults to None. Note that indices are not supported for all methods.
method: Method to use for computation. See :func:`cuex.segmented_polynomial <cuequivariance_jax.segmented_polynomial>` for available methods.
math_dtype: See :func:`cuex.segmented_polynomial <cuequivariance_jax.segmented_polynomial>` for supported options.
options: Optional dictionary of method-specific options. See :func:`cuex.segmented_polynomial <cuequivariance_jax.segmented_polynomial>` for supported options.
name: Optional name for the operation. Defaults to None.

Returns:
Expand Down Expand Up @@ -186,8 +190,10 @@ def equivariant_polynomial(
outputs_shape_dtype,
indices,
math_dtype=math_dtype,
options=options,
name=name,
method=method,
precision=precision,
)
outputs = [cuex.RepArray(rep, x) for rep, x in zip(poly.outputs, outputs)]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import warnings
from functools import partial
from typing import Any

import jax
import jax.core
Expand Down Expand Up @@ -55,6 +56,7 @@ def segmented_polynomial(
*,
method: str = "",
math_dtype: str | None = None,
options: dict[str, Any] | None = None,
name: str | None = None,
precision: jax.lax.Precision = "undefined",
) -> list[jax.Array]:
Expand All @@ -79,15 +81,19 @@ def segmented_polynomial(

.. note::
The ``"fused_tp"`` method is only available in the PyTorch implementation.
math_dtype: Data type for computational operations. If None, automatically determined from input types. Defaults to None.
math_dtype: (Deprecated) Data type for computational operations. Prefer using ``options`` instead. Defaults to None.
options: Optional dictionary of method-specific options. Defaults to None.

Supported options vary by method:
Common options:

- ``"naive"``: String dtype names (e.g., ``"float32"``, ``"float64"``, ``"float16"``, ``"bfloat16"``).
Also supports ``"tensor_float32"`` for TensorFloat-32 mode.
- ``"uniform_1d"``: String dtype names (e.g., ``"float32"``, ``"float64"``, ``"float16"``, ``"bfloat16"``).
- ``"indexed_linear"``: CUBLAS compute type strings such as ``"CUBLAS_COMPUTE_32F"``, ``"CUBLAS_COMPUTE_32F_FAST_TF32"``,
``"CUBLAS_COMPUTE_32F_PEDANTIC"``, ``"CUBLAS_COMPUTE_64F"``, etc.
- ``"math_dtype"``: Data type for computational operations. If None, automatically determined from input types.
Supported values vary by method:

- ``"naive"``: String dtype names (e.g., ``"float32"``, ``"float64"``, ``"float16"``, ``"bfloat16"``).
Also supports ``"tensor_float32"`` for TensorFloat-32 mode.
- ``"uniform_1d"``: String dtype names (e.g., ``"float32"``, ``"float64"``, ``"float16"``, ``"bfloat16"``).
- ``"indexed_linear"``: CUBLAS compute type strings such as ``"CUBLAS_COMPUTE_32F"``, ``"CUBLAS_COMPUTE_32F_FAST_TF32"``,
``"CUBLAS_COMPUTE_32F_PEDANTIC"``, ``"CUBLAS_COMPUTE_64F"``, etc.

name: Optional name for the operation.

Expand Down Expand Up @@ -171,9 +177,24 @@ def segmented_polynomial(
if name is None:
name = "segmented_polynomial"

if math_dtype is not None and not isinstance(math_dtype, str):
math_dtype = jnp.dtype(math_dtype).name
assert isinstance(math_dtype, str) or math_dtype is None
if options is None:
options = {}
else:
options = dict(options)

if math_dtype is not None:
if "math_dtype" in options:
raise ValueError(
"math_dtype provided both as argument and in options dict. "
"Please use only options['math_dtype']."
)
options["math_dtype"] = math_dtype

if "math_dtype" in options:
math_dtype_val = options["math_dtype"]
if math_dtype_val is not None and not isinstance(math_dtype_val, str):
options["math_dtype"] = jnp.dtype(math_dtype_val).name
assert isinstance(options["math_dtype"], str) or options["math_dtype"] is None

if precision != "undefined":
raise ValueError(
Expand Down Expand Up @@ -297,7 +318,7 @@ def fn(x, n: int):
index_configuration=index_configuration,
index_mode=index_mode,
polynomial=polynomial,
math_dtype=math_dtype,
options=tuple(sorted(options.items())),
name=name,
)

Expand Down Expand Up @@ -360,7 +381,7 @@ def segmented_polynomial_prim(
index_configuration: list[list[int]], # maps: buffer index -> unique indices index
index_mode: list[list[IndexingMode]], # shared, batched, indexed, repeated
polynomial: cue.SegmentedPolynomial,
math_dtype: str | None,
options: tuple[tuple[str, Any], ...],
name: str,
method: str,
return_none_if_empty: bool = False,
Expand Down Expand Up @@ -391,7 +412,7 @@ def segmented_polynomial_prim(
x for x, used in zip(outputs_shape_dtype, used_outputs) if used
),
polynomial=polynomial.filter_keep_operands(used_inputs + used_outputs),
math_dtype=math_dtype,
options=options,
name=str(name),
method=method,
)
Expand Down Expand Up @@ -449,7 +470,7 @@ def segmented_polynomial_abstract_eval(
index_mode: tuple[tuple[IndexingMode, ...], ...],
outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...],
polynomial: cue.SegmentedPolynomial,
math_dtype: str | None,
options: tuple[tuple[str, Any], ...],
name: str,
method: str,
) -> tuple[jax.core.ShapedArray, ...]:
Expand All @@ -465,7 +486,7 @@ def segmented_polynomial_impl(
index_mode: tuple[tuple[IndexingMode, ...], ...],
outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...],
polynomial: cue.SegmentedPolynomial,
math_dtype: str | None,
options: tuple[tuple[str, Any], ...],
name: str,
method: str,
) -> tuple[jax.Array, ...]:
Expand Down Expand Up @@ -509,7 +530,7 @@ def segmented_polynomial_impl(
indices=indices,
index_configuration=index_configuration,
polynomial=polynomial,
math_dtype=math_dtype,
options=dict(options),
name=name,
)

Expand All @@ -536,7 +557,7 @@ def segmented_polynomial_jvp(
index_mode: tuple[tuple[IndexingMode, ...], ...],
outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...],
polynomial: cue.SegmentedPolynomial,
math_dtype: str | None,
options: tuple[tuple[str, Any], ...],
name: str,
method: str,
) -> tuple[tuple[jax.Array, ...], tuple[jax.Array | ad.Zero, ...]]:
Expand All @@ -557,7 +578,7 @@ def segmented_polynomial_jvp(
index_configuration,
index_mode,
polynomial,
math_dtype,
options,
name,
method=method,
)
Expand All @@ -579,7 +600,7 @@ def segmented_polynomial_jvp(
jvp_index_configuration,
jvp_index_mode,
jvp_poly,
math_dtype,
options,
name
+ "_jvp"
+ "".join("0" if isinstance(t, ad.Zero) else "1" for t in tangents),
Expand All @@ -596,7 +617,7 @@ def segmented_polynomial_transpose(
index_mode: tuple[tuple[IndexingMode, ...], ...],
outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...],
polynomial: cue.SegmentedPolynomial,
math_dtype: str | None,
options: tuple[tuple[str, Any], ...],
name: str,
method: str,
) -> tuple[jax.Array | ad.Zero | None, ...]:
Expand Down Expand Up @@ -637,7 +658,7 @@ def segmented_polynomial_transpose(
tr_index_configuration,
tr_index_mode,
tr_poly,
math_dtype,
options,
name + "_T",
method=method,
return_none_if_empty=True,
Expand All @@ -660,7 +681,7 @@ def segmented_polynomial_batching(
index_mode: tuple[tuple[IndexingMode, ...], ...],
outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...],
polynomial: cue.SegmentedPolynomial,
math_dtype: str | None,
options: tuple[tuple[str, Any], ...],
name: str,
method: str,
) -> tuple[tuple[jax.Array, ...], tuple[int, ...]]:
Expand Down Expand Up @@ -700,7 +721,7 @@ def prepare(input: jax.Array, axis: int | None) -> jax.Array:
index_mode=index_mode,
outputs_shape_dtype=outputs_shape_dtype,
polynomial=polynomial,
math_dtype=math_dtype,
options=options,
name=name + "_batching",
method=method,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ def execute_indexed_linear(
index_configuration: tuple[tuple[int, ...], ...],
index_mode: tuple[tuple[IndexingMode, ...], ...],
polynomial: cue.SegmentedPolynomial,
math_dtype: str | None,
options,
name: str,
run_kernel: bool = True,
) -> list[jax.Array]: # output buffers
math_dtype = options.get("math_dtype")
num_inputs = len(index_configuration) - len(outputs_shape_dtype)

io_buffers = list(inputs) + [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def execute_naive(
index_configuration: tuple[tuple[int, ...], ...],
index_mode: tuple[tuple[IndexingMode, ...], ...],
polynomial: cue.SegmentedPolynomial,
math_dtype: str | None,
options,
name: str,
) -> list[jax.Array]: # output buffers
if any(mode == IndexingMode.REPEATED for modes in index_mode for mode in modes):
Expand All @@ -54,7 +54,7 @@ def execute_naive(
index_configuration,
index_mode,
polynomial,
math_dtype,
options,
name,
run_kernel=False,
)
Expand All @@ -63,9 +63,8 @@ def execute_naive(
jnp.result_type(
*[x.dtype for x in inputs] + [x.dtype for x in outputs_shape_dtype]
),
math_dtype,
options.get("math_dtype"),
)
del math_dtype

num_inputs = len(index_configuration) - len(outputs_shape_dtype)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,13 @@ def execute_uniform_1d(
indices: list[jax.Array],
index_configuration: tuple[tuple[int, ...], ...],
polynomial: cue.SegmentedPolynomial,
math_dtype: str | None,
options: dict,
name: str,
) -> list[jax.Array]:
error_message = f"Failed to execute 'uniform_1d' method for the following polynomial:\n{polynomial}\n"

index_configuration = np.array(index_configuration)
num_batch_axes = index_configuration.shape[1]
assert (
polynomial.num_inputs + len(outputs_shape_dtype) == index_configuration.shape[0]
)
assert polynomial.num_outputs == len(outputs_shape_dtype)

try:
Expand All @@ -69,9 +66,14 @@ def fn(op, d: cue.SegmentedTensorProduct):

polynomial = polynomial.apply_fn(fn)

# We don't use the feature that indices can index themselves
index_configuration = np.concatenate(
[index_configuration, np.full((len(indices), num_batch_axes), -1, np.int32)]
if polynomial.num_inputs + len(outputs_shape_dtype) == index_configuration.shape[0]:
index_configuration = np.concatenate(
[index_configuration, np.full((len(indices), num_batch_axes), -1, np.int32)]
)

assert (
polynomial.num_inputs + len(outputs_shape_dtype) + len(indices)
== index_configuration.shape[0]
)

buffers = list(inputs) + list(outputs_shape_dtype)
Expand Down Expand Up @@ -151,8 +153,9 @@ def fn(op, d: cue.SegmentedTensorProduct):
if len({b.shape[-1] for b in buffers}.union({1})) > 2:
raise ValueError(f"Buffer shapes not compatible {[b.shape for b in buffers]}")

if math_dtype is not None:
if "math_dtype" in options:
supported_dtypes = {"float32", "float64", "float16", "bfloat16"}
math_dtype = options["math_dtype"]
if math_dtype not in supported_dtypes:
raise ValueError(
f"method='uniform_1d' only supports math_dtype equal to {supported_dtypes}, got '{math_dtype}'."
Expand All @@ -165,12 +168,7 @@ def fn(op, d: cue.SegmentedTensorProduct):
compute_dtype = jnp.float32

try:
from cuequivariance_ops_jax import (
Operation,
Path,
__version__,
tensor_product_uniform_1d_jit,
)
from cuequivariance_ops_jax import Operation, Path, __version__, uniform_1d
except ImportError as e:
raise ValueError(f"cuequivariance_ops_jax is not installed: {e}")

Expand All @@ -186,7 +184,7 @@ def fn(op, d: cue.SegmentedTensorProduct):
for path in stp.paths:
paths.append(Path(path.indices, path.coefficients.item()))

outputs = tensor_product_uniform_1d_jit(
outputs = uniform_1d(
buffers[: polynomial.num_inputs],
buffers[polynomial.num_inputs :],
list(indices),
Expand Down
Loading