Enabling Robust In-Context Memory and Rapid Task Adaptation in Transformers with Hebbian and Gradient-Based Plasticity
This repository contains reference implementations and experiment scripts for the paper Enabling Robust In-Context Memory and Rapid Task Adaptation in Transformers with Hebbian and Gradient-Based Plasticity. The code equips decoder-only Transformers with fast-weight components that are updated via neuromodulated Hebbian or gradient-based plasticity rules and evaluates them on the suite of tasks introduced by Duan et al. (2023).
.
├── requirements.txt # Python dependencies
├── src/
│ ├── models/ # Plastic Transformer and Conv-4 encoder
│ ├── tasks/ # Task generators and datasets
│ └── experiments/ # CLI entry points for each benchmark
├── experiments/ # JSON logs written by the runners
├── scripts/ # Result aggregation and plotting utilities
├── figures/ # Generated figures (after running plot_results.py)
└── README.md
python3 -m venv .venv
source .venv/bin/activate
pip install --upgrade pip
pip install -r requirements.txtTip: The experiment entry points use package imports (
src.*). Run all commands from the repository root so that Python can resolve the package correctly.
Each experiment script accepts a consistent set of flags:
--rule {none, hebbian, gradient}selects the plasticity rule.--base-seedsets the first random seed (default123).--seedscontrols how many additional seeds to run (base_seed + ifori in [0, seeds-1]).--output-pathwrites a JSON log containing per-seed histories and aggregate statistics.--device {auto, cpu, cuda, mps}selects the hardware backend (autoprefers CUDA → MPS → CPU).
Example usages are shown below.
python -m src.experiments.copying \
--rule gradient \
--seq-length 5 \
--delay 20 \
--seeds 3 \
--output-path experiments/results/copy_delay20_rule-gradient.jsonpython -m src.experiments.cue_reward \
--rule hebbian \
--num-pairs 8 \
--cue-dim 20 \
--seeds 3 \
--output-path experiments/results/cue_rule-hebbian.jsonpython -m src.experiments.few_shot_regression \
--rule hebbian \
--k-support 10 \
--k-query 10 \
--seeds 3 \
--output-path experiments/results/regression_rule-hebbian.jsonpython -m src.experiments.one_shot_classification \
--rule hebbian \
--dataset cifarfs \
--ways 5 --shots 1 --queries 15 \
--epochs 20 --episodes-per-epoch 200 \
--seeds 3 \
--output-path experiments/results/classification_cifarfs_rule-hebbian.jsonSet --dataset omniglot to train on Omniglot. Torchvision automatically downloads CIFAR-100 and Omniglot into --data-root (default ./data).
-
Aggregate multiple runs
python scripts/aggregate_results.py
This scans
experiments/results/*.jsonand producesexperiments/results/summary.jsonwith per-task, per-rule aggregates. -
Regenerate figures
python scripts/plot_results.py
The script reads from
experiments/results/and writes publication-ready plots tofigures/. -
Build auxiliary tables
python scripts/build_tables.pycollects metrics into CSV/Markdown tables.python scripts/compile_baseline_table.pyreproduces the cross-architecture comparison table.
All result JSON files follow the same schema:
{
"config": { "model": {...}, "task": {...}, "training": {...} },
"runs": [
{"seed": 3000, "history": [...], "final": {...}},
...
],
"aggregate": {
"loss_mean": 0.352,
"loss_std": 0.021,
...
}
}- Run each task for the rules (
none,hebbian,gradient) with three seeds, saving outputs underexperiments/results/. - Execute
python scripts/aggregate_results.py. - Generate plots via
python scripts/plot_results.py. - Update the manuscript or slides using the refreshed tables (
scripts/build_tables.py) and figures (figures/).
The complete campaign (copying, cue–reward, regression, CIFAR-FS, and Omniglot with three seeds each) consumes roughly 25 GPU-hours on a single NVIDIA A100 (40 GB), including diagnostics.
ModuleNotFoundError: No module named 'src'— Ensure commands are launched from the repository root or setPYTHONPATH=..- CUDA requested but not available — Use
--device cpuor install the appropriate CUDA toolkit/driver. - Dataset download stalls — Torchvision downloads CIFAR-100 and Omniglot automatically; if mirrors are blocked, manually place the archives in
./dataand rerun.
For any other issues, inspect the JSON logs in experiments/results/ (they capture per-epoch losses, neuromodulation statistics, and plastic weight norms for each seed).