Conversation
971dd29 to
b2937bc
Compare
timduignan
left a comment
There was a problem hiding this comment.
Great thanks Vaidas,
Sorry I forgot about fine-tuning. I think there may be an issue though still:
pair_repulsion + disabled stress in direct_regressor.py
In forward() and predict(), the new continue guards only protect the first head iteration loop. The pair_repulsion block still iterates over
all heads including stress:
direct_regressor.py forward() lines 96-103 (current main):
if self.pair_repulsion:
out_pair_repulsion = self.pair_repulsion_fn(batch)
for name, head in self.heads.items(): # <-- iterates stress head
raw_repulsion = self._get_raw_repulsion(name, out_pair_repulsion)
if raw_repulsion is not None:
head = cast(ForcefieldHead, head)
raw = head.denormalize(out[name], batch) # <-- KeyError: stress not in out
Same issue in predict() lines 114-119. If pair_repulsion=True and stress is disabled, this will raise KeyError.
While pair_repulsion defaults to False and likely isn't used with omol models today, it's a latent bug. The fix is straightforward — add the
same guard in the pair_repulsion loops:
for name, head in self.heads.items():
if self._stress_disabled and "stress" in name:
continue
...
Some minor observations
- properties property not updated in DirectForcefieldRegressor: It returns list(self.heads.keys()) unconditionally, so "stress" appears even
when disabled. If downstream code (e.g., ASE calculator) uses properties to decide what the model predicts, this could cause confusion or
errors. Consider filtering:
heads = [k for k in self.heads.keys() if not (self._stress_disabled and "stress" in k)] - Redundant asserts in conservative_regressor.py loss(): Lines 245-246 (assert self.stress_name is not None / assert self.grad_stress_name
is not None) are now always true since the names are always set. Not harmful, but they no longer serve as useful guards. Could be removed for
clarity. - Negative stress_loss_weight: In finetune.py, only > 0 and == 0 are handled. A negative value would silently fall through (neither enabling
nor disabling). Could add a validation or note in the argparse help.
|
Thanks, good catch! Fixed it and added tests for ase calculator and torchsim integration. We do need to clean up the Note: I've limited nvalchemiops to <0.3.0. They've released a new version today, which breaks the old API. We will make the updates for the newer version in a followup. |
Minor improvement allowing enabling/disabling of stress computation for conservative models, and just disabling it for direct models.
In
finetune.pysetting--stress_loss_weight=0setshas_stress=False, which preventsAseSqliteDatasetfrom trying to retrieve stress data.Fixes #153.