diff --git a/finetune.py b/finetune.py index 71b4cdf..3e7a27a 100644 --- a/finetune.py +++ b/finetune.py @@ -732,4 +732,8 @@ def main(): if __name__ == "__main__": + import multiprocessing + + # Spawn workers instead of fork to prevent inherited CUDA contexts causing Warp errors. + multiprocessing.set_start_method("spawn", force=True) main() diff --git a/orb_models/common/atoms/graph_featurization.py b/orb_models/common/atoms/graph_featurization.py index a83c853..1a540f2 100644 --- a/orb_models/common/atoms/graph_featurization.py +++ b/orb_models/common/atoms/graph_featurization.py @@ -5,9 +5,9 @@ import numpy as np import torch -from nvalchemiops.neighborlist import estimate_max_neighbors -from nvalchemiops.neighborlist import neighbor_list as nva_neighbor_list -from nvalchemiops.neighborlist.neighbor_utils import get_neighbor_list_from_neighbor_matrix +from nvalchemiops.neighbors.neighbor_utils import estimate_max_neighbors +from nvalchemiops.torch.neighbors import neighbor_list as nva_neighbor_list +from nvalchemiops.torch.neighbors.neighbor_utils import get_neighbor_list_from_neighbor_matrix from scipy.spatial import KDTree as SciKDTree try: @@ -748,6 +748,7 @@ def _compute_neighbor_list_with_fallback( batch_ptr=batch_ptr, fill_value=fill_value, max_neighbors=max_num_neighbors_alchemi, + wrap_positions=False, # we handle wrapping externally ) if max_num_neighbors_alchemi >= num_neighbors.max().item(): return neighbor_matrix, num_neighbors, neighbor_shift_matrix diff --git a/orb_models/forcefield/inference/d3_model.py b/orb_models/forcefield/inference/d3_model.py index e6c22b5..0a8ad29 100644 --- a/orb_models/forcefield/inference/d3_model.py +++ b/orb_models/forcefield/inference/d3_model.py @@ -3,7 +3,7 @@ import scipy.constants import torch -from nvalchemiops.interactions.dispersion.dftd3 import D3Parameters, dftd3 +from nvalchemiops.torch.interactions.dispersion import D3Parameters, dftd3 from orb_models.common.atoms.batch.graph_batch import AtomGraphs from orb_models.common.atoms.graph_featurization import _compute_neighbor_list_with_fallback diff --git a/pyproject.toml b/pyproject.toml index 1ea19e8..092a6ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "torch>=2.8.0, <3.0.0", "dm-tree==0.1.8", # Pinned because of https://github.com/google-deepmind/tree/issues/128 "tqdm>=4.67.1", - "nvalchemi-toolkit-ops>=0.2.0,<0.3.0", + "nvalchemi-toolkit-ops[torch]>=0.3.0", ] [project.optional-dependencies]