Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,4 +732,8 @@ def main():


if __name__ == "__main__":
import multiprocessing

# Spawn workers instead of fork to prevent inherited CUDA contexts causing Warp errors.
multiprocessing.set_start_method("spawn", force=True)
main()
7 changes: 4 additions & 3 deletions orb_models/common/atoms/graph_featurization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

import numpy as np
import torch
from nvalchemiops.neighborlist import estimate_max_neighbors
from nvalchemiops.neighborlist import neighbor_list as nva_neighbor_list
from nvalchemiops.neighborlist.neighbor_utils import get_neighbor_list_from_neighbor_matrix
from nvalchemiops.neighbors.neighbor_utils import estimate_max_neighbors
from nvalchemiops.torch.neighbors import neighbor_list as nva_neighbor_list
from nvalchemiops.torch.neighbors.neighbor_utils import get_neighbor_list_from_neighbor_matrix
from scipy.spatial import KDTree as SciKDTree

try:
Expand Down Expand Up @@ -748,6 +748,7 @@ def _compute_neighbor_list_with_fallback(
batch_ptr=batch_ptr,
fill_value=fill_value,
max_neighbors=max_num_neighbors_alchemi,
wrap_positions=False, # we handle wrapping externally
)
if max_num_neighbors_alchemi >= num_neighbors.max().item():
return neighbor_matrix, num_neighbors, neighbor_shift_matrix
Expand Down
2 changes: 1 addition & 1 deletion orb_models/forcefield/inference/d3_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import scipy.constants
import torch
from nvalchemiops.interactions.dispersion.dftd3 import D3Parameters, dftd3
from nvalchemiops.torch.interactions.dispersion import D3Parameters, dftd3

from orb_models.common.atoms.batch.graph_batch import AtomGraphs
from orb_models.common.atoms.graph_featurization import _compute_neighbor_list_with_fallback
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies = [
"torch>=2.8.0, <3.0.0",
"dm-tree==0.1.8", # Pinned because of https://github.com/google-deepmind/tree/issues/128
"tqdm>=4.67.1",
"nvalchemi-toolkit-ops>=0.2.0,<0.3.0",
"nvalchemi-toolkit-ops[torch]>=0.3.0",
]

[project.optional-dependencies]
Expand Down
Loading