Skip to content

Commit 8dbe714

Browse files
authored
Merge pull request #18 from erwallace/v0.0.10
v0.0.10
2 parents e977ef9 + e3a638f commit 8dbe714

File tree

6 files changed

+251
-11
lines changed

6 files changed

+251
-11
lines changed

README.md

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,6 @@ class MyCalculator(Calculator):
134134
forces = torch.zeros_like(batch.pos, device=self.device)
135135
# ... fill forces from your model ...
136136
return energies, forces
137-
138-
def to_atomic_data():
139-
pass
140-
141-
def from_atomic_data():
142-
pass
143137
```
144138

145139
### Data Containers
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from ._fairchem import FAIRChemCalculator
22
from ._mace import MACECalculator
3+
from ._mmff94 import MMFF94Calculator
34
from ._rand import RandomCalculator
45

5-
__all__ = ["MACECalculator", "FAIRChemCalculator", "RandomCalculator"]
6+
__all__ = ["MACECalculator", "FAIRChemCalculator", "RandomCalculator", "MMFF94Calculator"]
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from collections.abc import Mapping
2+
from typing import Any
3+
4+
import torch
5+
from ase.units import eV, kcal, mol
6+
from rdkit import Chem
7+
from rdkit.Chem import AllChem, rdDetermineBonds
8+
from torch_geometric.data import Batch, Data
9+
10+
from neural_optimiser.calculators.base import Calculator
11+
12+
KCAL_PER_MOL_TO_EV = kcal / mol / eV
13+
14+
15+
class MMFF94Calculator(Calculator):
16+
"""Calculator using RDKit's implementation of the Merck Molecular Force Field (MMFF94(s))."""
17+
18+
def __init__(
19+
self,
20+
MMFFGetMoleculeProperties: Mapping | None = None,
21+
MMFFGetMoleculeForceField: Mapping | None = None,
22+
):
23+
self.mol_prop = dict(MMFFGetMoleculeProperties or {})
24+
self.mol_ff = dict(MMFFGetMoleculeForceField or {})
25+
self.bond_info = None
26+
self._cache_key = None
27+
28+
def __repr__(self) -> str:
29+
parts = []
30+
if self.mol_prop:
31+
parts.extend(f"{k}={v}" for k, v in self.mol_prop.items())
32+
if self.mol_ff:
33+
parts.extend(f"{k}={v}" for k, v in self.mol_ff.items())
34+
args = ", ".join(parts)
35+
return f"MMFFCalculator({args})" if args else "MMFFCalculator()"
36+
37+
def _ensure_single_conformer(self, batch: Data | Batch) -> None:
38+
"""Ensure that the batch contains only a single conformer."""
39+
if isinstance(batch, Batch) and batch.batch.unique().size(0) != 1:
40+
raise ValueError("MMFFCalculator only supports single-conformer batches.")
41+
42+
def _to_mol(self, batch: Data | Batch) -> Chem.RWMol:
43+
"""Convert a single-conformer batch to an RDKit RWMol."""
44+
if isinstance(batch, Batch):
45+
return Chem.RWMol(batch.to_rdkit()[0])
46+
return Chem.RWMol(batch.to_rdkit())
47+
48+
def _get_charge(self, mol: Chem.Mol) -> int:
49+
"""Get the formal charge of the molecule from its properties, defaulting to 0."""
50+
return int(mol.GetProp("charge")) if mol.HasProp("charge") else 0
51+
52+
def _molecule_key(self, mol: Chem.Mol) -> str:
53+
"""Generate a unique key for the molecule based on its SMILES or atom types."""
54+
if mol.HasProp("smiles"):
55+
s = mol.GetProp("smiles")
56+
if s:
57+
return s
58+
# Fallback to SMILES; as a last resort, atom sequence signature
59+
try:
60+
return Chem.MolToSmiles(mol, canonical=True)
61+
except Exception:
62+
return ",".join(str(a.GetAtomicNum()) for a in mol.GetAtoms())
63+
64+
def _prepare_mol(self, mol: Chem.Mol) -> Chem.Mol:
65+
"""Determine or restore bonds, then sanitize. Uses cached bond topology per molecule key."""
66+
charge = self._get_charge(mol)
67+
new_key = self._molecule_key(mol)
68+
69+
if new_key != self._cache_key:
70+
rdDetermineBonds.DetermineBonds(mol, charge=charge)
71+
self.bond_info = [
72+
(b.GetBeginAtomIdx(), b.GetEndAtomIdx(), b.GetBondType()) for b in mol.GetBonds()
73+
]
74+
self._cache_key = new_key
75+
else:
76+
if self.bond_info:
77+
for begin, end, bond_type in self.bond_info:
78+
mol.AddBond(begin, end, bond_type)
79+
80+
Chem.SanitizeMol(mol)
81+
return mol
82+
83+
def _build_forcefield(self, mol: Chem.Mol) -> Any:
84+
"""Build the MMFF force field for the given molecule."""
85+
mp = AllChem.MMFFGetMoleculeProperties(mol, **self.mol_prop)
86+
ff = AllChem.MMFFGetMoleculeForceField(mol, mp, **self.mol_ff)
87+
return ff
88+
89+
def _calculate(self, batch: Data | Batch) -> tuple[torch.Tensor, torch.Tensor]:
90+
"""Compute energies and forces for a batch of conformers using MMFF94."""
91+
self._ensure_single_conformer(batch)
92+
mol = self._to_mol(batch)
93+
mol = self._prepare_mol(mol)
94+
95+
ff = self._build_forcefield(mol)
96+
energy = torch.tensor([ff.CalcEnergy()])
97+
grad = (
98+
torch.Tensor(ff.CalcGrad()) * -KCAL_PER_MOL_TO_EV
99+
) # Convert to eV/Å for convergence criteria
100+
return energy, grad.reshape(-1, 3)
101+
102+
def get_energies(self, batch: Data | Batch) -> torch.Tensor:
103+
"""Compute energies for a batch of conformers using MMFF94."""
104+
self._ensure_single_conformer(batch)
105+
mol = self._to_mol(batch)
106+
mol = self._prepare_mol(mol)
107+
108+
ff = self._build_forcefield(mol)
109+
return torch.tensor([ff.CalcEnergy()])

src/neural_optimiser/calculators/base.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,9 @@ def __call__(self, batch: Data | Batch) -> tuple[torch.Tensor, torch.Tensor]:
1616
self.device = batch.pos.device
1717
return self._calculate(batch)
1818

19-
@abstractmethod
2019
def __repr__(self):
2120
...
2221

23-
@abstractmethod
2422
def get_energies(self, batch: Data | Batch) -> torch.Tensor:
2523
"""Get only energies from the calculator."""
2624
...
@@ -30,7 +28,6 @@ def _calculate(self, batch: Data | Batch) -> tuple[torch.Tensor, torch.Tensor]:
3028
"""Return (energies, forces) from the calculator."""
3129
...
3230

33-
@abstractmethod
3431
def to_atomic_data(self, batch: Data | Batch) -> Batch:
3532
"""Convert to AtomicData format compatible with ML model used."""
3633
...

tests/calculators/test_mmff94.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import pytest
2+
import torch
3+
from ase.units import eV, kcal, mol
4+
from neural_optimiser.calculators import MMFF94Calculator
5+
from neural_optimiser.conformers import Conformer
6+
from rdkit import Chem
7+
from rdkit.Chem import AllChem, rdDetermineBonds
8+
9+
KCAL_PER_MOL_TO_EV = kcal / mol / eV
10+
11+
12+
def test_MMFF94Calculator_calculate(mol):
13+
"""Compare energy and forces to RDKit's MMFF94 implementation."""
14+
mp = AllChem.MMFFGetMoleculeProperties(mol)
15+
ff = AllChem.MMFFGetMoleculeForceField(mol, mp)
16+
_energy = ff.CalcEnergy()
17+
_forces = torch.tensor(ff.CalcGrad()) * -KCAL_PER_MOL_TO_EV
18+
19+
data = Conformer.from_rdkit(mol)
20+
calc = MMFF94Calculator()
21+
energy, forces = calc._calculate(data)
22+
23+
assert torch.isclose(energy, torch.tensor(_energy))
24+
assert torch.allclose(forces.flatten(), torch.Tensor(_forces))
25+
26+
27+
def test_repr_no_args():
28+
"""Test the __repr__ method with no arguments."""
29+
calc = MMFF94Calculator()
30+
assert repr(calc) == "MMFFCalculator()"
31+
32+
33+
def test_repr_with_args_only_affects_repr_not_behavior():
34+
"""Test the __repr__ method with arguments."""
35+
calc = MMFF94Calculator(
36+
MMFFGetMoleculeProperties={"mmffVariant": "MMFF94s"},
37+
MMFFGetMoleculeForceField={"nonBondedThresh": 10.0},
38+
)
39+
s = repr(calc)
40+
assert "mmffVariant=MMFF94s" in s
41+
assert "nonBondedThresh=10.0" in s
42+
43+
44+
def test_single_conformer_energy_and_forces_shape(batch):
45+
"""Test that energy is scalar and forces have correct shape."""
46+
calc = MMFF94Calculator()
47+
energy, forces = calc._calculate(batch)
48+
49+
# energy is a scalar tensor
50+
assert isinstance(energy, torch.Tensor)
51+
assert energy.ndim == 1
52+
53+
# forces match [n_atoms, 3]
54+
assert isinstance(forces, torch.Tensor)
55+
assert forces.ndim == 2 and forces.shape[1] == 3
56+
assert forces.shape[0] == batch.pos.shape[0]
57+
58+
59+
def test_get_energies_matches_calculate(batch):
60+
"""Test that get_energies matches the energy from _calculate."""
61+
calc = MMFF94Calculator()
62+
e_only = calc.get_energies(batch)
63+
e_calc, _ = calc._calculate(batch)
64+
65+
assert torch.allclose(e_only, e_calc)
66+
67+
68+
def test_multi_conformer_raises(minimised_batch):
69+
"""Test that multi-conformer batches raise ValueError."""
70+
calc = MMFF94Calculator()
71+
with pytest.raises(ValueError):
72+
calc._calculate(minimised_batch)
73+
with pytest.raises(ValueError):
74+
calc.get_energies(minimised_batch)
75+
76+
77+
def test_bond_determination_cached(monkeypatch, batch):
78+
"""Test that bond determination is cached based on molecule key."""
79+
calls = {"n": 0}
80+
orig = rdDetermineBonds.DetermineBonds
81+
82+
def wrapper(mol, *args, **kwargs):
83+
calls["n"] += 1
84+
return orig(mol, *args, **kwargs)
85+
86+
monkeypatch.setattr(rdDetermineBonds, "DetermineBonds", wrapper)
87+
88+
calc = MMFF94Calculator()
89+
90+
# First call should determine bonds once
91+
_ = calc._calculate(batch)
92+
assert calls["n"] == 1
93+
94+
# Second prepare on a new mol with same key should reuse cached bonds
95+
mol2 = Chem.RWMol(batch.to_rdkit()[0])
96+
_ = calc._prepare_mol(mol2)
97+
assert calls["n"] == 1 # unchanged
98+
99+
100+
def test_smiles_property_controls_cache_key(monkeypatch, batch):
101+
"""Test that the 'smiles' property controls the caching of bond determination."""
102+
calls = {"n": 0}
103+
orig = rdDetermineBonds.DetermineBonds
104+
105+
def wrapper(mol, *args, **kwargs):
106+
calls["n"] += 1
107+
return orig(mol, *args, **kwargs)
108+
109+
monkeypatch.setattr(rdDetermineBonds, "DetermineBonds", wrapper)
110+
111+
calc = MMFF94Calculator()
112+
113+
# First molecule with a specific 'smiles' property
114+
mol1 = Chem.RWMol(batch.to_rdkit()[0])
115+
mol1.SetProp("smiles", "KEY")
116+
_ = calc._prepare_mol(mol1)
117+
assert calls["n"] == 1
118+
119+
# Same key => should not re-run DetermineBonds
120+
mol2 = Chem.RWMol(batch.to_rdkit()[0])
121+
mol2.SetProp("smiles", "KEY")
122+
_ = calc._prepare_mol(mol2)
123+
assert calls["n"] == 1
124+
125+
# Different key => should run DetermineBonds again
126+
mol3 = Chem.RWMol(batch.to_rdkit()[0])
127+
mol3.SetProp("smiles", "KEY2")
128+
_ = calc._prepare_mol(mol3)
129+
assert calls["n"] == 2

tests/optimise/test_bfgs.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
import torch
55
from neural_optimiser import test_dir
6-
from neural_optimiser.calculators import MACECalculator
6+
from neural_optimiser.calculators import MACECalculator, MMFF94Calculator
77
from neural_optimiser.conformers import Conformer, ConformerBatch
88
from neural_optimiser.optimisers import BFGS
99

@@ -115,6 +115,16 @@ def test_bfgs_integration(atoms, atoms2):
115115
assert converged is True
116116

117117

118+
def test_bfgs_integration2(atoms2):
119+
"""Test BFGS integration with MMFF94Calculator on CPU."""
120+
batch = ConformerBatch.from_ase([atoms2])
121+
122+
optimiser = BFGS(steps=100, fmax=0.05, fexit=500.0)
123+
optimiser.calculator = MMFF94Calculator()
124+
converged = optimiser.run(batch)
125+
assert converged is True
126+
127+
118128
def test_bfgs_integration_gpu(atoms, atoms2):
119129
"""Test BFGS integration with MACECalculator on GPU."""
120130
pytest.importorskip("mace", reason="MACE not installed")

0 commit comments

Comments
 (0)