Skip to content

sciai-lab/tensor_frames

Repository files navigation

tensor_frames: Expressive Equivariant Message Passing via Local Canonicalization

python pytorch black

Description

The tensor_frames package implements the message passing class described in the Beyond Canonicalization: How Tensorial Messages Improve Equivariant Message Passing and also depicted in the above figure. This class generalizes the typical message passing algorithm by transforming features from one node's local frame to another node's frame. This transformation results in an $O(N)$ invariant layer, which can be used to construct fully equivariant architectures:

$$ f_i^{(k)}=\psi^{(k)}\bigg( f_i^{(k-1)}, \bigoplus_{j\in\mathcal{N}}\phi^{(k)}\left(f_i^{(k-1)},\rho(g_i g_j^{-1})f_j^{(k-1)}, \rho_e(g_i)e_{ji}, R_i(\mathbf x_i - \mathbf x_j)\right) \bigg) $$

The TFMessagePassing class is introduced to efficiently implement these layers, abstracting the transformation behavior of the parameters. For predicting local frames, the LearnedLFrames module is available, which calculates the local frame based on a local neighborhood, as described in the paper. Additionally, we provide input and output layers to build fully end-to-end equivariant models, adhering to the guidelines outlined in the referenced paper.

Furthermore, we implemented different representation classes, which define the transformation behavior of features. The TensorReps class allows for the definition of arbitrary cartesian tensor representations, while the Irreps uses the irreducible representation of $\mathrm{O}(3)$. Both classes are efficiently implemented and can be used interchangeably.

Lastly, we also implemented a simple attention-based message passing architecture, called LoCaFormer. This architecture utilizes the TFMessagePassing class and serves as a practical example of how to build equivariant models using the tensor_frames package.

The frame-to-frame transitions $\rho(g_i g_j^{-1})f_j^{(k-1)}$ of tensorial messages (implemented by TFMessagePassing) enable expressive equivariant message passing between nodes with different local frames. Without frame-to-frame transitions the communication between nodes of different local frames is limited to the exchange of scalar messages:

Create your own module

The whole transformations are abstracted away by the TFMessagePassing class, where every parameter is transformed into the right frame. This class inherits from the MessagePassing class in PyTorch Geometric, which allows for easy integration with existing PyG models. A simple GCNConv-like module could look the following:

from tensor_frames.nn.tfmessage_passing import TFMessagePassing
from tensor_frames.reps.tensorreps import TensorReps

class GCNConv(TFMessagePassing):
    def __init__(self, in_reps: TensorReps, out_reps: TensorReps):
        super().__init__(
            params_dict={
                "x": {"type": "local", "rep": in_reps}
            }
        )
        self.linear = torch.nn.Linear(in_reps.dim, out_reps.dim)

    def forward(self, edge_index, x, lframes):
        return self.propagate(edge_index, x=x, lframes=lframes)

    def message(self, x_j):
        return self.linear(x_j)

module = GCNConv(TensorReps("16x0n+8x1n"), TensorReps("4x0n+1x1n"))

where the PyG equivalent would look like:

from torch_geometric.nn import MessagePassing

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')
        self.linear = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)

    def message(self, x_j):
        return self.linear(x_j)

Through the TFMessagePassing class the feature x_j is automatically transformed into the local frame of node i. The transformation behavior of the parameters which are parsed in the propagate function can be determined by the params_dict. Through this method every message passing layer, written using the PyG library, can be implemented as an equivariant layer in an efficient way.

Installation

Install using Conda/Mamba/Micromamba

For mamba or micromamba replace conda with mamba or micromamba below. (Micromamba is recommended)

# create conda environment and install dependencies
conda env create -f environment.yaml -n tensor_frames

# activate conda environment
conda activate tensor_frames

# install as an editable package (params are used because of vscode autofill)
pip install -e . --config-settings editable_mode=strict

Citation

If you find this code useful in your research, please consider citing the following paper:

@inproceedings{lippmann2025beyond,
  title={Beyond Canonicalization: How Tensorial Messages Improve Equivariant Message Passing},
  author={Lippmann, Peter and Gerhartz, Gerrit and Remme, Roman and Hamprecht, Fred A},
  booktitle={The Thirteenth International Conference on Learning Representations},
  year={2025},
  url={https://openreview.net/forum?id=vDp6StrKIq}
}

Interesting Applications of tensor_frames

This package is successfully being used in SOTA machine-learned Orbital-Free Density Function theory and in SOTA machine learning for particle physics:
Stable and Accurate Orbital-Free Density Functional Theory Powered by Machine Learning
Lorentz Local Canonicalization: How to Make Any Network Lorentz-Equivariant

Developer Info

Pre-commit

Before starting to commit, run

pre-commit install

After installing, pre-commit will check the code formatting and much more when you try to commit. Exact settings can be found in .pre-commit-config.yaml. If pre-commit fails during committing, check for changed files, stage them again and commit again. If it still fails, read what is failing and manually fix it. If you want to commit anyway run:

git commit --no-verify -m "message"

Testing

To go through the tests of the package just run:

pytest

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 3

  •  
  •  
  •  

Languages