-
Notifications
You must be signed in to change notification settings - Fork 461
Feature/issue 275 gemma3 tpu v5e8 #283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3842595
66f072e
234d936
679ce09
831b0b6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,185 @@ | ||
| { | ||
| "cells": [ | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "2c0670cb", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "# Optax Lookahead Optimizer: Bug Identification and Fix\n", | ||
| "\n", | ||
| "This notebook demonstrates how to identify, fix, and verify a bug related to the usage of the Optax lookahead optimizer. We will:\n", | ||
| "\n", | ||
| "1. Identify the issue in the code.\n", | ||
| "2. Reproduce the bug.\n", | ||
| "3. Apply the fix.\n", | ||
| "4. Verify the fix with unit tests.\n", | ||
| "5. Check the output.\n", | ||
| "6. Run the fixed code in the integrated terminal." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "dde58b8b", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## 1. Identify the Issue\n", | ||
| "\n", | ||
| "Suppose we have a bug in our usage of the Optax lookahead optimizer, such as incorrect initialization or improper application in a training loop. Below is a snippet of the problematic code section:\n", | ||
| "\n", | ||
| "```python\n", | ||
| "import optax\n", | ||
| "base_optimizer = optax.sgd(learning_rate=0.1)\n", | ||
| "lookahead = optax.lookahead(base_optimizer)\n", | ||
| "# ...\n", | ||
| "# Incorrect usage: not updating the lookahead state properly\n", | ||
| "```\n", | ||
| "\n", | ||
| "The issue: The lookahead optimizer state is not being updated correctly during training, leading to suboptimal or incorrect training behavior." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "6d5deef3", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## 2. Reproduce the Bug\n", | ||
| "\n", | ||
| "Let's reproduce the bug by running a minimal MNIST training loop using the incorrect lookahead usage. This will show the error or unexpected behavior." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "54741a5a", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# Minimal MNIST training loop with incorrect lookahead usage\n", | ||
| "import jax\n", | ||
| "import jax.numpy as jnp\n", | ||
| "import optax\n", | ||
| "import numpy as np\n", | ||
| "\n", | ||
| "# Dummy data for demonstration\n", | ||
| "x = jnp.ones((32, 784))\n", | ||
| "y = jnp.zeros((32,), dtype=jnp.int32)\n", | ||
| "\n", | ||
| "# Simple model\n", | ||
| "def model(params, x):\n", | ||
| " return jnp.dot(x, params['w']) + params['b']\n", | ||
| "\n", | ||
| "def loss_fn(params, x, y):\n", | ||
| " logits = model(params, x)\n", | ||
| " return jnp.mean((logits - y) ** 2)\n", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| "\n", | ||
| "params = {'w': jnp.zeros((784, 10)), 'b': jnp.zeros((10,))}\n", | ||
| "base_optimizer = optax.sgd(learning_rate=0.1)\n", | ||
| "lookahead = optax.lookahead(base_optimizer)\n", | ||
| "opt_state = lookahead.init(params)\n", | ||
| "\n", | ||
| "@jax.jit\n", | ||
| "def update(params, opt_state, x, y):\n", | ||
| " grads = jax.grad(loss_fn)(params, x, y)\n", | ||
| " updates, new_opt_state = lookahead.update(grads, opt_state, params)\n", | ||
| " new_params = optax.apply_updates(params, updates)\n", | ||
| " return new_params, new_opt_state\n", | ||
| "\n", | ||
| "# Incorrect usage: not updating lookahead state properly in a loop\n", | ||
| "for step in range(5):\n", | ||
| " params, opt_state = update(params, opt_state, x, y)\n", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line correctly updates the |
||
| " print(f\"Step {step}, Loss: {loss_fn(params, x, y)}\")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "a6793fd0", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## 3. Apply the Fix\n", | ||
| "\n", | ||
| "To fix the bug, ensure that the lookahead optimizer state is updated correctly and that the slow weights are properly synchronized. Here is the corrected code:" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "862b51f5", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# Corrected lookahead usage\n", | ||
| "params = {'w': jnp.zeros((784, 10)), 'b': jnp.zeros((10,))}\n", | ||
| "base_optimizer = optax.sgd(learning_rate=0.1)\n", | ||
| "lookahead = optax.lookahead(base_optimizer)\n", | ||
| "opt_state = lookahead.init(params)\n", | ||
| "\n", | ||
| "@jax.jit\n", | ||
| "def update(params, opt_state, x, y):\n", | ||
| " grads = jax.grad(loss_fn)(params, x, y)\n", | ||
| " updates, new_opt_state = lookahead.update(grads, opt_state, params)\n", | ||
| " new_params = optax.apply_updates(params, updates)\n", | ||
| " return new_params, new_opt_state\n", | ||
| "\n", | ||
| "for step in range(5):\n", | ||
| " params, opt_state = update(params, opt_state, x, y)\n", | ||
| " # Correct: always use the updated opt_state\n", | ||
| " print(f\"Step {step}, Loss: {loss_fn(params, x, y)}\")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "053261f5", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## 4. Verify the Fix with Unit Tests\n", | ||
| "\n", | ||
| "Let's write a simple test to confirm that the lookahead optimizer now updates the parameters and state as expected." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "b4ffbe0f", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# Simple test: check that parameters are updated\n", | ||
| "params = {'w': jnp.zeros((784, 10)), 'b': jnp.zeros((10,))}\n", | ||
| "opt_state = lookahead.init(params)\n", | ||
| "initial_loss = loss_fn(params, x, y)\n", | ||
| "for _ in range(3):\n", | ||
| " params, opt_state = update(params, opt_state, x, y)\n", | ||
| "final_loss = loss_fn(params, x, y)\n", | ||
| "assert final_loss < initial_loss + 1e-5, \"Loss did not decrease as expected!\"\n", | ||
| "print(f\"Initial loss: {initial_loss}, Final loss: {final_loss}\")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "09d7c3f4", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## 5. Check Output in Output Pane\n", | ||
| "\n", | ||
| "The output above should show a decreasing loss value, confirming that the optimizer is working as expected after the fix." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "15f199a7", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## 6. Run in Integrated Terminal\n", | ||
| "\n", | ||
| "To validate end-to-end functionality, you can run the fixed code in the integrated terminal or as a script. This ensures the bug is resolved in all environments." | ||
| ] | ||
| } | ||
| ], | ||
| "metadata": { | ||
| "language_info": { | ||
| "name": "python" | ||
| } | ||
| }, | ||
| "nbformat": 4, | ||
| "nbformat_minor": 5 | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"mount_file_id":"1nnp8oaJ0xLYyGRKg7HxMvz3ogzdVhcQo","authorship_tag":"ABX9TyNmFgJvRVwFtguMONKHwZIm"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":3,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"aOlzDFou39zu","executionInfo":{"status":"ok","timestamp":1769608951676,"user_tz":-330,"elapsed":3956,"user":{"displayName":"Auysh Debnath","userId":"15768827574649638244"}},"outputId":"8ae9fb90-4cf5-411a-8033-66cd53c8e719"},"outputs":[{"output_type":"stream","name":"stdout","text":["Traceback (most recent call last):\n"," File \"/content/drive/MyDrive/purdueprj/main_train.py\", line 25, in <module>\n"," main()\n"," File \"/content/drive/MyDrive/purdueprj/main_train.py\", line 14, in main\n"," ds = ModelNetDataset(root=\"/content/drive/MyDrive/ModelNet10\", split='train')\n"," ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n"," File \"/content/drive/MyDrive/purdueprj/data/modelnet.py\", line 12, in __init__\n"," self.files = self._scan_files()\n"," ^^^^^^^^^^^^^^^^^^\n"," File \"/content/drive/MyDrive/purdueprj/data/modelnet.py\", line 17, in _scan_files\n"," for class_name in sorted(os.listdir(self.root)):\n"," ^^^^^^^^^^^^^^^^^^^^^\n","FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/ModelNet10'\n"]}],"source":["# !python /content/drive/MyDrive/purdueprj/main_train.py\n"]},{"cell_type":"code","source":["!mkdir -p /content/drive/MyDrive/ModelNet10\n","\n","!wget -O /content/drive/MyDrive/ModelNet10.zip \\\n"," https://huggingface.co/datasets/ShapeNet/ModelNet10/resolve/main/ModelNet10_npy.zip\n","\n","!unzip /content/drive/MyDrive/ModelNet10.zip -d /content/drive/MyDrive/\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"tkQlOM5c6cTo","executionInfo":{"status":"ok","timestamp":1769609193866,"user_tz":-330,"elapsed":456,"user":{"displayName":"Auysh Debnath","userId":"15768827574649638244"}},"outputId":"5e457186-3a33-425c-dff8-cfb89153266e"},"execution_count":4,"outputs":[{"output_type":"stream","name":"stdout","text":["--2026-01-28 14:06:32-- https://huggingface.co/datasets/ShapeNet/ModelNet10/resolve/main/ModelNet10_npy.zip\n","Resolving huggingface.co (huggingface.co)... 18.239.50.103, 18.239.50.49, 18.239.50.80, ...\n","Connecting to huggingface.co (huggingface.co)|18.239.50.103|:443... connected.\n","HTTP request sent, awaiting response... 401 Unauthorized\n","\n","Username/Password Authentication Failed.\n","Archive: /content/drive/MyDrive/ModelNet10.zip\n"," End-of-central-directory signature not found. Either this file is not\n"," a zipfile, or it constitutes one disk of a multi-part archive. In the\n"," latter case the central directory and zipfile comment will be found on\n"," the last disk(s) of this archive.\n","unzip: cannot find zipfile directory in one of /content/drive/MyDrive/ModelNet10.zip or\n"," /content/drive/MyDrive/ModelNet10.zip.zip, and cannot find /content/drive/MyDrive/ModelNet10.zip.ZIP, period.\n"]}]},{"cell_type":"code","source":[],"metadata":{"id":"EvHHfqNl6b_c"},"execution_count":null,"outputs":[]}]} |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| import os | ||
| import torch | ||
| import numpy as np | ||
| from torch.utils.data import Dataset | ||
|
|
||
| class ModelNetDataset(Dataset): | ||
| def __init__(self, root, num_points=1024, split='train'): | ||
| self.root = root | ||
| self.num_points = num_points | ||
| self.split = split | ||
|
|
||
| self.files = self._scan_files() | ||
| self.data, self.labels = self._load_all() | ||
|
|
||
| def _scan_files(self): | ||
| items = [] | ||
| for class_name in sorted(os.listdir(self.root)): | ||
| class_path = os.path.join(self.root, class_name, self.split) | ||
| if not os.path.isdir(class_path): | ||
| continue | ||
| label = sorted(os.listdir(self.root)).index(class_name) | ||
| for f in os.listdir(class_path): | ||
| if f.endswith('.npy') or f.endswith('.txt'): | ||
| items.append((os.path.join(class_path, f), label)) | ||
| return items | ||
|
|
||
| def _load_points(self, path): | ||
| if path.endswith('.npy'): | ||
| pts = np.load(path).astype(np.float32) | ||
| else: | ||
| pts = np.loadtxt(path).astype(np.float32) | ||
| return pts | ||
|
|
||
| def _load_all(self): | ||
| all_pts, all_labels = [], [] | ||
| for path, label in self.files: | ||
| pts = self._load_points(path) | ||
|
|
||
| if pts.shape[0] >= self.num_points: | ||
| idx = np.random.choice(pts.shape[0], self.num_points, replace=False) | ||
| pts = pts[idx] | ||
| else: | ||
| pad = self.num_points - pts.shape[0] | ||
| pts = np.vstack([pts, pts[:pad]]) | ||
|
|
||
| all_pts.append(pts) | ||
| all_labels.append(label) | ||
|
|
||
| return np.array(all_pts), np.array(all_labels) | ||
|
|
||
| def __len__(self): | ||
| return len(self.labels) | ||
|
|
||
| def __getitem__(self, idx): | ||
| pts = self.data[idx] | ||
| label = self.labels[idx] | ||
| np.random.shuffle(pts) | ||
| return torch.tensor(pts, dtype=torch.float32), torch.tensor(label, dtype=torch.long) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| import torch | ||
|
|
||
| def slice_random(points, T=8): | ||
| N = points.shape[0] | ||
| perm = torch.randperm(N) | ||
| return torch.chunk(perm, T) | ||
|
|
||
| def slice_radial(points, T=8): | ||
| center = points.mean(dim=0) | ||
| dist = torch.norm(points - center, dim=1) | ||
| perm = torch.argsort(dist) # inner → outer | ||
| return torch.chunk(perm, T) | ||
|
|
||
| def slice_pca(points, T=8): | ||
| X = points - points.mean(dim=0) | ||
| U, S, V = torch.pca_lowrank(X) | ||
| pc1 = V[:, 0] | ||
| proj = X @ pc1 | ||
| perm = torch.argsort(proj) | ||
| return torch.chunk(perm, T) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| import torch | ||
| from torch.utils.data import DataLoader | ||
|
|
||
| from data.modelnet import ModelNetDataset | ||
| from models.pointnet_snn import PointNetSNN | ||
| from training.train_loop import train_one_epoch | ||
| from training.optimizers import build_optimizer | ||
|
|
||
| def main(): | ||
|
|
||
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
|
|
||
| # Debug dataset | ||
| ds = ModelNetDataset(root="/content/drive/MyDrive/ModelNet10", split='train') | ||
| dataloader = DataLoader(ds, batch_size=1, shuffle=True) | ||
|
|
||
| model = PointNetSNN(num_classes=10).to(device) | ||
| optimizer = build_optimizer(model) | ||
|
|
||
| for epoch in range(5): | ||
| loss, acc = train_one_epoch(model, dataloader, optimizer, device) | ||
| print(f"Epoch {epoch} | Loss: {loss:.4f} | Acc: {acc:.4f}") | ||
|
|
||
| if __name__ == "__main__": | ||
| main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| import torch | ||
| import torch.nn as nn | ||
| from models.snn_layers import LIFLayer | ||
|
|
||
| class PointNetBackbone(nn.Module): | ||
| def __init__(self, hidden_dims=[64,128,256]): | ||
| super().__init__() | ||
| self.layers = nn.ModuleList() | ||
| in_dim = 3 | ||
| for h in hidden_dims: | ||
| self.layers.append(LIFLayer(in_dim, h)) | ||
| in_dim = h | ||
|
|
||
| def reset_state(self, batch_size, device=None): | ||
| for layer in self.layers: | ||
| layer.reset_state(batch_size, device) | ||
|
|
||
| def forward(self, pts): | ||
| # pts: [B, N, 3] | ||
| B, N, _ = pts.shape | ||
| x = pts.view(B*N, -1) # merge points | ||
| for layer in self.layers: | ||
| spk, mem = layer(x) | ||
| x = mem | ||
| return mem.view(B, N, -1) # unmerge |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| from models.pointnet_backbone import PointNetBackbone | ||
| from models.temporal_snn import TemporalSNN | ||
|
|
||
|
|
||
| class PointNetSNN(nn.Module): | ||
| """ | ||
| Full SNN-PointNet model: | ||
| - Per-point spiking MLP (PointNet backbone) | ||
| - Slice pooling (mean pooling) | ||
| - Temporal SNN accumulator | ||
| - Final classifier (inside temporal SNN module) | ||
| """ | ||
|
|
||
| def __init__(self, | ||
| point_dims=[64, 128, 256], # per-point feature sizes | ||
| temporal_dim=256, # temporal SNN hidden dimension | ||
| num_classes=10): # ModelNet10 for debugging | ||
| super().__init__() | ||
|
|
||
| self.backbone = PointNetBackbone(hidden_dims=point_dims) | ||
| self.temporal = TemporalSNN(dim=temporal_dim) | ||
| self.num_classes = num_classes | ||
|
|
||
| def reset_state(self, batch_size, device=None): | ||
| """Reset membrane states before each new sample.""" | ||
| self.backbone.reset_state(batch_size, device) | ||
| self.temporal.reset_state(batch_size, device) | ||
|
|
||
| def forward_step(self, pts_slice): | ||
| """ | ||
| Process a single slice of points. | ||
| pts_slice: [B, n_points, 3] | ||
|
|
||
| Returns: | ||
| logits_t : [B, num_classes] | ||
| """ | ||
|
|
||
| # 1. Per-point spiking MLP | ||
| per_point_feat = self.backbone(pts_slice) # [B, M, 256] | ||
|
|
||
| # 2. Mean pooling across slice points → slice embedding | ||
| slice_feat = per_point_feat.mean(dim=1) # [B, 256] | ||
|
|
||
| # 3. Feed slice embedding into temporal SNN | ||
| logits_t = self.temporal(slice_feat) # [B, num_classes] | ||
|
|
||
| return logits_t |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dummy data setup with
yas all zeros, combined with zero-initialized parameters, results in an initial loss of 0 and zero gradients. Consequently, the optimizer will not update the parameters, and the loss will not decrease. This prevents the notebook from demonstrating that the optimizer is working. To fix this, initializeywith non-zero values to ensure there is a non-zero loss and gradient at the start of training.