The tensor_frames package implements the message passing class described in the Beyond Canonicalization: How Tensorial Messages Improve Equivariant Message Passing and also depicted in the above figure. This class generalizes the typical message passing algorithm by transforming features from one node's local frame to another node's frame. This transformation results in an
The TFMessagePassing class is introduced to efficiently implement these layers, abstracting the transformation behavior of the parameters. For predicting local frames, the LearnedLFrames module is available, which calculates the local frame based on a local neighborhood, as described in the paper. Additionally, we provide input and output layers to build fully end-to-end equivariant models, adhering to the guidelines outlined in the referenced paper.
Furthermore, we implemented different representation classes, which define the transformation behavior of features. The TensorReps class allows for the definition of arbitrary cartesian tensor representations, while the Irreps uses the irreducible representation of
Lastly, we also implemented a simple attention-based message passing architecture, called LoCaFormer. This architecture utilizes the TFMessagePassing class and serves as a practical example of how to build equivariant models using the tensor_frames package.
The frame-to-frame transitions TFMessagePassing) enable expressive equivariant message passing between nodes with different local frames. Without frame-to-frame transitions the communication between nodes of different local frames is limited to the exchange of scalar messages:
The whole transformations are abstracted away by the TFMessagePassing class, where every parameter is transformed into the right frame. This class inherits from the MessagePassing class in PyTorch Geometric, which allows for easy integration with existing PyG models. A simple GCNConv-like module could look the following:
from tensor_frames.nn.tfmessage_passing import TFMessagePassing
from tensor_frames.reps.tensorreps import TensorReps
class GCNConv(TFMessagePassing):
def __init__(self, in_reps: TensorReps, out_reps: TensorReps):
super().__init__(
params_dict={
"x": {"type": "local", "rep": in_reps}
}
)
self.linear = torch.nn.Linear(in_reps.dim, out_reps.dim)
def forward(self, edge_index, x, lframes):
return self.propagate(edge_index, x=x, lframes=lframes)
def message(self, x_j):
return self.linear(x_j)
module = GCNConv(TensorReps("16x0n+8x1n"), TensorReps("4x0n+1x1n"))where the PyG equivalent would look like:
from torch_geometric.nn import MessagePassing
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add')
self.linear = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
return self.propagate(edge_index, x=x)
def message(self, x_j):
return self.linear(x_j)Through the TFMessagePassing class the feature x_j is automatically transformed into the local frame of node i. The transformation behavior of the parameters which are parsed in the propagate function can be determined by the params_dict. Through this method every message passing layer, written using the PyG library, can be implemented as an equivariant layer in an efficient way.
For mamba or micromamba replace conda with mamba or micromamba below. (Micromamba is recommended)
# create conda environment and install dependencies
conda env create -f environment.yaml -n tensor_frames
# activate conda environment
conda activate tensor_frames
# install as an editable package (params are used because of vscode autofill)
pip install -e . --config-settings editable_mode=strictIf you find this code useful in your research, please consider citing the following paper:
@inproceedings{lippmann2025beyond,
title={Beyond Canonicalization: How Tensorial Messages Improve Equivariant Message Passing},
author={Lippmann, Peter and Gerhartz, Gerrit and Remme, Roman and Hamprecht, Fred A},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://openreview.net/forum?id=vDp6StrKIq}
}
This package is successfully being used in SOTA machine-learned Orbital-Free Density Function theory and in SOTA machine learning for particle physics:
Stable and Accurate Orbital-Free Density Functional Theory Powered by Machine Learning
Lorentz Local Canonicalization: How to Make Any Network Lorentz-Equivariant
Before starting to commit, run
pre-commit installAfter installing, pre-commit will check the code formatting and much more when you try to commit.
Exact settings can be found in .pre-commit-config.yaml.
If pre-commit fails during committing, check for changed files, stage them again and commit again.
If it still fails, read what is failing and manually fix it. If you want to commit anyway run:
git commit --no-verify -m "message"To go through the tests of the package just run:
pytest