Skip to content

uvm-neurobotics-lab/stitching

Repository files navigation

What is Possible with Neural Network Stitching?

In this work, we seek to push the boundaries of neural network stitching to see what it can and can't do.

Setup

You can install necessary dependencies using the provided environment file:

conda env create -f environment.yml
conda activate stitch

However, many users will need to install PyTorch manually, based on their specific system configuration. In that case,

  1. Create an environment (using your preferred Python version): conda create -n stitch python=3.11
  2. Activate: conda activate stitch
  3. Install PyTorch and Torchvision first.
  4. Manually install the rest of the packages listed in environment.yml. Install conda packages before pip packages.
    • Note: The wandb package comes from the conda-forge channel: conda install wandb -c conda-forge

For convenience, you may consider setting up a symlink to the folder that contains your datasets. Otherwise you must specify the --data-path when you run. For instance:

cd stitching
ln -s ~/datasets ./data

Organization

  • The executable for a single stitching job is src/stitch_train.py.
    • Example configs can be found in tests/.
  • There is also a script which tests stitching across different layer depths and different architecture combinations: src/launch_scaling_experiments.py.
    • This script requires access to a Slurm cluster as it will launch an array of jobs with different configurations.
    • Example configs can be found in across-scales/.
    • Once all experimental results are generated, we use src/across-scales.ipynb to post-process the results and generate all our plots.

Run a Stitching Job

To test model stitching, you can run src/stitch_train.py.

  • This will run a single stitching job. See examples at the top of the file.
  • It will generate a pickled dataframe (result.pkl) which logs each training step, and (optionally) model checkpoints.
  • See tests/ for a list of example configs that can be executed with stitch_train.py. This will give you a sense for the wide range of possible configurations.

A single stitching job consists of the following steps:

  1. Load a configured set of subnets using utils.subgraphs.create_sub_network().
  2. Construct a network with configured stitching modules in between each subnet.
  3. Train the stitching module(s) for a configured number of epochs using a configured optimizer.
  4. Write the training trajectory to a dataframe on disk (result.pkl).

We recommend you create a subfolder experiments/<my-experiment-name> for each experiment. Copy the config here and edit as needed. Then, run from this folder (e.g., python ../../src/stitch_train.py -c ./config.yml). This means the results and all checkpoints will be neatly packaged together with the config that was used to generate them.

You can also run on a Slurm cluster, by customizing one of our example *.sbatch files. From the experiment folder, run sbatch /<full-path-to>/stitching/nvtrain.sbatch stitchup /<full-path-to>/stitching/src/stitch_train.py --config config.yml.

Run a Sweep Over Stitching Gaps and Adapters

Each config in across-scales/ defines all the jobs for a single pair of architectures. The two given architectures (src_stages and dest_stages) are stitched in a number of different ways. A Slurm job is launched for each different way. See examples at the top of src/launch_scaling_experiments.py.

Citation

If you use this work, please cite as:

@inproceedings{traft2025bridging,
  title={Bridging Large Gaps in Neural Network Representations with Model Stitching},
  author={Traft, Neil and Cheney, Nick},
  booktitle={Proceedings of UniReps: the Third Edition of the Workshop on Unifying Representations in Neural Models},
  year={2025},
  organization={PMLR}
}

About

Exploring neural network stitching.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published