add H3, H3Conv and Hyena as model architecture#289
add H3, H3Conv and Hyena as model architecture#289GuptaVishu2002 wants to merge 28 commits intomasterfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Adds support for the H3, H3Conv, and Hyena architectures across the Snakemake-driven training/sampling pipeline, exposing their configuration via new CLI/config parameters and extending the test suite for training.
Changes:
- Added H3/H3Conv/Hyena model implementations and wired them into
train_models_RNN.pyandsample_molecules_RNN.py. - Extended workflow config and Snakemake rules to pass new model parameters (
bias,use_fast_fftconv,order,filter_order,inner_factor). - Expanded unit tests to cover training runs for the new model types.
Reviewed changes
Copilot reviewed 8 out of 9 changed files in this pull request and generated 13 comments.
Show a summary per file
| File | Description |
|---|---|
| workflow/config/config_fast.yaml | Documents/configures new model types and their parameters for “fast” runs. |
| workflow/config/config.yaml | Documents/configures new model types and their parameters for standard runs. |
| workflow/Snakefile_data | Threads new model parameters through Snakemake CLI invocations for training/sampling. |
| tests/test_snakemake_steps.py | Adds training tests for H3/H3Conv/Hyena and updates existing calls to include new args. |
| src/clm/models.py | Adds H3/H3Conv/Hyena model classes using safari implementations. |
| src/clm/commands/train_models_RNN.py | Adds CLI args and model selection branches for H3/H3Conv/Hyena; forwards new params. |
| src/clm/commands/sample_molecules_RNN.py | Adds CLI args and model selection branches for H3/H3Conv/Hyena; forwards new params. |
| requirements.txt | Adds safari and pins pytorch-lightning/hydra-core. |
| pyproject.toml | Adds safari, hydra-core, and pytorch-lightning as dependencies. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| "s4dd @ git+https://github.com/GuptaVishu2002/s4-for-de-novo-drug-design.git@fix-module-library-packaging", | ||
| "safari @ git+https://github.com/GuptaVishu2002/safari.git@fix-setup", |
There was a problem hiding this comment.
The dependency safari @ git+https://github.com/GuptaVishu2002/safari.git@fix-setup pulls third-party code directly from a Git repository using a mutable ref (fix-setup), which enables supply-chain attacks if that branch/tag is ever compromised or force-moved. An attacker who gains control of that repo could silently change the code at the same ref and have malicious code executed in any environment that installs this project. To mitigate this, pin Git-based dependencies to immutable commit hashes (or published versions on a trusted index) and periodically update them intentionally, rather than tracking branches/tags.
skinnider
left a comment
There was a problem hiding this comment.
@GuptaVishu2002 see a couple potential issues below
| '--n_ssm {MODEL_PARAMS[n_ssm]} ' | ||
| '--n_heads {MODEL_PARAMS[n_heads]} ' | ||
| '--exp_factor {MODEL_PARAMS[exp_factor]} ' | ||
| f'{"--bias" if MODEL_PARAMS["bias"] else ""} ' |
There was a problem hiding this comment.
@GuptaVishu2002 can you double-check that boolean arguments are parsed correctly by the existing combination of config.yaml -> Snakefile_data, if you haven't already? i.e., can the user definitely set bias to True or False within train_models_RNN?
| # ) | ||
|
|
||
| elif model_type == "H3": | ||
| assert ( |
There was a problem hiding this comment.
@GuptaVishu2002 I think there must be an if conditional: missing here, no? Otherwise, why does the H3 require a heldout file? (Is a conditional H3 even implemented? Maybe the assertion should be the opposite, i.e., for all models other than RNN, assert that conditional is not True)
|
@GuptaVishu2002 I read through the PR more carefully today and noticed some stuff I hadn’t before. Sorry to not catch it the first time, but I think some points would be good to clarify and others seem like they definitely need to be addressed:
... and I'll create a separate issue to address the RNN.
I also noticed a few other things that we can address as separate PRs - will create separate issues for those. |
…efault values, update params, remove dead code, correct padding_idx, correct RNN sample
|
Made the following changes as mentioned above
For #293, fix For #294, can assign max_len in the config file now For #295, added eval()/train() for RNN classes to solve dropout issue |
This pull request introduces support for the H3, H3Conv, and Hyena model types in both training and sampling scripts, making them configurable through new command-line arguments. It also adds corresponding test coverage and updates dependencies to facilitate these changes.
Model Support and Argument Handling:
train_models_RNN.pyandsample_molecules_RNN.py, including new arguments (bias,use_fast_fftconv,order,filter_order,inner_factor) to configure these models from the command line.Testing Enhancements:
test_snakemake_steps.pyto include dedicated tests for H3, H3Conv, and Hyena models, as well as updating existing tests to use the new arguments.Dependency Updates:
pyproject.tomlto addsafari,hydra-core, andpytorch-lightningas dependencies, supporting the new models and configuration management.