Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions Desktop/GSoC/Deepmind/optax/examples/lookahead_mnist.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "2c0670cb",
"metadata": {},
"source": [
"# Optax Lookahead Optimizer: Bug Identification and Fix\n",
"\n",
"This notebook demonstrates how to identify, fix, and verify a bug related to the usage of the Optax lookahead optimizer. We will:\n",
"\n",
"1. Identify the issue in the code.\n",
"2. Reproduce the bug.\n",
"3. Apply the fix.\n",
"4. Verify the fix with unit tests.\n",
"5. Check the output.\n",
"6. Run the fixed code in the integrated terminal."
]
},
{
"cell_type": "markdown",
"id": "dde58b8b",
"metadata": {},
"source": [
"## 1. Identify the Issue\n",
"\n",
"Suppose we have a bug in our usage of the Optax lookahead optimizer, such as incorrect initialization or improper application in a training loop. Below is a snippet of the problematic code section:\n",
"\n",
"```python\n",
"import optax\n",
"base_optimizer = optax.sgd(learning_rate=0.1)\n",
"lookahead = optax.lookahead(base_optimizer)\n",
"# ...\n",
"# Incorrect usage: not updating the lookahead state properly\n",
"```\n",
"\n",
"The issue: The lookahead optimizer state is not being updated correctly during training, leading to suboptimal or incorrect training behavior."
]
},
{
"cell_type": "markdown",
"id": "6d5deef3",
"metadata": {},
"source": [
"## 2. Reproduce the Bug\n",
"\n",
"Let's reproduce the bug by running a minimal MNIST training loop using the incorrect lookahead usage. This will show the error or unexpected behavior."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "54741a5a",
"metadata": {},
"outputs": [],
"source": [
"# Minimal MNIST training loop with incorrect lookahead usage\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import optax\n",
"import numpy as np\n",
"\n",
"# Dummy data for demonstration\n",
"x = jnp.ones((32, 784))\n",
"y = jnp.zeros((32,), dtype=jnp.int32)\n",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The dummy data setup with y as all zeros, combined with zero-initialized parameters, results in an initial loss of 0 and zero gradients. Consequently, the optimizer will not update the parameters, and the loss will not decrease. This prevents the notebook from demonstrating that the optimizer is working. To fix this, initialize y with non-zero values to ensure there is a non-zero loss and gradient at the start of training.

y = jnp.ones((32,), dtype=jnp.int32)

"\n",
"# Simple model\n",
"def model(params, x):\n",
" return jnp.dot(x, params['w']) + params['b']\n",
"\n",
"def loss_fn(params, x, y):\n",
" logits = model(params, x)\n",
" return jnp.mean((logits - y) ** 2)\n",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The loss_fn will raise a ValueError because of a shape mismatch during subtraction. logits has a shape of (32, 10), while y has a shape of (32,). These shapes are not compatible for broadcasting. To fix this, you should reshape y to (32, 1) to make it a column vector, which can then be broadcast correctly across the logits matrix.

    return jnp.mean((logits - y[:, None]) ** 2)

"\n",
"params = {'w': jnp.zeros((784, 10)), 'b': jnp.zeros((10,))}\n",
"base_optimizer = optax.sgd(learning_rate=0.1)\n",
"lookahead = optax.lookahead(base_optimizer)\n",
"opt_state = lookahead.init(params)\n",
"\n",
"@jax.jit\n",
"def update(params, opt_state, x, y):\n",
" grads = jax.grad(loss_fn)(params, x, y)\n",
" updates, new_opt_state = lookahead.update(grads, opt_state, params)\n",
" new_params = optax.apply_updates(params, updates)\n",
" return new_params, new_opt_state\n",
"\n",
"# Incorrect usage: not updating lookahead state properly in a loop\n",
"for step in range(5):\n",
" params, opt_state = update(params, opt_state, x, y)\n",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This line correctly updates the opt_state, which contradicts the section's goal of demonstrating a bug. To properly illustrate the 'incorrect usage', you should simulate a common error, such as failing to update the optimizer state. For example, you could discard the new state returned from the update function.

    params, _ = update(params, opt_state, x, y)  # Bug: opt_state is not updated for the next iteration

" print(f\"Step {step}, Loss: {loss_fn(params, x, y)}\")"
]
},
{
"cell_type": "markdown",
"id": "a6793fd0",
"metadata": {},
"source": [
"## 3. Apply the Fix\n",
"\n",
"To fix the bug, ensure that the lookahead optimizer state is updated correctly and that the slow weights are properly synchronized. Here is the corrected code:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "862b51f5",
"metadata": {},
"outputs": [],
"source": [
"# Corrected lookahead usage\n",
"params = {'w': jnp.zeros((784, 10)), 'b': jnp.zeros((10,))}\n",
"base_optimizer = optax.sgd(learning_rate=0.1)\n",
"lookahead = optax.lookahead(base_optimizer)\n",
"opt_state = lookahead.init(params)\n",
"\n",
"@jax.jit\n",
"def update(params, opt_state, x, y):\n",
" grads = jax.grad(loss_fn)(params, x, y)\n",
" updates, new_opt_state = lookahead.update(grads, opt_state, params)\n",
" new_params = optax.apply_updates(params, updates)\n",
" return new_params, new_opt_state\n",
"\n",
"for step in range(5):\n",
" params, opt_state = update(params, opt_state, x, y)\n",
" # Correct: always use the updated opt_state\n",
" print(f\"Step {step}, Loss: {loss_fn(params, x, y)}\")"
]
},
{
"cell_type": "markdown",
"id": "053261f5",
"metadata": {},
"source": [
"## 4. Verify the Fix with Unit Tests\n",
"\n",
"Let's write a simple test to confirm that the lookahead optimizer now updates the parameters and state as expected."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b4ffbe0f",
"metadata": {},
"outputs": [],
"source": [
"# Simple test: check that parameters are updated\n",
"params = {'w': jnp.zeros((784, 10)), 'b': jnp.zeros((10,))}\n",
"opt_state = lookahead.init(params)\n",
"initial_loss = loss_fn(params, x, y)\n",
"for _ in range(3):\n",
" params, opt_state = update(params, opt_state, x, y)\n",
"final_loss = loss_fn(params, x, y)\n",
"assert final_loss < initial_loss + 1e-5, \"Loss did not decrease as expected!\"\n",
"print(f\"Initial loss: {initial_loss}, Final loss: {final_loss}\")"
]
},
{
"cell_type": "markdown",
"id": "09d7c3f4",
"metadata": {},
"source": [
"## 5. Check Output in Output Pane\n",
"\n",
"The output above should show a decreasing loss value, confirming that the optimizer is working as expected after the fix."
]
},
{
"cell_type": "markdown",
"id": "15f199a7",
"metadata": {},
"source": [
"## 6. Run in Integrated Terminal\n",
"\n",
"To validate end-to-end functionality, you can run the fixed code in the integrated terminal or as a script. This ensures the bug is resolved in all environments."
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
1 change: 1 addition & 0 deletions Desktop/GSoC/Deepmind/optax/optax
Submodule optax added at ac177a
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"mount_file_id":"1nnp8oaJ0xLYyGRKg7HxMvz3ogzdVhcQo","authorship_tag":"ABX9TyNmFgJvRVwFtguMONKHwZIm"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":3,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"aOlzDFou39zu","executionInfo":{"status":"ok","timestamp":1769608951676,"user_tz":-330,"elapsed":3956,"user":{"displayName":"Auysh Debnath","userId":"15768827574649638244"}},"outputId":"8ae9fb90-4cf5-411a-8033-66cd53c8e719"},"outputs":[{"output_type":"stream","name":"stdout","text":["Traceback (most recent call last):\n"," File \"/content/drive/MyDrive/purdueprj/main_train.py\", line 25, in <module>\n"," main()\n"," File \"/content/drive/MyDrive/purdueprj/main_train.py\", line 14, in main\n"," ds = ModelNetDataset(root=\"/content/drive/MyDrive/ModelNet10\", split='train')\n"," ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n"," File \"/content/drive/MyDrive/purdueprj/data/modelnet.py\", line 12, in __init__\n"," self.files = self._scan_files()\n"," ^^^^^^^^^^^^^^^^^^\n"," File \"/content/drive/MyDrive/purdueprj/data/modelnet.py\", line 17, in _scan_files\n"," for class_name in sorted(os.listdir(self.root)):\n"," ^^^^^^^^^^^^^^^^^^^^^\n","FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/ModelNet10'\n"]}],"source":["# !python /content/drive/MyDrive/purdueprj/main_train.py\n"]},{"cell_type":"code","source":["!mkdir -p /content/drive/MyDrive/ModelNet10\n","\n","!wget -O /content/drive/MyDrive/ModelNet10.zip \\\n"," https://huggingface.co/datasets/ShapeNet/ModelNet10/resolve/main/ModelNet10_npy.zip\n","\n","!unzip /content/drive/MyDrive/ModelNet10.zip -d /content/drive/MyDrive/\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"tkQlOM5c6cTo","executionInfo":{"status":"ok","timestamp":1769609193866,"user_tz":-330,"elapsed":456,"user":{"displayName":"Auysh Debnath","userId":"15768827574649638244"}},"outputId":"5e457186-3a33-425c-dff8-cfb89153266e"},"execution_count":4,"outputs":[{"output_type":"stream","name":"stdout","text":["--2026-01-28 14:06:32-- https://huggingface.co/datasets/ShapeNet/ModelNet10/resolve/main/ModelNet10_npy.zip\n","Resolving huggingface.co (huggingface.co)... 18.239.50.103, 18.239.50.49, 18.239.50.80, ...\n","Connecting to huggingface.co (huggingface.co)|18.239.50.103|:443... connected.\n","HTTP request sent, awaiting response... 401 Unauthorized\n","\n","Username/Password Authentication Failed.\n","Archive: /content/drive/MyDrive/ModelNet10.zip\n"," End-of-central-directory signature not found. Either this file is not\n"," a zipfile, or it constitutes one disk of a multi-part archive. In the\n"," latter case the central directory and zipfile comment will be found on\n"," the last disk(s) of this archive.\n","unzip: cannot find zipfile directory in one of /content/drive/MyDrive/ModelNet10.zip or\n"," /content/drive/MyDrive/ModelNet10.zip.zip, and cannot find /content/drive/MyDrive/ModelNet10.zip.ZIP, period.\n"]}]},{"cell_type":"code","source":[],"metadata":{"id":"EvHHfqNl6b_c"},"execution_count":null,"outputs":[]}]}
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
import torch
import numpy as np
from torch.utils.data import Dataset

class ModelNetDataset(Dataset):
def __init__(self, root, num_points=1024, split='train'):
self.root = root
self.num_points = num_points
self.split = split

self.files = self._scan_files()
self.data, self.labels = self._load_all()

def _scan_files(self):
items = []
for class_name in sorted(os.listdir(self.root)):
class_path = os.path.join(self.root, class_name, self.split)
if not os.path.isdir(class_path):
continue
label = sorted(os.listdir(self.root)).index(class_name)
for f in os.listdir(class_path):
if f.endswith('.npy') or f.endswith('.txt'):
items.append((os.path.join(class_path, f), label))
return items

def _load_points(self, path):
if path.endswith('.npy'):
pts = np.load(path).astype(np.float32)
else:
pts = np.loadtxt(path).astype(np.float32)
return pts

def _load_all(self):
all_pts, all_labels = [], []
for path, label in self.files:
pts = self._load_points(path)

if pts.shape[0] >= self.num_points:
idx = np.random.choice(pts.shape[0], self.num_points, replace=False)
pts = pts[idx]
else:
pad = self.num_points - pts.shape[0]
pts = np.vstack([pts, pts[:pad]])

all_pts.append(pts)
all_labels.append(label)

return np.array(all_pts), np.array(all_labels)

def __len__(self):
return len(self.labels)

def __getitem__(self, idx):
pts = self.data[idx]
label = self.labels[idx]
np.random.shuffle(pts)
return torch.tensor(pts, dtype=torch.float32), torch.tensor(label, dtype=torch.long)
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch

def slice_random(points, T=8):
N = points.shape[0]
perm = torch.randperm(N)
return torch.chunk(perm, T)

def slice_radial(points, T=8):
center = points.mean(dim=0)
dist = torch.norm(points - center, dim=1)
perm = torch.argsort(dist) # inner → outer
return torch.chunk(perm, T)

def slice_pca(points, T=8):
X = points - points.mean(dim=0)
U, S, V = torch.pca_lowrank(X)
pc1 = V[:, 0]
proj = X @ pc1
perm = torch.argsort(proj)
return torch.chunk(perm, T)
25 changes: 25 additions & 0 deletions Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/main_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch
from torch.utils.data import DataLoader

from data.modelnet import ModelNetDataset
from models.pointnet_snn import PointNetSNN
from training.train_loop import train_one_epoch
from training.optimizers import build_optimizer

def main():

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Debug dataset
ds = ModelNetDataset(root="/content/drive/MyDrive/ModelNet10", split='train')
dataloader = DataLoader(ds, batch_size=1, shuffle=True)

model = PointNetSNN(num_classes=10).to(device)
optimizer = build_optimizer(model)

for epoch in range(5):
loss, acc = train_one_epoch(model, dataloader, optimizer, device)
print(f"Epoch {epoch} | Loss: {loss:.4f} | Acc: {acc:.4f}")

if __name__ == "__main__":
main()
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch
import torch.nn as nn
from models.snn_layers import LIFLayer

class PointNetBackbone(nn.Module):
def __init__(self, hidden_dims=[64,128,256]):
super().__init__()
self.layers = nn.ModuleList()
in_dim = 3
for h in hidden_dims:
self.layers.append(LIFLayer(in_dim, h))
in_dim = h

def reset_state(self, batch_size, device=None):
for layer in self.layers:
layer.reset_state(batch_size, device)

def forward(self, pts):
# pts: [B, N, 3]
B, N, _ = pts.shape
x = pts.view(B*N, -1) # merge points
for layer in self.layers:
spk, mem = layer(x)
x = mem
return mem.view(B, N, -1) # unmerge
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch
import torch.nn as nn

from models.pointnet_backbone import PointNetBackbone
from models.temporal_snn import TemporalSNN


class PointNetSNN(nn.Module):
"""
Full SNN-PointNet model:
- Per-point spiking MLP (PointNet backbone)
- Slice pooling (mean pooling)
- Temporal SNN accumulator
- Final classifier (inside temporal SNN module)
"""

def __init__(self,
point_dims=[64, 128, 256], # per-point feature sizes
temporal_dim=256, # temporal SNN hidden dimension
num_classes=10): # ModelNet10 for debugging
super().__init__()

self.backbone = PointNetBackbone(hidden_dims=point_dims)
self.temporal = TemporalSNN(dim=temporal_dim)
self.num_classes = num_classes

def reset_state(self, batch_size, device=None):
"""Reset membrane states before each new sample."""
self.backbone.reset_state(batch_size, device)
self.temporal.reset_state(batch_size, device)

def forward_step(self, pts_slice):
"""
Process a single slice of points.
pts_slice: [B, n_points, 3]

Returns:
logits_t : [B, num_classes]
"""

# 1. Per-point spiking MLP
per_point_feat = self.backbone(pts_slice) # [B, M, 256]

# 2. Mean pooling across slice points → slice embedding
slice_feat = per_point_feat.mean(dim=1) # [B, 256]

# 3. Feed slice embedding into temporal SNN
logits_t = self.temporal(slice_feat) # [B, num_classes]

return logits_t
Loading