Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/doc_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Install system dependencies
run: |
sudo apt-get update
sudo apt-get install -y graphviz
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/doc_build_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Install system dependencies
run: |
sudo apt-get update
sudo apt-get install -y graphviz
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
7 changes: 5 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install system dependencies
run: |
sudo apt-get update
sudo apt-get install -y graphviz
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install --upgrade uv
python -m uv pip install pytest
python -m uv pip install -e ./cuequivariance
python -m uv pip install -e ./cuequivariance[dev]
- name: Test with pytest
run: |
pytest --doctest-modules -x -m "not slow" cuequivariance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from .operation import Operation
from .segmented_polynomial import SegmentedPolynomial
from .visualization import visualize_polynomial


__all__ = [
Expand All @@ -37,4 +38,5 @@
"dispatch",
"Operation",
"SegmentedPolynomial",
"visualize_polynomial",
]
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import copy
import dataclasses
import itertools
from typing import Any, Callable, Sequence
from typing import Any, Callable, Optional, Sequence

import numpy as np

Expand Down Expand Up @@ -998,7 +998,10 @@ def jvp(
self, has_tangent: list[bool]
) -> tuple[
SegmentedPolynomial,
Callable[[tuple[list[Any], list[Any]]], tuple[list[Any], list[Any]]],
Callable[
[tuple[list[Any], list[Any]], Optional[Callable[[Any], Any]]],
tuple[list[Any], list[Any]],
],
]:
"""Compute the Jacobian-vector product of the polynomial.

Expand All @@ -1023,14 +1026,20 @@ def jvp(
):
new_operations.append((ope, multiplicator * stp))

def mapping(x: tuple[list[Any], list[Any]]) -> tuple[list[Any], list[Any]]:
def mapping(
x: tuple[list[Any], list[Any]],
into_grad: Optional[Callable[[Any], Any]] = None,
) -> tuple[list[Any], list[Any]]:
inputs, outputs = x
inputs, outputs = list(inputs), list(outputs)
assert len(inputs) == self.num_inputs
assert len(outputs) == self.num_outputs
into_grad = into_grad if callable(into_grad) else lambda x: x

new_inputs = inputs + [x for has, x in zip(has_tangent, inputs) if has]
new_outputs = outputs
new_inputs = inputs + [
into_grad(x) for has, x in zip(has_tangent, inputs) if has
]
new_outputs = [into_grad(x) for x in outputs]

return new_inputs, new_outputs

Expand Down Expand Up @@ -1088,7 +1097,10 @@ def backward(
self, requires_gradient: list[bool], has_cotangent: list[bool]
) -> tuple[
SegmentedPolynomial,
Callable[[tuple[list[Any], list[Any]]], tuple[list[Any], list[Any]]],
Callable[
[tuple[list[Any], list[Any]], Optional[Callable[[Any], Any]]],
tuple[list[Any], list[Any]],
],
]:
"""Compute the backward pass of the polynomial for gradient computation.

Expand All @@ -1106,7 +1118,10 @@ def backward(
has_cotangent,
)

def mapping(x: tuple[list[Any], list[Any]]) -> tuple[list[Any], list[Any]]:
return map2(map1(x))
def mapping(
x: tuple[list[Any], list[Any]],
into_grad: Optional[Callable[[Any], Any]] = None,
) -> tuple[list[Any], list[Any]]:
return map2(map1(x, into_grad))

return p, mapping
102 changes: 102 additions & 0 deletions cuequivariance/cuequivariance/segmented_polynomials/visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
import graphviz

import cuequivariance as cue


