Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
324 changes: 324 additions & 0 deletions generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,324 @@
"""
Generate text from a trained autoresearch checkpoint.

Usage:
uv run generate.py # interactive mode
uv run generate.py "Once upon a time" # single prompt
uv run generate.py --tokens 200 "Hello" # control length
"""

import os
import sys
import argparse
from dataclasses import dataclass

os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"

import torch
import torch.nn as nn
import torch.nn.functional as F

from prepare import Tokenizer


@dataclass
class GPTConfig:
sequence_len: int = 2048
vocab_size: int = 32768
n_layer: int = 12
n_head: int = 6
n_kv_head: int = 6
n_embd: int = 768
window_pattern: str = "SSSL"


def norm(x):
return F.rms_norm(x, (x.size(-1),))


def has_ve(layer_idx, n_layer):
return layer_idx % 2 == (n_layer - 1) % 2


def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4
d = x.shape[3] // 2
x1, x2 = x[..., :d], x[..., d:]
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat([y1, y2], 3)


class CausalSelfAttention(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.n_head = config.n_head
self.n_kv_head = config.n_kv_head
self.n_embd = config.n_embd
self.head_dim = self.n_embd // self.n_head
assert self.n_embd % self.n_head == 0
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
self.ve_gate_channels = 32
self.ve_gate = (
nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False)
if has_ve(layer_idx, config.n_layer)
else None
)

def forward(self, x, ve, cos_sin, window_size):
B, T, C = x.size()
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)

if ve is not None:
ve = ve.view(B, T, self.n_kv_head, self.head_dim)
gate = 2 * torch.sigmoid(self.ve_gate(x[..., : self.ve_gate_channels]))
v = v + gate.unsqueeze(-1) * ve

cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
q, k = norm(q), norm(k)

k = k.repeat_interleave(self.n_head // self.n_kv_head, dim=2)
v = v.repeat_interleave(self.n_head // self.n_kv_head, dim=2)

q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

window = window_size[0]
if window > 0 and window < T:
mask = torch.ones(T, T, dtype=torch.bool, device=q.device).tril()
mask = mask.triu(diagonal=1 - window)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
else:
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)

y = y.transpose(1, 2).contiguous().view(B, T, -1)
y = self.c_proj(y)
return y


class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)

def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square()
x = self.c_proj(x)
return x


class Block(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.attn = CausalSelfAttention(config, layer_idx)
self.mlp = MLP(config)

def forward(self, x, ve, cos_sin, window_size):
x = x + self.attn(norm(x), ve, cos_sin, window_size)
x = x + self.mlp(norm(x))
return x


class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.window_sizes = self._compute_window_sizes(config)
self.transformer = nn.ModuleDict(
{
"wte": nn.Embedding(config.vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, i) for i in range(config.n_layer)]),
}
)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer))
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer))
head_dim = config.n_embd // config.n_head
kv_dim = config.n_kv_head * head_dim
self.value_embeds = nn.ModuleDict(
{
str(i): nn.Embedding(config.vocab_size, kv_dim)
for i in range(config.n_layer)
if has_ve(i, config.n_layer)
}
)
self.rotary_seq_len = config.sequence_len * 10
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)

def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
if device is None:
device = self.transformer.wte.weight.device
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
inv_freq = 1.0 / (base ** (channel_range / head_dim))
t = torch.arange(seq_len, dtype=torch.float32, device=device)
freqs = torch.outer(t, inv_freq)
cos, sin = freqs.cos(), freqs.sin()
cos, sin = cos.bfloat16(), sin.bfloat16()
cos, sin = cos[None, :, None, :], sin[None, :, None, :]
return cos, sin

def _compute_window_sizes(self, config):
pattern = config.window_pattern.upper()
assert all(c in "SL" for c in pattern)
long_window = config.sequence_len
short_window = long_window // 2
char_to_window = {"L": (long_window, 0), "S": (short_window, 0)}
window_sizes = []
for layer_idx in range(config.n_layer):
char = pattern[layer_idx % len(pattern)]
window_sizes.append(char_to_window[char])
window_sizes[-1] = (long_window, 0)
return window_sizes

def forward(self, idx, targets=None, reduction="mean"):
B, T = idx.size()
assert T <= self.cos.size(1)
cos_sin = self.cos[:, :T], self.sin[:, :T]

x = self.transformer.wte(idx)
x = norm(x)
x0 = x
for i, block in enumerate(self.transformer.h):
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
x = block(x, ve, cos_sin, self.window_sizes[i])
x = norm(x)

softcap = 15
logits = self.lm_head(x)
logits = logits.float()
logits = softcap * torch.tanh(logits / softcap)

if targets is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
ignore_index=-1,
reduction=reduction,
)
return loss
return logits


@torch.no_grad()
def generate(model, tokenizer, prompt_ids, max_new_tokens, temperature, top_k, device):
idx = torch.tensor([prompt_ids], dtype=torch.long, device=device)
seq_len = model.config.sequence_len

for _ in range(max_new_tokens):
idx_cond = idx if idx.size(1) <= seq_len else idx[:, -seq_len:]
logits = model(idx_cond)
logits = logits[:, -1, :]

if temperature > 0:
logits = logits / temperature
if top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("inf")
probs = F.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
else:
next_id = logits.argmax(dim=-1, keepdim=True)

idx = torch.cat([idx, next_id], dim=1)

return idx[0].tolist()


