diff --git a/finetune.py b/finetune.py index b1b7a12..71b4cdf 100644 --- a/finetune.py +++ b/finetune.py @@ -514,6 +514,18 @@ def run(args): ) logging.info("=" * 60) + # Enable/disable stress based on stress_loss_weight + # None = not specified, keep model default; 0 = explicitly disable; >0 = explicitly enable + if args.stress_loss_weight is not None: + if args.stress_loss_weight > 0: + model.enable_stress() + logging.info( + "Stress training ENABLED (stress_loss_weight=%.4f)", args.stress_loss_weight + ) + elif model.has_stress: + model.disable_stress() + logging.info("Stress training DISABLED (stress_loss_weight=0.0)") + model_params = sum(p.numel() for p in model.parameters() if p.requires_grad) logging.info(f"Model has {model_params:,} trainable parameters.") diff --git a/orb_models/forcefield/models/conservative_regressor.py b/orb_models/forcefield/models/conservative_regressor.py index 115995f..834e5a6 100644 --- a/orb_models/forcefield/models/conservative_regressor.py +++ b/orb_models/forcefield/models/conservative_regressor.py @@ -94,21 +94,11 @@ def __init__( self.forces_target = PROPERTIES[self.forces_name] self.grad_forces_name = f"{self.grad_prefix}_{self.forces_name}" - # Stress is optional since only periodic systems have it + # Stress names are always derived (from level_of_theory); has_stress toggles computation + self.stress_name: str = f"stress-{level_of_theory}" if level_of_theory else "stress" + self.stress_target: PropertyDefinition = PROPERTIES[self.stress_name] + self.grad_stress_name: str = f"{self.grad_prefix}_{self.stress_name}" self.has_stress = has_stress - if self.has_stress: - self.stress_name: str | None = ( - f"stress-{level_of_theory}" if level_of_theory else "stress" - ) - self.stress_target: PropertyDefinition | None = PROPERTIES[self.stress_name] - self.grad_stress_name: str | None = f"{self.grad_prefix}_{self.stress_name}" - else: - self.stress_name = None - self.stress_target = None - self.grad_stress_name = None - assert self.has_stress == (self.grad_stress_name is not None), ( - "grad_stress_name must be set if has_stress is True" - ) self.grad_rotation_name = "rotational_grad" @@ -117,6 +107,14 @@ def __init__( if heads[name] is not None: self.extra_properties.append(heads[name].target.fullname) + def enable_stress(self) -> None: + """Enable stress computation.""" + self.has_stress = True + + def disable_stress(self) -> None: + """Disable stress computation.""" + self.has_stress = False + @property def properties(self): """List of names of predicted properties.""" @@ -126,9 +124,12 @@ def properties(self): self.grad_forces_name, self.grad_rotation_name, ] - if self.grad_stress_name is not None: + if self.has_stress: props.append(self.grad_stress_name) - props.extend(self.extra_properties) + for name in self.extra_properties: + if not self.has_stress and "stress" in name: + continue + props.append(name) return props def forward(self, batch: AtomGraphs) -> dict[str, torch.Tensor]: @@ -179,7 +180,7 @@ def predict(self, batch: AtomGraphs, split: bool = False) -> dict[str, torch.Ten energy_head = cast(ForcefieldHead, self.heads[self.energy_name]) out[self.energy_name] = energy_head.denormalize(preds[self.energy_name], batch) out[self.grad_forces_name] = preds[self.grad_forces_name] - if self.grad_stress_name: + if self.has_stress: out[self.grad_stress_name] = preds[self.grad_stress_name] out[self.grad_rotation_name] = preds[self.grad_rotation_name] for name in self.extra_properties: @@ -242,8 +243,6 @@ def loss(self, batch: AtomGraphs) -> base.ModelOutput: # Conservative stress (optional) if self.has_stress and self.grad_stress_name in out: - assert self.stress_name is not None - assert self.grad_stress_name is not None raw_grad_stress_pred = out[self.grad_stress_name] grad_stress_pred = self.grad_stress_normalizer(raw_grad_stress_pred, online=False) loss_out = stress_loss_function( @@ -267,7 +266,6 @@ def loss(self, batch: AtomGraphs) -> base.ModelOutput: if self.has_stress and self.grad_stress_name in out else [] ): - assert grad_name is not None direct_name = grad_name.replace(self.grad_prefix + "_", "") if direct_name in self.extra_properties: direct_head = cast(ForcefieldHead, self.heads[direct_name]) diff --git a/orb_models/forcefield/models/direct_regressor.py b/orb_models/forcefield/models/direct_regressor.py index 681c1b5..4d4e727 100644 --- a/orb_models/forcefield/models/direct_regressor.py +++ b/orb_models/forcefield/models/direct_regressor.py @@ -57,6 +57,7 @@ def __init__( _validate_heads_and_loss_weights(heads, loss_weights) self.heads = torch.nn.ModuleDict(heads) + self._stress_disabled = False self.loss_weights = loss_weights self.model_requires_grad = model_requires_grad self.cutoff_layers = cutoff_layers @@ -82,20 +83,34 @@ def __init__( @property def has_stress(self) -> bool: - """Check if the model has stress prediction.""" - return "stress" in self.heads + """Check if the model has stress prediction and it is enabled.""" + return "stress" in self.heads and not self._stress_disabled + + def enable_stress(self) -> None: + """Enable stress computation.""" + if "stress" not in self.heads: + raise ValueError("Cannot enable stress: no stress head exists.") + self._stress_disabled = False + + def disable_stress(self) -> None: + """Disable stress computation.""" + self._stress_disabled = True def forward(self, batch: AtomGraphs) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: """Forward pass of DirectForcefieldRegressor.""" out = self.model(batch) node_features = out["node_features"] for name, head in self.heads.items(): + if self._stress_disabled and "stress" in name: + continue res = head(node_features, batch) out[name] = res if self.pair_repulsion: out_pair_repulsion = self.pair_repulsion_fn(batch) for name, head in self.heads.items(): + if self._stress_disabled and "stress" in name: + continue raw_repulsion = self._get_raw_repulsion(name, out_pair_repulsion) if raw_repulsion is not None: head = cast(ForcefieldHead, head) @@ -109,11 +124,15 @@ def predict(self, batch: AtomGraphs, split: bool = False) -> dict[str, torch.Ten node_features = out["node_features"] output = {} for name, head in self.heads.items(): + if self._stress_disabled and "stress" in name: + continue output[name] = cast(ForcefieldHead | ConfidenceHead, head).predict(node_features, batch) if self.pair_repulsion: out_pair_repulsion = self.pair_repulsion_fn(batch) for name, head in self.heads.items(): + if self._stress_disabled and "stress" in name: + continue raw_repulsion = self._get_raw_repulsion(name, out_pair_repulsion) if raw_repulsion is not None: output[name] = output[name] + raw_repulsion @@ -138,6 +157,8 @@ def loss(self, batch: AtomGraphs) -> base.ModelOutput: for name, head in self.heads.items(): if name == "confidence": continue + if self._stress_disabled and "stress" in name: + continue head = cast(ForcefieldHead, head) head_out = head.loss(out[name], batch) weight = self.loss_weights[name] @@ -191,6 +212,8 @@ def load_state_dict( def properties(self): """List of names of predicted properties.""" heads = list(self.heads.keys()) + if self._stress_disabled: + heads = [head for head in heads if "stress" not in head] if "energy" in heads: heads.append("free_energy") return heads diff --git a/pyproject.toml b/pyproject.toml index 94f6d60..1ea19e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "torch>=2.8.0, <3.0.0", "dm-tree==0.1.8", # Pinned because of https://github.com/google-deepmind/tree/issues/128 "tqdm>=4.67.1", - "nvalchemi-toolkit-ops>=0.2.0", + "nvalchemi-toolkit-ops>=0.2.0,<0.3.0", ] [project.optional-dependencies] diff --git a/tests/forcefield/test_calculator.py b/tests/forcefield/test_calculator.py index b1f0035..e69baa3 100644 --- a/tests/forcefield/test_calculator.py +++ b/tests/forcefield/test_calculator.py @@ -52,3 +52,57 @@ def test_calc_non_conservative_defaults(direct_regressor): "forces", "stress", } + + +class TestStressToggle: + """Test that enable_stress/disable_stress controls stress in calculator results.""" + + def test_conservative_stress_disabled(self, conservative_regressor, mptraj_10_systems_db): + conservative_regressor.disable_stress() + calc = ORBCalculator( + model=conservative_regressor, + atoms_adapter=ForcefieldAtomsAdapter(6.0, 20), + ) + assert "stress" not in calc.implemented_properties + atoms = mptraj_10_systems_db.get_atoms(1) + calc.calculate(atoms) + assert "stress" not in calc.results + assert "forces" in calc.results + + def test_conservative_stress_enabled(self, conservative_regressor, mptraj_10_systems_db): + conservative_regressor.disable_stress() + conservative_regressor.enable_stress() + calc = ORBCalculator( + model=conservative_regressor, + atoms_adapter=ForcefieldAtomsAdapter(6.0, 20), + ) + assert "stress" in calc.implemented_properties + atoms = mptraj_10_systems_db.get_atoms(1) + calc.calculate(atoms) + assert "stress" in calc.results + assert "forces" in calc.results + + def test_direct_stress_disabled(self, direct_regressor, mptraj_10_systems_db): + direct_regressor.disable_stress() + calc = ORBCalculator( + model=direct_regressor, + atoms_adapter=ForcefieldAtomsAdapter(6.0, 20), + ) + assert "stress" not in calc.implemented_properties + atoms = mptraj_10_systems_db.get_atoms(1) + calc.calculate(atoms) + assert "stress" not in calc.results + assert "forces" in calc.results + + def test_direct_stress_enabled(self, direct_regressor, mptraj_10_systems_db): + direct_regressor.disable_stress() + direct_regressor.enable_stress() + calc = ORBCalculator( + model=direct_regressor, + atoms_adapter=ForcefieldAtomsAdapter(6.0, 20), + ) + assert "stress" in calc.implemented_properties + atoms = mptraj_10_systems_db.get_atoms(1) + calc.calculate(atoms) + assert "stress" in calc.results + assert "forces" in calc.results diff --git a/tests/forcefield/test_orb_torchsim.py b/tests/forcefield/test_orb_torchsim.py index 3486230..ac37c7d 100644 --- a/tests/forcefield/test_orb_torchsim.py +++ b/tests/forcefield/test_orb_torchsim.py @@ -70,3 +70,49 @@ def test_orb_torchsim_interface(edge_method, conservative_regressor, mptraj_10_s atol=1e-5, equal_nan=True, ) + + +class TestOrbTorchSimStressToggle: + """Test that enable_stress/disable_stress controls stress in OrbTorchSimModel results.""" + + def test_conservative_stress_disabled(self, conservative_regressor, mptraj_10_systems_db): + conservative_regressor.disable_stress() + atoms_list = [mptraj_10_systems_db.get_atoms(1)] + adapter = ForcefieldAtomsAdapter(6.0, 120) + sim_state = ts.io.atoms_to_state(atoms_list, "cpu", torch.get_default_dtype()) + sim_model = OrbTorchSimModel(conservative_regressor, adapter) + results = sim_model(sim_state) + assert "stress" not in results + assert "forces" in results + + def test_conservative_stress_enabled(self, conservative_regressor, mptraj_10_systems_db): + conservative_regressor.disable_stress() + conservative_regressor.enable_stress() + atoms_list = [mptraj_10_systems_db.get_atoms(1)] + adapter = ForcefieldAtomsAdapter(6.0, 120) + sim_state = ts.io.atoms_to_state(atoms_list, "cpu", torch.get_default_dtype()) + sim_model = OrbTorchSimModel(conservative_regressor, adapter) + results = sim_model(sim_state) + assert "stress" in results + assert "forces" in results + + def test_direct_stress_disabled(self, direct_regressor, mptraj_10_systems_db): + direct_regressor.disable_stress() + atoms_list = [mptraj_10_systems_db.get_atoms(1)] + adapter = ForcefieldAtomsAdapter(6.0, 120) + sim_state = ts.io.atoms_to_state(atoms_list, "cpu", torch.get_default_dtype()) + sim_model = OrbTorchSimModel(direct_regressor, adapter) + results = sim_model(sim_state) + assert "stress" not in results + assert "forces" in results + + def test_direct_stress_enabled(self, direct_regressor, mptraj_10_systems_db): + direct_regressor.disable_stress() + direct_regressor.enable_stress() + atoms_list = [mptraj_10_systems_db.get_atoms(1)] + adapter = ForcefieldAtomsAdapter(6.0, 120) + sim_state = ts.io.atoms_to_state(atoms_list, "cpu", torch.get_default_dtype()) + sim_model = OrbTorchSimModel(direct_regressor, adapter) + results = sim_model(sim_state) + assert "stress" in results + assert "forces" in results