Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/doc_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Install system dependencies
run: |
sudo apt-get update
sudo apt-get install -y graphviz
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/doc_build_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Install system dependencies
run: |
sudo apt-get update
sudo apt-get install -y graphviz
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
7 changes: 5 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install system dependencies
run: |
sudo apt-get update
sudo apt-get install -y graphviz
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install --upgrade uv
python -m uv pip install pytest
python -m uv pip install -e ./cuequivariance
python -m uv pip install -e ./cuequivariance[dev]
- name: Test with pytest
run: |
pytest --doctest-modules -x -m "not slow" cuequivariance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from .operation import Operation
from .segmented_polynomial import SegmentedPolynomial
from .visualization import visualize_polynomial


__all__ = [
Expand All @@ -37,4 +38,5 @@
"dispatch",
"Operation",
"SegmentedPolynomial",
"visualize_polynomial",
]
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import copy
import dataclasses
import itertools
from typing import Any, Callable, Sequence
from typing import Any, Callable, Optional, Sequence

import numpy as np

Expand Down Expand Up @@ -998,7 +998,10 @@ def jvp(
self, has_tangent: list[bool]
) -> tuple[
SegmentedPolynomial,
Callable[[tuple[list[Any], list[Any]]], tuple[list[Any], list[Any]]],
Callable[
[tuple[list[Any], list[Any]], Optional[Callable[[Any], Any]]],
tuple[list[Any], list[Any]],
],
]:
"""Compute the Jacobian-vector product of the polynomial.

Expand All @@ -1023,14 +1026,20 @@ def jvp(
):
new_operations.append((ope, multiplicator * stp))

def mapping(x: tuple[list[Any], list[Any]]) -> tuple[list[Any], list[Any]]:
def mapping(
x: tuple[list[Any], list[Any]],
into_grad: Optional[Callable[[Any], Any]] = None,
) -> tuple[list[Any], list[Any]]:
inputs, outputs = x
inputs, outputs = list(inputs), list(outputs)
assert len(inputs) == self.num_inputs
assert len(outputs) == self.num_outputs
into_grad = into_grad if callable(into_grad) else lambda x: x

new_inputs = inputs + [x for has, x in zip(has_tangent, inputs) if has]
new_outputs = outputs
new_inputs = inputs + [
into_grad(x) for has, x in zip(has_tangent, inputs) if has
]
new_outputs = [into_grad(x) for x in outputs]

return new_inputs, new_outputs

Expand Down Expand Up @@ -1088,7 +1097,10 @@ def backward(
self, requires_gradient: list[bool], has_cotangent: list[bool]
) -> tuple[
SegmentedPolynomial,
Callable[[tuple[list[Any], list[Any]]], tuple[list[Any], list[Any]]],
Callable[
[tuple[list[Any], list[Any]], Optional[Callable[[Any], Any]]],
tuple[list[Any], list[Any]],
],
]:
"""Compute the backward pass of the polynomial for gradient computation.

Expand All @@ -1106,7 +1118,10 @@ def backward(
has_cotangent,
)

def mapping(x: tuple[list[Any], list[Any]]) -> tuple[list[Any], list[Any]]:
return map2(map1(x))
def mapping(
x: tuple[list[Any], list[Any]],
into_grad: Optional[Callable[[Any], Any]] = None,
) -> tuple[list[Any], list[Any]]:
return map2(map1(x, into_grad))

return p, mapping
102 changes: 102 additions & 0 deletions cuequivariance/cuequivariance/segmented_polynomials/visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
import graphviz

import cuequivariance as cue


def visualize_polynomial(
poly: "cue.SegmentedPolynomial",
input_names: list[str],
output_names: list[str],
) -> "graphviz.Digraph":
"""
Create a graphviz diagram showing the dataflow from inputs through STPs to outputs.

Args:
poly: The SegmentedPolynomial to visualize.
input_names: Names for each input operand (length must match poly.num_inputs).
output_names: Names for each output operand (length must match poly.num_outputs).

Returns:
A graphviz.Digraph object that can be rendered, saved, or displayed.

Example:
>>> import cuequivariance as cue
>>> from cuequivariance.segmented_polynomials.visualization import visualize_polynomial
>>> poly = cue.descriptors.spherical_harmonics(cue.SO3(1), [1, 2, 3]).polynomial
>>> graph = visualize_polynomial(poly, ["x"], ["Y"])
>>> graph.render("spherical_harmonics", format="png", cleanup=True) # doctest: +SKIP
>>> # Or in Jupyter:
>>> # graph # Displays inline

Raises:
ValueError: If the number of names doesn't match the number of inputs/outputs.
ImportError: If graphviz is not installed.
"""
# Validate parameters first
if len(input_names) != poly.num_inputs:
raise ValueError(
f"Expected {poly.num_inputs} input names, got {len(input_names)}"
)
if len(output_names) != poly.num_outputs:
raise ValueError(
f"Expected {poly.num_outputs} output names, got {len(output_names)}"
)

# Import graphviz (checked after parameter validation)
try:
import graphviz
except ImportError as e:
raise ImportError(
"graphviz is required for visualization. Install it with: pip install graphviz"
) from e

# Create directed graph
dot = graphviz.Digraph(comment="Segmented Polynomial Flow")
dot.attr(rankdir="LR") # Left to right layout
dot.attr("node", shape="box")

# Create input nodes
for i, (name, operand) in enumerate(zip(input_names, poly.inputs)):
label = f"{name}\\n{operand.num_segments} segments\\nsize={operand.size}"
dot.node(f"input_{i}", label, style="filled", fillcolor="lightblue")

# Create output nodes
for i, (name, operand) in enumerate(zip(output_names, poly.outputs)):
label = f"{name}\\n{operand.num_segments} segments\\nsize={operand.size}"
dot.node(f"output_{i}", label, style="filled", fillcolor="lightgreen")

# Create STP nodes and edges
for stp_idx, (operation, stp) in enumerate(poly.operations):
# Create STP node
stp_label = f"{stp.subscripts}\\n{stp.num_paths} paths"
dot.node(f"stp_{stp_idx}", stp_label, style="filled", fillcolor="lightyellow")

# Create edges from inputs to this STP
for operand_idx in operation.input_buffers(poly.num_inputs):
dot.edge(f"input_{operand_idx}", f"stp_{stp_idx}")

# Create edge from this STP to output
output_buffer = operation.output_buffer(poly.num_inputs)
output_idx = output_buffer - poly.num_inputs
dot.edge(f"stp_{stp_idx}", f"output_{output_idx}")

return dot
6 changes: 6 additions & 0 deletions cuequivariance/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ classifiers = [
"Programming Language :: Python :: 3.12",
]

[project.optional-dependencies]
dev = [
"pytest",
"graphviz",
]

[tool.hatch.version]
path = "cuequivariance/VERSION"
pattern = "(?P<version>\\d+\\.\\d+\\.\\d+(?:[a-z]+\\d+)?)"
Expand Down
3 changes: 2 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ ipykernel
matplotlib
jupyter-sphinx
e3nn
flax
flax
graphviz
3 changes: 2 additions & 1 deletion docs/tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Tutorials

irreps
layout
poly
stp
poly
pytorch/index
jax/index
23 changes: 23 additions & 0 deletions docs/tutorials/jax/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
.. SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

JAX Examples
============

.. toctree::
:maxdepth: 1

poly

Loading