Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,18 @@ def run(args):
)
logging.info("=" * 60)

# Enable/disable stress based on stress_loss_weight
# None = not specified, keep model default; 0 = explicitly disable; >0 = explicitly enable
if args.stress_loss_weight is not None:
if args.stress_loss_weight > 0:
model.enable_stress()
logging.info(
"Stress training ENABLED (stress_loss_weight=%.4f)", args.stress_loss_weight
)
elif model.has_stress:
model.disable_stress()
logging.info("Stress training DISABLED (stress_loss_weight=0.0)")

model_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logging.info(f"Model has {model_params:,} trainable parameters.")

Expand Down
38 changes: 18 additions & 20 deletions orb_models/forcefield/models/conservative_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,21 +94,11 @@ def __init__(
self.forces_target = PROPERTIES[self.forces_name]
self.grad_forces_name = f"{self.grad_prefix}_{self.forces_name}"

# Stress is optional since only periodic systems have it
# Stress names are always derived (from level_of_theory); has_stress toggles computation
self.stress_name: str = f"stress-{level_of_theory}" if level_of_theory else "stress"
self.stress_target: PropertyDefinition = PROPERTIES[self.stress_name]
self.grad_stress_name: str = f"{self.grad_prefix}_{self.stress_name}"
self.has_stress = has_stress
if self.has_stress:
self.stress_name: str | None = (
f"stress-{level_of_theory}" if level_of_theory else "stress"
)
self.stress_target: PropertyDefinition | None = PROPERTIES[self.stress_name]
self.grad_stress_name: str | None = f"{self.grad_prefix}_{self.stress_name}"
else:
self.stress_name = None
self.stress_target = None
self.grad_stress_name = None
assert self.has_stress == (self.grad_stress_name is not None), (
"grad_stress_name must be set if has_stress is True"
)

self.grad_rotation_name = "rotational_grad"

Expand All @@ -117,6 +107,14 @@ def __init__(
if heads[name] is not None:
self.extra_properties.append(heads[name].target.fullname)

def enable_stress(self) -> None:
"""Enable stress computation."""
self.has_stress = True

def disable_stress(self) -> None:
"""Disable stress computation."""
self.has_stress = False

@property
def properties(self):
"""List of names of predicted properties."""
Expand All @@ -126,9 +124,12 @@ def properties(self):
self.grad_forces_name,
self.grad_rotation_name,
]
if self.grad_stress_name is not None:
if self.has_stress:
props.append(self.grad_stress_name)
props.extend(self.extra_properties)
for name in self.extra_properties:
if not self.has_stress and "stress" in name:
continue
props.append(name)
return props

def forward(self, batch: AtomGraphs) -> dict[str, torch.Tensor]:
Expand Down Expand Up @@ -179,7 +180,7 @@ def predict(self, batch: AtomGraphs, split: bool = False) -> dict[str, torch.Ten
energy_head = cast(ForcefieldHead, self.heads[self.energy_name])
out[self.energy_name] = energy_head.denormalize(preds[self.energy_name], batch)
out[self.grad_forces_name] = preds[self.grad_forces_name]
if self.grad_stress_name:
if self.has_stress:
out[self.grad_stress_name] = preds[self.grad_stress_name]
out[self.grad_rotation_name] = preds[self.grad_rotation_name]
for name in self.extra_properties:
Expand Down Expand Up @@ -242,8 +243,6 @@ def loss(self, batch: AtomGraphs) -> base.ModelOutput:

# Conservative stress (optional)
if self.has_stress and self.grad_stress_name in out:
assert self.stress_name is not None
assert self.grad_stress_name is not None
raw_grad_stress_pred = out[self.grad_stress_name]
grad_stress_pred = self.grad_stress_normalizer(raw_grad_stress_pred, online=False)
loss_out = stress_loss_function(
Expand All @@ -267,7 +266,6 @@ def loss(self, batch: AtomGraphs) -> base.ModelOutput:
if self.has_stress and self.grad_stress_name in out
else []
):
assert grad_name is not None
direct_name = grad_name.replace(self.grad_prefix + "_", "")
if direct_name in self.extra_properties:
direct_head = cast(ForcefieldHead, self.heads[direct_name])
Expand Down
27 changes: 25 additions & 2 deletions orb_models/forcefield/models/direct_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
_validate_heads_and_loss_weights(heads, loss_weights)

self.heads = torch.nn.ModuleDict(heads)
self._stress_disabled = False
self.loss_weights = loss_weights
self.model_requires_grad = model_requires_grad
self.cutoff_layers = cutoff_layers
Expand All @@ -82,20 +83,34 @@ def __init__(

@property
def has_stress(self) -> bool:
"""Check if the model has stress prediction."""
return "stress" in self.heads
"""Check if the model has stress prediction and it is enabled."""
return "stress" in self.heads and not self._stress_disabled

def enable_stress(self) -> None:
"""Enable stress computation."""
if "stress" not in self.heads:
raise ValueError("Cannot enable stress: no stress head exists.")
self._stress_disabled = False

def disable_stress(self) -> None:
"""Disable stress computation."""
self._stress_disabled = True

def forward(self, batch: AtomGraphs) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]:
"""Forward pass of DirectForcefieldRegressor."""
out = self.model(batch)
node_features = out["node_features"]
for name, head in self.heads.items():
if self._stress_disabled and "stress" in name:
continue
res = head(node_features, batch)
out[name] = res

if self.pair_repulsion:
out_pair_repulsion = self.pair_repulsion_fn(batch)
for name, head in self.heads.items():
if self._stress_disabled and "stress" in name:
continue
raw_repulsion = self._get_raw_repulsion(name, out_pair_repulsion)
if raw_repulsion is not None:
head = cast(ForcefieldHead, head)
Expand All @@ -109,11 +124,15 @@ def predict(self, batch: AtomGraphs, split: bool = False) -> dict[str, torch.Ten
node_features = out["node_features"]
output = {}
for name, head in self.heads.items():
if self._stress_disabled and "stress" in name:
continue
output[name] = cast(ForcefieldHead | ConfidenceHead, head).predict(node_features, batch)

if self.pair_repulsion:
out_pair_repulsion = self.pair_repulsion_fn(batch)
for name, head in self.heads.items():
if self._stress_disabled and "stress" in name:
continue
raw_repulsion = self._get_raw_repulsion(name, out_pair_repulsion)
if raw_repulsion is not None:
output[name] = output[name] + raw_repulsion
Expand All @@ -138,6 +157,8 @@ def loss(self, batch: AtomGraphs) -> base.ModelOutput:
for name, head in self.heads.items():
if name == "confidence":
continue
if self._stress_disabled and "stress" in name:
continue
head = cast(ForcefieldHead, head)
head_out = head.loss(out[name], batch)
weight = self.loss_weights[name]
Expand Down Expand Up @@ -191,6 +212,8 @@ def load_state_dict(
def properties(self):
"""List of names of predicted properties."""
heads = list(self.heads.keys())
if self._stress_disabled:
heads = [head for head in heads if "stress" not in head]
if "energy" in heads:
heads.append("free_energy")
return heads
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies = [
"torch>=2.8.0, <3.0.0",
"dm-tree==0.1.8", # Pinned because of https://github.com/google-deepmind/tree/issues/128
"tqdm>=4.67.1",
"nvalchemi-toolkit-ops>=0.2.0",
"nvalchemi-toolkit-ops>=0.2.0,<0.3.0",
]

[project.optional-dependencies]
Expand Down
54 changes: 54 additions & 0 deletions tests/forcefield/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,57 @@ def test_calc_non_conservative_defaults(direct_regressor):
"forces",
"stress",
}


class TestStressToggle:
"""Test that enable_stress/disable_stress controls stress in calculator results."""

def test_conservative_stress_disabled(self, conservative_regressor, mptraj_10_systems_db):
conservative_regressor.disable_stress()
calc = ORBCalculator(
model=conservative_regressor,
atoms_adapter=ForcefieldAtomsAdapter(6.0, 20),
)
assert "stress" not in calc.implemented_properties
atoms = mptraj_10_systems_db.get_atoms(1)
calc.calculate(atoms)
assert "stress" not in calc.results
assert "forces" in calc.results

def test_conservative_stress_enabled(self, conservative_regressor, mptraj_10_systems_db):
conservative_regressor.disable_stress()
conservative_regressor.enable_stress()
calc = ORBCalculator(
model=conservative_regressor,
atoms_adapter=ForcefieldAtomsAdapter(6.0, 20),
)
assert "stress" in calc.implemented_properties
atoms = mptraj_10_systems_db.get_atoms(1)
calc.calculate(atoms)
assert "stress" in calc.results
assert "forces" in calc.results

def test_direct_stress_disabled(self, direct_regressor, mptraj_10_systems_db):
direct_regressor.disable_stress()
calc = ORBCalculator(
model=direct_regressor,
atoms_adapter=ForcefieldAtomsAdapter(6.0, 20),
)
assert "stress" not in calc.implemented_properties
atoms = mptraj_10_systems_db.get_atoms(1)
calc.calculate(atoms)
assert "stress" not in calc.results
assert "forces" in calc.results

def test_direct_stress_enabled(self, direct_regressor, mptraj_10_systems_db):
direct_regressor.disable_stress()
direct_regressor.enable_stress()
calc = ORBCalculator(
model=direct_regressor,
atoms_adapter=ForcefieldAtomsAdapter(6.0, 20),
)
assert "stress" in calc.implemented_properties
atoms = mptraj_10_systems_db.get_atoms(1)
calc.calculate(atoms)
assert "stress" in calc.results
assert "forces" in calc.results
46 changes: 46 additions & 0 deletions tests/forcefield/test_orb_torchsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,49 @@ def test_orb_torchsim_interface(edge_method, conservative_regressor, mptraj_10_s
atol=1e-5,
equal_nan=True,
)


class TestOrbTorchSimStressToggle:
"""Test that enable_stress/disable_stress controls stress in OrbTorchSimModel results."""

def test_conservative_stress_disabled(self, conservative_regressor, mptraj_10_systems_db):
conservative_regressor.disable_stress()
atoms_list = [mptraj_10_systems_db.get_atoms(1)]
adapter = ForcefieldAtomsAdapter(6.0, 120)
sim_state = ts.io.atoms_to_state(atoms_list, "cpu", torch.get_default_dtype())
sim_model = OrbTorchSimModel(conservative_regressor, adapter)
results = sim_model(sim_state)
assert "stress" not in results
assert "forces" in results

def test_conservative_stress_enabled(self, conservative_regressor, mptraj_10_systems_db):
conservative_regressor.disable_stress()
conservative_regressor.enable_stress()
atoms_list = [mptraj_10_systems_db.get_atoms(1)]
adapter = ForcefieldAtomsAdapter(6.0, 120)
sim_state = ts.io.atoms_to_state(atoms_list, "cpu", torch.get_default_dtype())
sim_model = OrbTorchSimModel(conservative_regressor, adapter)
results = sim_model(sim_state)
assert "stress" in results
assert "forces" in results

def test_direct_stress_disabled(self, direct_regressor, mptraj_10_systems_db):
direct_regressor.disable_stress()
atoms_list = [mptraj_10_systems_db.get_atoms(1)]
adapter = ForcefieldAtomsAdapter(6.0, 120)
sim_state = ts.io.atoms_to_state(atoms_list, "cpu", torch.get_default_dtype())
sim_model = OrbTorchSimModel(direct_regressor, adapter)
results = sim_model(sim_state)
assert "stress" not in results
assert "forces" in results

def test_direct_stress_enabled(self, direct_regressor, mptraj_10_systems_db):
direct_regressor.disable_stress()
direct_regressor.enable_stress()
atoms_list = [mptraj_10_systems_db.get_atoms(1)]
adapter = ForcefieldAtomsAdapter(6.0, 120)
sim_state = ts.io.atoms_to_state(atoms_list, "cpu", torch.get_default_dtype())
sim_model = OrbTorchSimModel(direct_regressor, adapter)
results = sim_model(sim_state)
assert "stress" in results
assert "forces" in results
Loading