From 2a02db2fe279fc8d0a8441f926ba270526e8973b Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 24 Feb 2026 00:46:08 +0000 Subject: [PATCH] Add virology ERA gradient routing experiments Co-Authored-By: Claude Opus 4.6 --- projects/virology_era/__init__.py | 0 projects/virology_era/evaluate_virology.py | 311 +++++++++++++ .../outputs/dry_run/ablation_comparison.json | 10 + .../virology_era/outputs/dry_run/results.json | 15 + projects/virology_era/virology_data.py | 107 +++++ projects/virology_era/virology_era.py | 433 ++++++++++++++++++ projects/virology_era/virology_settings.py | 73 +++ runs/virology_era.sbatch | 92 ++++ runs/virology_era_dryrun.sbatch | 30 ++ 9 files changed, 1071 insertions(+) create mode 100644 projects/virology_era/__init__.py create mode 100644 projects/virology_era/evaluate_virology.py create mode 100644 projects/virology_era/outputs/dry_run/ablation_comparison.json create mode 100644 projects/virology_era/outputs/dry_run/results.json create mode 100644 projects/virology_era/virology_data.py create mode 100644 projects/virology_era/virology_era.py create mode 100644 projects/virology_era/virology_settings.py create mode 100644 runs/virology_era.sbatch create mode 100644 runs/virology_era_dryrun.sbatch diff --git a/projects/virology_era/__init__.py b/projects/virology_era/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/projects/virology_era/evaluate_virology.py b/projects/virology_era/evaluate_virology.py new file mode 100644 index 0000000..d99a960 --- /dev/null +++ b/projects/virology_era/evaluate_virology.py @@ -0,0 +1,311 @@ +""" +Evaluation for virology ERA models. + +Two approaches: +1. Convert TransformerLens state dict back to HF GPT-NeoX format, save, and + run lm_eval via subprocess (for multi-GPU eval). +2. Custom lm_eval.LM wrapper for TransformerLens HookedTransformer (fallback). + +Usage: + cd /home/a6a/lucia.a6a/gradient-routing + python projects/virology_era/evaluate_virology.py [--hf-convert] +""" + +import argparse +import json +import os +import shutil +import subprocess +import sys +from contextlib import contextmanager +from typing import List, Tuple + +import einops +import lm_eval +import torch +import torch.nn.functional as F +from lm_eval.api.instance import Instance +from lm_eval.api.model import LM +from tqdm import tqdm +from transformer_lens import HookedTransformer + +from projects.virology_era.virology_settings import VirologyERAConfig + +LM_EVAL_TASKS_PATH = "/home/a6a/lucia.a6a/unlearn/unlearn/lm_eval_tasks" + + +# ============================================================ +# Approach 1: Convert TL weights -> HF GPT-NeoX and use lm_eval CLI +# ============================================================ + +def convert_tl_to_neox_state_dict(tl_state_dict, cfg): + """Reverse the TransformerLens convert_neox_weights transformation. + + Maps TransformerLens parameter names back to HuggingFace GPT-NeoX names. + """ + hf_state_dict = {} + + hf_state_dict["gpt_neox.embed_in.weight"] = tl_state_dict["embed.W_E"] + + for l in range(cfg.n_layers): + # Layer norms + hf_state_dict[f"gpt_neox.layers.{l}.input_layernorm.weight"] = tl_state_dict[f"blocks.{l}.ln1.w"] + hf_state_dict[f"gpt_neox.layers.{l}.input_layernorm.bias"] = tl_state_dict[f"blocks.{l}.ln1.b"] + hf_state_dict[f"gpt_neox.layers.{l}.post_attention_layernorm.weight"] = tl_state_dict[f"blocks.{l}.ln2.w"] + hf_state_dict[f"gpt_neox.layers.{l}.post_attention_layernorm.bias"] = tl_state_dict[f"blocks.{l}.ln2.b"] + + # QKV: reverse of rearrange(W, "(i qkv h) m -> qkv i m h", i=n_heads, qkv=3) + W_Q = tl_state_dict[f"blocks.{l}.attn.W_Q"] # [n_heads, d_model, d_head] + W_K = tl_state_dict[f"blocks.{l}.attn.W_K"] + W_V = tl_state_dict[f"blocks.{l}.attn.W_V"] + W_qkv = torch.stack([W_Q, W_K, W_V], dim=0) # [3, n_heads, d_model, d_head] + W_qkv_flat = einops.rearrange(W_qkv, "qkv i m h -> (i qkv h) m") + hf_state_dict[f"gpt_neox.layers.{l}.attention.query_key_value.weight"] = W_qkv_flat + + # QKV bias: reverse of rearrange(bias, "(index qkv head) -> qkv index head") + b_Q = tl_state_dict[f"blocks.{l}.attn.b_Q"] # [n_heads, d_head] + b_K = tl_state_dict[f"blocks.{l}.attn.b_K"] + b_V = tl_state_dict[f"blocks.{l}.attn.b_V"] + b_qkv = torch.stack([b_Q, b_K, b_V], dim=0) # [3, n_heads, d_head] + b_qkv_flat = einops.rearrange(b_qkv, "qkv index head -> (index qkv head)") + hf_state_dict[f"gpt_neox.layers.{l}.attention.query_key_value.bias"] = b_qkv_flat + + # W_O: reverse of rearrange(W_O, "m (i h) -> i h m", i=n_heads) + W_O = tl_state_dict[f"blocks.{l}.attn.W_O"] # [n_heads, d_head, d_model] + W_O_flat = einops.rearrange(W_O, "i h m -> m (i h)") + hf_state_dict[f"gpt_neox.layers.{l}.attention.dense.weight"] = W_O_flat + + hf_state_dict[f"gpt_neox.layers.{l}.attention.dense.bias"] = tl_state_dict[f"blocks.{l}.attn.b_O"] + + # MLP: reverse .T + hf_state_dict[f"gpt_neox.layers.{l}.mlp.dense_h_to_4h.weight"] = tl_state_dict[f"blocks.{l}.mlp.W_in"].T + hf_state_dict[f"gpt_neox.layers.{l}.mlp.dense_h_to_4h.bias"] = tl_state_dict[f"blocks.{l}.mlp.b_in"] + hf_state_dict[f"gpt_neox.layers.{l}.mlp.dense_4h_to_h.weight"] = tl_state_dict[f"blocks.{l}.mlp.W_out"].T + hf_state_dict[f"gpt_neox.layers.{l}.mlp.dense_4h_to_h.bias"] = tl_state_dict[f"blocks.{l}.mlp.b_out"] + + # Final layer norm + hf_state_dict["gpt_neox.final_layer_norm.weight"] = tl_state_dict["ln_final.w"] + hf_state_dict["gpt_neox.final_layer_norm.bias"] = tl_state_dict["ln_final.b"] + + # Unembed: reverse .T + hf_state_dict["embed_out.weight"] = tl_state_dict["unembed.W_U"].T + + return hf_state_dict + + +def save_as_hf_model(tl_state_dict_path: str, output_dir: str, cfg=None): + """Load a TL state dict, convert to HF format, and save as an HF model.""" + from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + from projects.virology_era.virology_era import load_neox_model + + if cfg is None: + cfg = VirologyERAConfig() + + print(f"Loading TL state dict from {tl_state_dict_path}") + tl_state_dict = torch.load(tl_state_dict_path, map_location="cpu") + + # Load original HF model config and tokenizer + hf_config = AutoConfig.from_pretrained(cfg.model_path) + tokenizer = AutoTokenizer.from_pretrained(cfg.model_path) + + # Load the TL model config to get dimensions + tl_model = load_neox_model(cfg.model_path, device="cpu", dtype=torch.bfloat16) + tl_cfg = tl_model.cfg + del tl_model + + # Convert weights + print("Converting TL -> HF state dict") + hf_state_dict = convert_tl_to_neox_state_dict(tl_state_dict, tl_cfg) + + # Create HF model and load weights + print("Creating HF model and loading converted weights") + hf_model = AutoModelForCausalLM.from_config(hf_config, torch_dtype=torch.bfloat16) + missing, unexpected = hf_model.load_state_dict(hf_state_dict, strict=False) + if missing: + print(f"Warning: missing keys: {missing}") + if unexpected: + print(f"Warning: unexpected keys: {unexpected}") + + # Save + os.makedirs(output_dir, exist_ok=True) + hf_model.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + print(f"HF model saved to {output_dir}") + return output_dir + + +def run_lm_eval_subprocess(model_path: str, tasks: list[str], batch_size: int = 32, num_fewshot: int = 0): + """Run lm_eval via subprocess for multi-GPU evaluation.""" + results = {} + for task in tasks: + cmd = [ + sys.executable, "-m", "lm_eval", + "--model", "hf", + "--model_args", f"pretrained={model_path}", + "--tasks", task, + "--batch_size", str(batch_size), + ] + if task in ("mmlu",): + cmd.extend(["--num_fewshot", str(num_fewshot or 1)]) + if task in ("wmdp_bio_robust",): + cmd.extend(["--include_path", LM_EVAL_TASKS_PATH]) + cmd.extend(["--verbosity", "WARNING"]) + + # Write results to temp dir + output_dir = f"/tmp/lm_eval_results_{task}" + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + cmd.extend(["--output_path", output_dir]) + + print(f"Running: {' '.join(cmd)}") + result = subprocess.run(cmd, capture_output=True, text=True) + print(result.stdout[-2000:] if result.stdout else "") + if result.returncode != 0: + print(f"lm_eval failed for {task}:") + print(result.stderr[-2000:] if result.stderr else "") + + # Parse results from output dir + task_results = _parse_lm_eval_results(output_dir, task) + results[task] = task_results + + return results + + +def _parse_lm_eval_results(output_dir: str, task_name: str) -> dict: + """Parse lm_eval JSON results from output directory.""" + results = {} + try: + for root, dirs, files in os.walk(output_dir): + for f in files: + if f.endswith(".json"): + with open(os.path.join(root, f)) as fh: + data = json.load(fh) + if "results" in data: + results = data["results"] + break + except Exception as e: + print(f"Warning: could not parse results for {task_name}: {e}") + return results + + +# ============================================================ +# Approach 2: Custom lm_eval.LM wrapper for HookedTransformer +# ============================================================ + +class HookedTransformerLM(LM): + """lm_eval wrapper for TransformerLens HookedTransformer models.""" + + def __init__(self, model: HookedTransformer, device: torch.device): + super().__init__() + self.model = model + self.device = device + self.tokenizer = model.tokenizer + + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, int]]: + outputs = [] + for req in tqdm(requests, desc="loglikelihood"): + context, continuation = req.args + context_enc = self.tokenizer.encode(context) + continuation_enc = self.tokenizer.encode(continuation) + full_enc = context_enc + continuation_enc + + inp = torch.tensor(full_enc).to(self.device).unsqueeze(0) + + with torch.inference_mode(): + logits = self.model(inp[:, :-1]) + log_probs = F.log_softmax(logits, dim=-1) + + cont_log_probs = log_probs[:, -len(continuation_enc):] + greedy_tokens = cont_log_probs.argmax(dim=-1) + cont_toks = torch.tensor(continuation_enc, dtype=torch.long).unsqueeze(0).to(self.device) + + is_top = (greedy_tokens == cont_toks).all() + gathered = torch.gather(cont_log_probs, 2, cont_toks.unsqueeze(-1)).squeeze(-1) + ll = gathered.sum() + + outputs.append((ll.item(), int(is_top.item()))) + return outputs + + def loglikelihood_rolling(self, requests: List[Instance]) -> List[Tuple[float]]: + raise NotImplementedError + + def generate_until(self, requests: List[Instance]) -> List[str]: + raise NotImplementedError + + @contextmanager + def eval_mode(self): + was_training = self.model.training + self.model.eval() + yield + self.model.train(was_training) + + +def eval_tl_model_directly( + state_dict_path: str, + tasks: list[str], + device: torch.device, + cfg: VirologyERAConfig = None, +) -> dict: + """Evaluate a TL model directly using the HookedTransformerLM wrapper.""" + from projects.virology_era.virology_era import load_neox_model + + if cfg is None: + cfg = VirologyERAConfig() + + print(f"Loading TL model for direct evaluation") + model = load_neox_model(cfg.model_path, device=device, dtype=torch.bfloat16) + state_dict = torch.load(state_dict_path, map_location=device) + model.load_state_dict(state_dict, strict=False) + + wrapped = HookedTransformerLM(model, device) + with wrapped.eval_mode(): + results = lm_eval.simple_evaluate( + model=wrapped, + tasks=tasks, + include_path=LM_EVAL_TASKS_PATH, + ) + return results.get("results", {}) + + +# ============================================================ +# Main +# ============================================================ + +def main(): + parser = argparse.ArgumentParser(description="Evaluate virology ERA model") + parser.add_argument("state_dict_path", help="Path to TL state dict .pt file") + parser.add_argument("--hf-convert", action="store_true", + help="Convert to HF and use subprocess lm_eval (for multi-GPU)") + parser.add_argument("--output-dir", default=None, + help="Directory for HF model output (default: /hf_model)") + parser.add_argument("--tasks", nargs="+", default=["wmdp_bio_robust", "mmlu"], + help="lm_eval tasks to run") + parser.add_argument("--direct", action="store_true", + help="Use direct TL evaluation (single GPU, no conversion)") + args = parser.parse_args() + + cfg = VirologyERAConfig() + + if args.direct: + from factored_representations.utils import get_gpu_with_most_memory + device = get_gpu_with_most_memory() + results = eval_tl_model_directly(args.state_dict_path, args.tasks, device, cfg) + print(json.dumps(results, indent=2)) + return + + if args.hf_convert: + output_dir = args.output_dir or os.path.join( + os.path.dirname(args.state_dict_path), "hf_model" + ) + save_as_hf_model(args.state_dict_path, output_dir, cfg) + results = run_lm_eval_subprocess(output_dir, args.tasks) + print(json.dumps(results, indent=2)) + + results_path = os.path.join(os.path.dirname(args.state_dict_path), "eval_results.json") + with open(results_path, "w") as f: + json.dump(results, f, indent=2) + print(f"Results saved to {results_path}") + + +if __name__ == "__main__": + main() diff --git a/projects/virology_era/outputs/dry_run/ablation_comparison.json b/projects/virology_era/outputs/dry_run/ablation_comparison.json new file mode 100644 index 0000000..2d75a93 --- /dev/null +++ b/projects/virology_era/outputs/dry_run/ablation_comparison.json @@ -0,0 +1,10 @@ +{ + "pre_ablation": { + "forget_loss": 1.9230653964556181, + "retain_loss": 1.913012724656325 + }, + "post_ablation": { + "forget_loss": 1.9231660549457257, + "retain_loss": 1.9132063113726103 + } +} \ No newline at end of file diff --git a/projects/virology_era/outputs/dry_run/results.json b/projects/virology_era/outputs/dry_run/results.json new file mode 100644 index 0000000..3c7ec41 --- /dev/null +++ b/projects/virology_era/outputs/dry_run/results.json @@ -0,0 +1,15 @@ +{ + "pre_ablation": { + "forget_loss": 1.9230653964556181, + "retain_loss": 1.913012724656325 + }, + "post_ablation": { + "forget_loss": 1.9231660549457257, + "retain_loss": 1.9132063113726103 + }, + "final": { + "forget_loss": 1.9231660549457257, + "retain_loss": 1.9132063113726103 + }, + "best_coherence_step": 0 +} \ No newline at end of file diff --git a/projects/virology_era/virology_data.py b/projects/virology_era/virology_data.py new file mode 100644 index 0000000..b45308a --- /dev/null +++ b/projects/virology_era/virology_data.py @@ -0,0 +1,107 @@ +"""Data loading for virology ERA unlearning. + +Forget: WMDP Bio Remove Dataset (virology/bio papers) +Retain: WikiText-103 (general text) + +Produces (text, label) pairs where label=0 is forget, label=1 is retain. +""" + +import numpy as np +from datasets import load_dataset + + +def load_forget_data(max_examples: int | None = None) -> list[tuple[str, int]]: + """Load WMDP Bio Remove Dataset as forget data (label=0). + + Concatenates title + abstract + text for each document. + """ + split = f"train[:{max_examples}]" if max_examples else "train" + ds = load_dataset("Unlearning/WMDP-Bio-Remove-Dataset", split=split) + + forget_data = [] + for row in ds: + parts = [] + if row["title"]: + parts.append(row["title"]) + if row["abstract"]: + parts.append(row["abstract"]) + if row["text"]: + parts.append(row["text"]) + text = "\n\n".join(parts) + if text.strip(): + forget_data.append((text, 0)) + + return forget_data + + +def load_retain_data(max_examples: int | None = None) -> list[tuple[str, int]]: + """Load WikiText-103 as retain data (label=1). + + Uses the document-level version with column 'page'. + """ + split = f"train[:{max_examples}]" if max_examples else "train" + ds = load_dataset( + "EleutherAI/wikitext_document_level", + "wikitext-103-raw-v1", + split=split, + ) + + retain_data = [] + for row in ds: + text = row["page"] + if text and text.strip(): + retain_data.append((text, 1)) + + return retain_data + + +def load_training_data( + max_forget: int | None = None, + max_retain: int | None = None, + seed: int = 42, +) -> list[tuple[str, int]]: + """Load and shuffle combined forget + retain training data.""" + forget = load_forget_data(max_forget) + retain = load_retain_data(max_retain) + + print(f"Loaded {len(forget)} forget examples, {len(retain)} retain examples") + + combined = forget + retain + rng = np.random.default_rng(seed) + rng.shuffle(combined) + return combined + + +def load_validation_data( + n_forget: int = 200, + n_retain: int = 200, + seed: int = 123, +) -> tuple[list[tuple[str, int]], list[tuple[str, int]]]: + """Load validation splits for forget and retain. + + Uses the end of the forget dataset and WikiText validation split. + """ + # Forget validation: take from the end of the training set + all_forget = load_forget_data() + rng = np.random.default_rng(seed) + rng.shuffle(all_forget) + forget_val = all_forget[-n_forget:] + + # Retain validation: use WikiText validation split + ds = load_dataset( + "EleutherAI/wikitext_document_level", + "wikitext-103-raw-v1", + split=f"validation[:{n_retain}]", + ) + retain_val = [(row["page"], 1) for row in ds if row["page"] and row["page"].strip()] + + return forget_val, retain_val + + +if __name__ == "__main__": + data = load_training_data(max_forget=100, max_retain=100) + print(f"Total training examples: {len(data)}") + print(f"First example label: {data[0][1]}, text[:80]: {data[0][0][:80]}") + + fval, rval = load_validation_data(n_forget=10, n_retain=10) + print(f"Validation: {len(fval)} forget, {len(rval)} retain") diff --git a/projects/virology_era/virology_era.py b/projects/virology_era/virology_era.py new file mode 100644 index 0000000..5b99178 --- /dev/null +++ b/projects/virology_era/virology_era.py @@ -0,0 +1,433 @@ +""" +ERA (Expand-Route-Ablate) gradient routing for virology unlearning. + +Applies ERA to EleutherAI/deep-ignorance-unfiltered (6.85B GPT-NeoX) +using WMDP Bio Remove Dataset as forget data. + +Usage: + cd /home/a6a/lucia.a6a/gradient-routing + python projects/virology_era/virology_era.py [dry_run] +""" + +import json +import math +import os +import sys +from copy import deepcopy +from typing import Optional, Tuple + +import numpy as np +import torch as t +import torch.utils.data as data +import tqdm +import wandb +from transformer_lens import HookedTransformer +from transformer_lens.loading_from_pretrained import ( + get_pretrained_model_config, + get_pretrained_state_dict, +) + +import factored_representations.model_expansion as model_expansion +import factored_representations.training as training +from factored_representations import masklib, string_utils + +from projects.virology_era.virology_data import ( + load_retain_data, + load_training_data, + load_validation_data, +) +from projects.virology_era.virology_settings import VirologyERAConfig + + +def load_neox_model(model_path: str, device, dtype=t.bfloat16) -> HookedTransformer: + """Load a GPT-NeoX model from a local HF path into HookedTransformer. + + Bypasses the official model name check in HookedTransformer.from_pretrained + by calling the config/state_dict loaders directly (which do support local paths). + """ + cfg = get_pretrained_model_config( + model_path, + device=device, + dtype=dtype, + fold_ln=False, + ) + + state_dict = get_pretrained_state_dict(model_path, cfg, hf_model=None, dtype=dtype) + + model = HookedTransformer(cfg, move_to_device=False) + model.load_and_process_state_dict( + state_dict, + fold_ln=False, + center_writing_weights=False, + center_unembed=False, + fold_value_biases=False, + ) + model.move_model_modules_to_device() + return model + + +def full_seq_mask_rule(labels, seq_length, device): + """0 = forget set, 1 = retain set.""" + return labels.unsqueeze(1).repeat(1, seq_length).to(device) + + +def convert_to_labeled_mask_rule(mask_rule, label_setting="retain_always_unmasked"): + """Wrap a token mask rule to handle (input_ids, labels) pairs.""" + if label_setting == "retain_always_unmasked": + def unmasked_retain_mask_rule(input_ids_and_labels): + input_ids, labels = input_ids_and_labels + original_mask = mask_rule(input_ids) + is_retain = full_seq_mask_rule( + labels, input_ids.shape[1] - 1, original_mask.device + ) + return t.maximum(is_retain, original_mask) + return unmasked_retain_mask_rule + else: + raise ValueError(f"Unknown label setting: {label_setting}") + + +@t.inference_mode() +def eval_on_validation( + model, + validation_data: list[tuple], + truncate_at: int, +) -> float: + """Evaluate model loss on validation data.""" + dataloader = data.DataLoader( + string_utils.ListDataset(validation_data), batch_size=4, shuffle=False + ) + device = t.device(model.cfg.device) + batch_losses = [] + for batch in dataloader: + stories, labels = batch + tokens, attention_mask = string_utils.tokenize_batch( + stories, + model.tokenizer, + prepend_bos=True, + truncate_at=truncate_at, + padding_side="right", + device=device, + ) + with t.autocast(device_type="cuda", dtype=t.bfloat16, enabled=device.type == "cuda"): + loss = training.compute_preds_and_get_ce_loss( + model, tokens, attention_mask, None + ) + batch_losses.append(loss.item()) + return sum(batch_losses) / len(batch_losses) if batch_losses else float("inf") + + +def do_virology_era_run( + cfg: VirologyERAConfig, + device: t.device, + save_dir: str, + dry_run: bool = False, +): + """Run full ERA pipeline: expand+route, ablate, coherence finetune.""" + + os.makedirs(save_dir, exist_ok=True) + + project_dir = os.path.dirname(os.path.abspath(__file__)) + wandb.login() + wandb.init( + project=cfg.wandb_project, + mode="disabled" if dry_run else "online", + name=f"virology-era-{'dryrun' if dry_run else 'full'}", + config=cfg.__dict__, + settings=wandb.Settings(code_dir=project_dir), + dir=project_dir, + ) + + # ---- Load model ---- + print("Loading model from", cfg.model_path) + model = load_neox_model(cfg.model_path, device=device, dtype=t.bfloat16) + original_model_config = model.cfg + print(f"Model loaded: {model.cfg.n_layers} layers, d_model={model.cfg.d_model}, d_mlp={model.cfg.d_mlp}") + + # ---- Load data ---- + print("Loading data...") + if dry_run: + training_data = load_training_data(max_forget=50, max_retain=50, seed=42) + else: + training_data = load_training_data(seed=42) + + forget_val, retain_val = load_validation_data( + n_forget=50 if dry_run else 200, + n_retain=50 if dry_run else 200, + ) + + # ---- Build token freq mask rule (needed even for full_seq scheme) ---- + # For full_seq with ddbp, we use the token freq mask rule but override with + # full-sequence labels via convert_to_labeled_mask_rule + forget_texts = [text for text, label in training_data if label == 0] + retain_texts = [text for text, label in training_data if label == 1] + + token_freq_kwargs = dict( + retain_stories=retain_texts[:5000], + forget_stories=forget_texts[:5000], + num_stories=min(5000, len(forget_texts), len(retain_texts)), + truncate_at=None, + num_synthetic_tokens_retain=20, + num_synthetic_tokens_forget=1, + scale=1.0, + bias=-4.0, + tokenizer=model.tokenizer, + device=device, + ) + + token_mask_rule, _ = masklib.get_token_freq_masking_rule(**token_freq_kwargs) + mask_rule = convert_to_labeled_mask_rule(token_mask_rule, "retain_always_unmasked") + + # ---- ERA setup: expand model ---- + print(f"Expanding model: d_mlp += {cfg.d_mlp_expansion} on layers {cfg.layers_to_mask}") + model, specs = model_expansion.expand_and_get_mask_specs( + model, + cfg.to_expand, + layers_to_mask=cfg.layers_to_mask, + masking_rule=mask_rule, + suppress_warnings=False, + **cfg.expanded_vs_original_dim_learning_rates, + weight_initialization_coeff=1.0, + ) + # expand_model init_weights() creates new dims in float32; cast to match loaded weights + model = model.to(t.bfloat16) + print(f"Expanded model: d_mlp={model.cfg.d_mlp}") + + mask_applier = masklib.MaskApplier( + model, + specs, + use_partial_boolean_masks=True, # full_seq is in SCHEMES_WITH_PARTIAL_WEIGHTS + ) + # MaskApplier precomputes masks in float32; cast to bfloat16 to match model + mask_applier.mask_lookup_tensors = [ + m.to(t.bfloat16) for m in mask_applier.mask_lookup_tensors + ] + + dataloader = data.DataLoader( + string_utils.ListDataset(training_data), + shuffle=False, + batch_size=cfg.batch_size, + ) + + optim = t.optim.AdamW( + model.parameters(), + lr=cfg.learning_rate, + **cfg.optimizer_kwargs, + ) + + num_training_steps = min( + len(training_data) // cfg.batch_size, + cfg.num_steps_era_training, + ) + + def get_lr(it): + min_lr = cfg.learning_rate / 10 + warmup_iters = 100 + lr_decay_iters = num_training_steps + if it < warmup_iters: + return cfg.learning_rate * it / warmup_iters + if it > lr_decay_iters: + return min_lr + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) + return min_lr + coeff * (cfg.learning_rate - min_lr) + + # ---- PHASE 1: Expand and Route ---- + print("PHASE ONE: EXPAND AND ROUTE") + total_steps = min(cfg.num_steps_era_training, len(dataloader)) + eval_every = 8 if dry_run else 250 + + for step, batch in (pbar := tqdm.tqdm(enumerate(dataloader), total=total_steps)): + if step >= total_steps: + break + + lr = get_lr(step) + for param_group in optim.param_groups: + param_group["lr"] = lr + + stories, labels = batch + input_ids, attention_mask = string_utils.tokenize_batch( + stories, + model.tokenizer, + prepend_bos=True, + truncate_at=cfg.truncate_batch_tokens_at, + padding_side="right", + device=device, + ) + + with mask_applier((input_ids, labels), mask_weight=1.0): + with t.autocast(device_type="cuda", dtype=t.bfloat16): + loss = training.compute_preds_and_get_ce_loss( + model, input_ids, attention_mask, None + ) + + loss = loss / cfg.grad_accum_steps + loss.backward() + + if (step + 1) % cfg.grad_accum_steps == 0: + optim.step() + optim.zero_grad() + + effective_loss = loss.item() * cfg.grad_accum_steps + wandb.log({"loss": effective_loss, "lr": lr, "step": step}) + pbar.set_postfix({"loss": f"{effective_loss:.4f}"}) + + if step % eval_every == 0: + forget_loss = eval_on_validation(model, forget_val[:20], cfg.truncate_batch_tokens_at) + retain_loss = eval_on_validation(model, retain_val[:20], cfg.truncate_batch_tokens_at) + wandb.log({ + "validation_forget_loss": forget_loss, + "validation_retain_loss": retain_loss, + }) + print(f" Step {step}: forget_loss={forget_loss:.4f}, retain_loss={retain_loss:.4f}") + + # Pre-ablation eval + pre_forget = eval_on_validation(model, forget_val, cfg.truncate_batch_tokens_at) + pre_retain = eval_on_validation(model, retain_val, cfg.truncate_batch_tokens_at) + pre_ablation = {"forget_loss": pre_forget, "retain_loss": pre_retain} + print(f"Pre-ablation: {pre_ablation}") + wandb.log({"pre_ablation_forget_loss": pre_forget, "pre_ablation_retain_loss": pre_retain}) + + # Save pre-ablation model + pre_ablation_path = os.path.join(save_dir, "pre_ablation.pt") + t.save(model.state_dict(), pre_ablation_path) + print(f"Pre-ablation model saved to {pre_ablation_path}") + + # ---- PHASE 2: Ablate + Coherence Finetune ---- + print("PHASE TWO: ABLATE AND COHERENCE FINETUNE") + + # Free ERA optimizer/masks before creating contracted model to avoid OOM + del mask_applier + del optim + del specs + del dataloader + t.cuda.empty_cache() + + contracted_model = model_expansion.contract_model(model, original_model_config) + # contract_model creates views into expanded model tensors; clone to break dependency + for param in contracted_model.parameters(): + param.data = param.data.clone() + del model + t.cuda.empty_cache() + + # Post-ablation eval + post_forget = eval_on_validation(contracted_model, forget_val, cfg.truncate_batch_tokens_at) + post_retain = eval_on_validation(contracted_model, retain_val, cfg.truncate_batch_tokens_at) + post_ablation = {"forget_loss": post_forget, "retain_loss": post_retain} + print(f"Post-ablation: {post_ablation}") + wandb.log({"post_ablation_forget_loss": post_forget, "post_ablation_retain_loss": post_retain}) + + # Save ablation comparison + with open(os.path.join(save_dir, "ablation_comparison.json"), "w") as f: + json.dump({"pre_ablation": pre_ablation, "post_ablation": post_ablation}, f, indent=2) + + # Coherence finetuning on retain data (mini-batched for memory) + retain_for_coherence = load_retain_data(max_examples=1000 if not dry_run else 100) + rng = np.random.default_rng(42) + rng.shuffle(retain_for_coherence) + coherence_train = retain_for_coherence[:cfg.num_coherence_retain_train] + coherence_test = retain_for_coherence[ + cfg.num_coherence_retain_train : + cfg.num_coherence_retain_train + cfg.num_coherence_retain_test + ] + + coherence_dataloader = data.DataLoader( + string_utils.ListDataset(coherence_train), + shuffle=True, + batch_size=cfg.batch_size, + ) + + optim = t.optim.AdamW( + contracted_model.parameters(), + lr=cfg.coherence_lr, + **cfg.optimizer_kwargs, + ) + + best_loss = eval_on_validation(contracted_model, coherence_test, cfg.truncate_batch_tokens_at) + # Store best weights on CPU to save GPU memory for optimizer states + best_model_weights = {k: v.cpu().clone() for k, v in contracted_model.state_dict().items()} + best_step = 0 + + coherence_steps = 5 if dry_run else cfg.num_steps_coherence_finetuning + coherence_iter = iter(coherence_dataloader) + for step in (pbar := tqdm.tqdm(range(coherence_steps + 1))): + try: + batch = next(coherence_iter) + except StopIteration: + coherence_iter = iter(coherence_dataloader) + batch = next(coherence_iter) + stories, labels = batch + input_ids, attention_mask = string_utils.tokenize_batch( + stories, + contracted_model.tokenizer, + prepend_bos=True, + truncate_at=cfg.truncate_batch_tokens_at, + padding_side="left", + device=device, + ) + with t.autocast(device_type="cuda", dtype=t.bfloat16, enabled=device.type == "cuda"): + loss = training.compute_preds_and_get_ce_loss( + contracted_model, input_ids, attention_mask, None + ) + optim.zero_grad() + loss.backward() + optim.step() + wandb.log({"coherence_train_loss": loss.item()}) + + if step % 10 == 0 or step == coherence_steps: + test_loss = eval_on_validation( + contracted_model, coherence_test, cfg.truncate_batch_tokens_at + ) + if test_loss < best_loss: + best_loss = test_loss + best_model_weights = {k: v.cpu().clone() for k, v in contracted_model.state_dict().items()} + best_step = step + wandb.log({"coherence_test_loss": test_loss}) + + forget_loss = eval_on_validation( + contracted_model, forget_val[:20], cfg.truncate_batch_tokens_at + ) + wandb.log({"coherence_forget_loss": forget_loss}) + + print(f"Best coherence loss {best_loss:.4f} at step {best_step}") + wandb.run.summary["best_coherence_step"] = best_step + + contracted_model.load_state_dict(best_model_weights) + + # Final eval + final_forget = eval_on_validation(contracted_model, forget_val, cfg.truncate_batch_tokens_at) + final_retain = eval_on_validation(contracted_model, retain_val, cfg.truncate_batch_tokens_at) + print(f"Final: forget_loss={final_forget:.4f}, retain_loss={final_retain:.4f}") + wandb.log({"final_forget_loss": final_forget, "final_retain_loss": final_retain}) + + # Save final model + final_path = os.path.join(save_dir, "final.pt") + t.save(contracted_model.state_dict(), final_path) + print(f"Final model saved to {final_path}") + + with open(os.path.join(save_dir, "results.json"), "w") as f: + json.dump({ + "pre_ablation": pre_ablation, + "post_ablation": post_ablation, + "final": {"forget_loss": final_forget, "retain_loss": final_retain}, + "best_coherence_step": best_step, + }, f, indent=2) + + wandb.finish() + return contracted_model + + +if __name__ == "__main__": + from factored_representations.utils import get_gpu_with_most_memory + + DRY_RUN = len(sys.argv) > 1 and sys.argv[1] == "dry_run" + device = get_gpu_with_most_memory() + print(f"Running virology ERA on {device=}, {DRY_RUN=}") + + cfg = VirologyERAConfig() + save_dir = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "outputs", + "dry_run" if DRY_RUN else "full_run", + ) + + do_virology_era_run(cfg, device, save_dir, dry_run=DRY_RUN) diff --git a/projects/virology_era/virology_settings.py b/projects/virology_era/virology_settings.py new file mode 100644 index 0000000..dc2fbe9 --- /dev/null +++ b/projects/virology_era/virology_settings.py @@ -0,0 +1,73 @@ +from dataclasses import dataclass, field +from pathlib import Path + + +# Resolve the model snapshot path via the HF cache symlink +MODEL_PATH = str( + Path.home() + / ".cache/huggingface/hub/models--EleutherAI--deep-ignorance-unfiltered" + / "snapshots/c8df368ff247cb90b62e21e1689260701b3ff25a" +) + +# Model architecture (from config.json) +D_MODEL = 4096 +D_MLP = 16384 +N_LAYERS = 32 +N_HEADS = 32 +D_HEAD = D_MODEL // N_HEADS # 128 + + +@dataclass +class VirologyERAConfig: + # Model + model_path: str = MODEL_PATH + + # Training + batch_size: int = 2 + grad_accum_steps: int = 8 # effective batch = 16 + truncate_batch_tokens_at: int = 512 + learning_rate: float = 5e-5 + optimizer_kwargs: dict = field( + default_factory=lambda: dict(betas=(0.9, 0.95), weight_decay=0.1) + ) + + # ERA + layers_to_mask: list = field( + default_factory=lambda: list(range(8)) # first 8 of 32 + ) + d_mlp_expansion: int = 128 # 128/16384 = 0.78% + num_steps_era_training: int = 2000 + num_steps_coherence_finetuning: int = 500 + coherence_lr: float = 5e-5 + + # Masking + masking_scheme: str = "full_seq" + masking_type: str = "ddbp" + expanded_dim_lr_target: float = 1.0 + original_dim_lr_target: float = -0.75 + expanded_dim_lr_off_target: float = 1.0 + original_dim_lr_off_target: float = 1.0 + + # Data + forget_dataset: str = "Unlearning/WMDP-Bio-Remove-Dataset" + retain_dataset: str = "EleutherAI/wikitext_document_level" + + # Eval + num_coherence_retain_train: int = 64 + num_coherence_retain_test: int = 200 + + # W&B + wandb_project: str = "virology-era-unlearning" + + @property + def to_expand(self) -> dict: + return {"d_mlp": self.d_mlp_expansion} + + @property + def expanded_vs_original_dim_learning_rates(self) -> dict: + return dict( + expanded_dim_lr_target=self.expanded_dim_lr_target, + original_dim_lr_target=self.original_dim_lr_target, + expanded_dim_lr_off_target=self.expanded_dim_lr_off_target, + original_dim_lr_off_target=self.original_dim_lr_off_target, + ) diff --git a/runs/virology_era.sbatch b/runs/virology_era.sbatch new file mode 100644 index 0000000..9fc4d4b --- /dev/null +++ b/runs/virology_era.sbatch @@ -0,0 +1,92 @@ +#!/bin/bash +#SBATCH --job-name=virology-era +#SBATCH --nodes=1 +#SBATCH --exclusive +#SBATCH --gpus-per-node=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --time=24:00:00 +#SBATCH --output=/home/a6a/lucia.a6a/gradient-routing/runs/virology-era-%j.out + +source /home/a6a/lucia.a6a/miniforge3/etc/profile.d/conda.sh +conda activate gradient-routing + +module load PrgEnv-cray +module load cuda/12.6 + +export PYTORCH_ALLOC_CONF=expandable_segments:True +export TORCH_HOME=/home/a6a/lucia.a6a/.cache/torch +export TORCHINDUCTOR_CACHE_DIR=/home/a6a/lucia.a6a/.cache/torch_inductor +export HF_HUB_OFFLINE=1 + +cd /home/a6a/lucia.a6a/gradient-routing + +echo "=== Virology ERA Gradient Routing ===" +echo "Time: $(date)" +echo "Node: $(hostname)" +nvidia-smi + +# Step 1: ERA training (single GPU, picked by get_gpu_with_most_memory) +echo "Step 1: ERA training" +python projects/virology_era/virology_era.py + +SAVE_DIR="projects/virology_era/outputs/full_run" + +if [ ! -f "${SAVE_DIR}/final.pt" ]; then + echo "ERROR: Final model not found at ${SAVE_DIR}/final.pt" + exit 1 +fi + +# Step 2: Convert TL model to HF format for evaluation +echo "Step 2: Convert to HF format" +python projects/virology_era/evaluate_virology.py \ + "${SAVE_DIR}/final.pt" \ + --hf-convert \ + --output-dir "${SAVE_DIR}/hf_model" + +HF_MODEL="${SAVE_DIR}/hf_model" + +if [ ! -d "$HF_MODEL" ]; then + echo "ERROR: HF model not found at $HF_MODEL" + exit 1 +fi + +# Step 3: Multi-GPU evaluation with lm_eval +echo "Step 3: WMDP Bio Robust evaluation" +python -m lm_eval --model hf \ + --model_args pretrained=$HF_MODEL \ + --tasks wmdp_bio_robust \ + --include_path "/home/a6a/lucia.a6a/unlearn/unlearn/lm_eval_tasks" \ + --batch_size 32 \ + --verbosity WARNING \ + --output_path "${SAVE_DIR}/eval_wmdp" + +echo "Step 4: MMLU evaluation" +python -m lm_eval --model hf \ + --model_args pretrained=$HF_MODEL \ + --tasks mmlu \ + --num_fewshot 1 \ + --batch_size 32 \ + --verbosity WARNING \ + --output_path "${SAVE_DIR}/eval_mmlu" + +# Also eval the baseline for comparison +echo "Step 5: Baseline WMDP Bio Robust" +BASELINE="/home/a6a/lucia.a6a/.cache/huggingface/hub/models--EleutherAI--deep-ignorance-unfiltered/snapshots/c8df368ff247cb90b62e21e1689260701b3ff25a" +python -m lm_eval --model hf \ + --model_args pretrained=$BASELINE \ + --tasks wmdp_bio_robust \ + --include_path "/home/a6a/lucia.a6a/unlearn/unlearn/lm_eval_tasks" \ + --batch_size 32 \ + --verbosity WARNING \ + --output_path "${SAVE_DIR}/eval_wmdp_baseline" + +echo "Step 6: Baseline MMLU" +python -m lm_eval --model hf \ + --model_args pretrained=$BASELINE \ + --tasks mmlu \ + --num_fewshot 1 \ + --batch_size 32 \ + --verbosity WARNING \ + --output_path "${SAVE_DIR}/eval_mmlu_baseline" + +echo "Done: $(date)" diff --git a/runs/virology_era_dryrun.sbatch b/runs/virology_era_dryrun.sbatch new file mode 100644 index 0000000..921f82c --- /dev/null +++ b/runs/virology_era_dryrun.sbatch @@ -0,0 +1,30 @@ +#!/bin/bash +#SBATCH --job-name=vir-era-dry +#SBATCH --nodes=1 +#SBATCH --exclusive +#SBATCH --gpus-per-node=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --time=01:00:00 +#SBATCH --output=/home/a6a/lucia.a6a/gradient-routing/runs/virology-era-dryrun-%j.out + +source /home/a6a/lucia.a6a/miniforge3/etc/profile.d/conda.sh +conda activate gradient-routing + +module load PrgEnv-cray +module load cuda/12.6 + +export PYTORCH_ALLOC_CONF=expandable_segments:True +export TORCH_HOME=/home/a6a/lucia.a6a/.cache/torch +export TORCHINDUCTOR_CACHE_DIR=/home/a6a/lucia.a6a/.cache/torch_inductor +export HF_HUB_OFFLINE=1 + +cd /home/a6a/lucia.a6a/gradient-routing + +echo "=== Virology ERA Dry Run ===" +echo "Time: $(date)" +echo "Node: $(hostname)" +nvidia-smi + +python projects/virology_era/virology_era.py dry_run + +echo "Done: $(date)"