def load_model(ckpt_path, device):
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
config = GPTConfig(**ckpt["config"])
model = GPT(config)
model.load_state_dict(ckpt["model"])
model.to(device)
model.eval()
return model


def main():
parser = argparse.ArgumentParser(
description="Generate text from autoresearch checkpoint"
)
parser.add_argument(
"prompt",
nargs="?",
default=None,
help="Text prompt (omit for interactive mode)",
)
parser.add_argument(
"--checkpoint", default="checkpoint.pt", help="Path to checkpoint file"
)
parser.add_argument(
"--tokens", type=int, default=128, help="Max tokens to generate"
)
parser.add_argument(
"--temperature",
type=float,
default=0.8,
help="Sampling temperature (0 = greedy)",
)
parser.add_argument(
"--top-k", type=int, default=40, help="Top-k sampling (0 = disabled)"
)
args = parser.parse_args()

if not os.path.exists(args.checkpoint):
print(f"No checkpoint found at '{args.checkpoint}'.")
print("Run 'uv run train.py' first to train a model and save a checkpoint.")
sys.exit(1)

device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)

print(f"Loading checkpoint from {args.checkpoint}...")
model = load_model(args.checkpoint, device)
tokenizer = Tokenizer.from_directory()
bos = tokenizer.get_bos_token_id()
print(f"Model: {sum(p.numel() for p in model.parameters()):,} params on {device}")
print()

def run_prompt(text):
if text.strip():
ids = [bos] + tokenizer.encode(text)
else:
ids = [bos]
output_ids = generate(
model, tokenizer, ids, args.tokens, args.temperature, args.top_k, device
)
return tokenizer.decode(output_ids)

if args.prompt is not None:
print(run_prompt(args.prompt))
else:
print("Interactive mode. Type a prompt and press Enter. Ctrl+C to quit.")
print()
while True:
try:
text = input("> ")
except (KeyboardInterrupt, EOFError):
print()
break
output = run_prompt(text)
print(output)
print()


if __name__ == "__main__":
main()
10 changes: 10 additions & 0 deletions program.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@ Each experiment runs on a single GPU. The training script runs for a **fixed tim

**Simplicity criterion**: All else being equal, simpler is better. A small improvement that adds ugly complexity is not worth it. Conversely, removing something and getting equal or better results is a great outcome — that's a simplification win. When evaluating whether to keep a change, weigh the complexity cost against the improvement magnitude. A 0.001 val_bpb improvement that adds 20 lines of hacky code? Probably not worth it. A 0.001 val_bpb improvement from deleting code? Definitely keep. An improvement of ~0 but much simpler code? Keep.

**Research style: GO WILD.** Don't be conservative. Try radical architectural changes, unconventional optimizers, weird activation functions, dramatic model size swings, aggressive hyperparameter sweeps. If an experiment crashes, that's fine — log it and move on. Favor bold bets over incremental tweaks. Some ideas to explore early:
- Dramatically different depth/width ratios (e.g. 2 layers very wide, or 16 layers very narrow)
- SwiGLU, GELU, or exotic activations instead of ReLU²
- Aggressive learning rate experiments (2x, 5x, 0.1x)
- Removing components (value embeddings, residual lambdas) to see if they actually help
- Different attention patterns (all local, all global, alternating)
- Halving or doubling batch size
- Weight tying (embedding = unembedding)
- Different warmup/cooldown schedules

**The first run**: Your very first run should always be to establish the baseline, so you will run the training script as is.

## Output format
Expand Down
24 changes: 24 additions & 0 deletions results.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
commit val_bpb memory_gb status description
d2298f6 1.924103 0.0 keep baseline
3b502ec 1.765497 0.0 keep depth 4->5 (24.6M params)
uncommitted 1.830374 0.0 discard depth 5->6 (26.3M params)
38a56f3 1.766993 0.0 discard SwiGLU activation
5595b39 1.782992 0.0 discard GELU activation
adf23d4 1.747570 0.0 keep 2x MATRIX_LR 0.04->0.08
bc2201f 1.796074 0.0 discard 3x MATRIX_LR 0.04->0.12
a7d287a 3.215849 0.0 discard weight tying (catastrophic)
d844272 1.750274 0.0 discard 2x EMBEDDING_LR 0.6->1.2
815735f 1.825159 0.0 discard wider ASPECT_RATIO 64->80
0e6c5e0 1.793629 0.0 discard remove value embeddings
c3bdac1 1.648321 0.0 keep halve TOTAL_BATCH_SIZE 65K->32K (-5.7%)
bab9d9e 1.532485 0.0 keep quarter batch 32K->16K (-7.0%)
c7b92df 0.000000 0.0 crash eighth batch 8K (assertion: batch < microbatch)
ba16b77 1.535071 0.0 discard warmdown 0.5->0.3
700a8b9 1.557813 0.0 discard remove weight decay
f46bad9 1.627648 0.0 discard add warmup 0.05
0f96deb 1.542159 0.0 discard DEVICE_BATCH=4 TOTAL=8K (too noisy)
34658ab 1.461426 0.0 keep DEPTH=4 with optimized HPs (11.5M 443 steps -4.6%)
6dee92c 1.454289 0.0 keep DEPTH=3 (10.7M 533 steps -0.5%)
b4e9a83 1.520656 0.0 discard DEPTH=2 (3.5M too small)
6465c77 1.472244 0.0 discard wider ASPECT_RATIO=128 at DEPTH=3
e66b0a3 1.453113 0.0 keep MATRIX_LR 0.08->0.06 at DEPTH=3
Loading