-
Notifications
You must be signed in to change notification settings - Fork 23
Open
Labels
bugSomething isn't workingSomething isn't workingfeature requestNew feature requestNew feature request
Description
I am encountering an issue when attempting to use the triangle_multiplicative_update function from the cuequivariance-torch library. Despite triton==3.3.0 was installed as recommended by the traceback, the function still fails to import.
Environment
GPU and CUDA version
- NVIDIA GeForce RTX 3090
- CUDA 12.6 (installed from
conda)
conda environment
channels:
- conda-forge
- nvidia/label/cuda-12.6.0
dependencies:
- nvidia/label/cuda-12.6.0::cuda-toolkit
- python=3.12
- pip:
- cuequivariance-ops-torch-cu12==0.6.0
- cuequivariance-torch==0.6.0
- torch==2.7.0+cu126
- torchaudio==2.7.0+cu126
- torchvision==0.22.0+cu126Minimal Reproducible code
import torch
import triton
from cuequivariance_torch import triangle_multiplicative_update
print(f"PyTroch vresion: {torch.__version__}\nTriton version: {triton.__version__}")
if torch.cuda.is_available():
device = torch.device("cuda")
batch_size, seq_len, hidden_dim = 1, 128, 128
# Create input tensor
x = torch.randn(batch_size, seq_len, seq_len, hidden_dim, requires_grad=True, device=device)
# Create mask (1 for valid positions, 0 for masked)
mask = torch.ones(batch_size, seq_len, seq_len, device=device)
# Perform triangular multiplication
output = triangle_multiplicative_update(
x=x,
direction="outgoing", # or "incoming"
mask=mask,
)
print(output.shape) # torch.Size([1, 128, 128, 128])
# Create gradient tensor and perform backward pass
grad_out = torch.randn_like(output)
output.backward(grad_out)
# Access gradients
print(x.grad.shape)Output
PyTroch vresion: 2.7.0+cu126
Triton version: 3.3.0
Traceback (most recent call last):
File "/path/to/cuequivariance-test/test.py", line 14, in <module>
output = triangle_multiplicative_update(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/path/to/env/lib/python3.12/site-packages/cuequivariance_torch/primitives/triangle.py", line 231, in triangle_multiplicative_update
return f(
^^
File "/path/to/env/lib/python3.12/site-packages/cuequivariance_ops_torch/__init__.py", line 72, in triangle_multiplicative_update
raise Exception(
Exception: Failed to import Triton-based component: triangle_multiplicative_update:
Not Supported
Please make sure to install triton==3.3.0. Other versions may not work!
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingfeature requestNew feature requestNew feature request