def visualize_polynomial(
poly: "cue.SegmentedPolynomial",
input_names: list[str],
output_names: list[str],
) -> "graphviz.Digraph":
"""
Create a graphviz diagram showing the dataflow from inputs through STPs to outputs.

Args:
poly: The SegmentedPolynomial to visualize.
input_names: Names for each input operand (length must match poly.num_inputs).
output_names: Names for each output operand (length must match poly.num_outputs).

Returns:
A graphviz.Digraph object that can be rendered, saved, or displayed.

Example:
>>> import cuequivariance as cue
>>> from cuequivariance.segmented_polynomials.visualization import visualize_polynomial
>>> poly = cue.descriptors.spherical_harmonics(cue.SO3(1), [1, 2, 3]).polynomial
>>> graph = visualize_polynomial(poly, ["x"], ["Y"])
>>> graph.render("spherical_harmonics", format="png", cleanup=True) # doctest: +SKIP
>>> # Or in Jupyter:
>>> # graph # Displays inline

Raises:
ValueError: If the number of names doesn't match the number of inputs/outputs.
ImportError: If graphviz is not installed.
"""
# Validate parameters first
if len(input_names) != poly.num_inputs:
raise ValueError(
f"Expected {poly.num_inputs} input names, got {len(input_names)}"
)
if len(output_names) != poly.num_outputs:
raise ValueError(
f"Expected {poly.num_outputs} output names, got {len(output_names)}"
)

# Import graphviz (checked after parameter validation)
try:
import graphviz
except ImportError as e:
raise ImportError(
"graphviz is required for visualization. Install it with: pip install graphviz"
) from e

# Create directed graph
dot = graphviz.Digraph(comment="Segmented Polynomial Flow")
dot.attr(rankdir="LR") # Left to right layout
dot.attr("node", shape="box")

# Create input nodes
for i, (name, operand) in enumerate(zip(input_names, poly.inputs)):
label = f"{name}\\n{operand.num_segments} segments\\nsize={operand.size}"
dot.node(f"input_{i}", label, style="filled", fillcolor="lightblue")

# Create output nodes
for i, (name, operand) in enumerate(zip(output_names, poly.outputs)):
label = f"{name}\\n{operand.num_segments} segments\\nsize={operand.size}"
dot.node(f"output_{i}", label, style="filled", fillcolor="lightgreen")

# Create STP nodes and edges
for stp_idx, (operation, stp) in enumerate(poly.operations):
# Create STP node
stp_label = f"{stp.subscripts}\\n{stp.num_paths} paths"
dot.node(f"stp_{stp_idx}", stp_label, style="filled", fillcolor="lightyellow")

# Create edges from inputs to this STP
for operand_idx in operation.input_buffers(poly.num_inputs):
dot.edge(f"input_{operand_idx}", f"stp_{stp_idx}")

# Create edge from this STP to output
output_buffer = operation.output_buffer(poly.num_inputs)
output_idx = output_buffer - poly.num_inputs
dot.edge(f"stp_{stp_idx}", f"output_{output_idx}")

return dot
6 changes: 6 additions & 0 deletions cuequivariance/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ classifiers = [
"Programming Language :: Python :: 3.12",
]

[project.optional-dependencies]
dev = [
"pytest",
"graphviz",
]

[tool.hatch.version]
path = "cuequivariance/VERSION"
pattern = "(?P<version>\\d+\\.\\d+\\.\\d+(?:[a-z]+\\d+)?)"
Expand Down
3 changes: 2 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ ipykernel
matplotlib
jupyter-sphinx
e3nn
flax
flax
graphviz
84 changes: 83 additions & 1 deletion docs/tutorials/poly.rst
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,86 @@ This hierarchical structure allows for efficient representation and computation
.. jupyter-execute::

p.operations


Visualization
-------------

You can visualize the dataflow of a :class:`cue.SegmentedPolynomial <cuequivariance.SegmentedPolynomial>` using graphviz. This creates a diagram showing how inputs flow through segmented tensor products to produce outputs.

First, install graphviz:

.. code-block:: bash

pip install graphviz

Then create a visualization:

.. jupyter-execute::

from cuequivariance.segmented_polynomials import visualize_polynomial

# Visualize the spherical harmonics polynomial
sh_poly = cue.descriptors.spherical_harmonics(cue.SO3(1), [1, 2]).polynomial
graph = visualize_polynomial(sh_poly, input_names=["x"], output_names=["Y"])

# Display the graph (in Jupyter it renders inline)
graph

The diagram shows:

* **Input nodes** (blue): Display the input name, number of segments, and total size
* **STP nodes** (yellow): Show the subscripts and number of computation paths
* **Output nodes** (green): Display the output name, number of segments, and total size
* **Edges**: Represent the dataflow, with multiple edges drawn when an input is used multiple times

You can save the diagram to a file:

.. jupyter-execute::
:hide-output:

# Save as PNG (or 'svg', 'pdf', etc.)
graph.render('spherical_harmonics', format='png', cleanup=True)

For more complex examples:

.. jupyter-execute::

# Visualize a linear layer
irreps_in = cue.Irreps("O3", "8x0e + 8x1o")
irreps_out = cue.Irreps("O3", "4x0e + 4x1o")
linear_poly = cue.descriptors.linear(irreps_in, irreps_out).polynomial

graph = visualize_polynomial(linear_poly, input_names=["weights", "input"], output_names=["output"])
graph

.. jupyter-execute::

# Visualize a tensor product
irreps = cue.Irreps("O3", "0e + 1o")
tp_poly = cue.descriptors.channelwise_tensor_product(irreps, irreps, irreps).polynomial

graph = visualize_polynomial(tp_poly, input_names=["weights", "x1", "x2"], output_names=["y"])
graph

Visualizing Backward Pass
^^^^^^^^^^^^^^^^^^^^^^^^^^

You can also visualize the backward pass of a polynomial. The mapping function returned by :meth:`cue.SegmentedPolynomial.backward <cuequivariance.SegmentedPolynomial.backward>` accepts an optional `into_grad` parameter that can transform operand names, which is useful for labeling gradients:

.. jupyter-execute::

# Create a polynomial and compute its backward pass
irreps = cue.Irreps("O3", "0e + 1o")
tp_poly = cue.descriptors.channelwise_tensor_product(irreps, irreps, irreps).polynomial

# Compute backward pass (all inputs require gradients, output has cotangent)
poly_bwd, m = tp_poly.backward([True, True, True], [True])

# Transform operand names using the mapping function with into_grad
# The mapping function takes (inputs, outputs) and returns (new_inputs, new_outputs)
operand_names = (["weights", "x1", "x2"], ["y"])
operand_names_bwd = m(operand_names, lambda n: f"d{n}")

# Visualize the backward polynomial
graph = visualize_polynomial(poly_bwd, input_names=operand_names_bwd[0], output_names=operand_names_bwd[1])
graph
49 changes: 49 additions & 0 deletions test_visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#!/usr/bin/env python3
"""Test script for the visualize_polynomial function (works without graphviz)."""

import cuequivariance as cue


def test_visualization_api():
"""Test that the visualization function has the correct API without rendering."""
from cuequivariance.segmented_polynomials import visualize_polynomial

# Create a simple polynomial
sh_poly = cue.descriptors.spherical_harmonics(cue.SO3(1), [1, 2]).polynomial

print("Testing visualize_polynomial API...")
print(f"Polynomial: {sh_poly}")
print(f" num_inputs: {sh_poly.num_inputs}")
print(f" num_outputs: {sh_poly.num_outputs}")
print(f" num_operations: {len(sh_poly.operations)}")
print()

# Test error handling for wrong number of names
try:
visualize_polynomial(sh_poly, ["x", "y"], ["Y"]) # Too many input names
print("❌ Should have raised ValueError for wrong number of inputs")
except ValueError as e:
print(f"✓ Correctly raised ValueError: {e}")

try:
visualize_polynomial(sh_poly, ["x"], ["Y", "Z"]) # Too many output names
print("❌ Should have raised ValueError for wrong number of outputs")
except ValueError as e:
print(f"✓ Correctly raised ValueError: {e}")

# Test that it raises ImportError if graphviz is not installed
try:
graph = visualize_polynomial(sh_poly, ["x"], ["Y"])
print("✓ graphviz is installed, graph created successfully")
print(f" Graph type: {type(graph)}")
# Print the DOT source
print("\nGenerated DOT source:")
print(graph.source)
except ImportError as e:
print(f"✓ Correctly raised ImportError when graphviz not installed: {e}")

print("\n✓ All API tests passed!")


if __name__ == "__main__":
test_visualization_api()