diff --git a/Desktop/GSoC/Deepmind/optax/examples/lookahead_mnist.ipynb b/Desktop/GSoC/Deepmind/optax/examples/lookahead_mnist.ipynb new file mode 100644 index 00000000..6cc3249f --- /dev/null +++ b/Desktop/GSoC/Deepmind/optax/examples/lookahead_mnist.ipynb @@ -0,0 +1,185 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "2c0670cb", + "metadata": {}, + "source": [ + "# Optax Lookahead Optimizer: Bug Identification and Fix\n", + "\n", + "This notebook demonstrates how to identify, fix, and verify a bug related to the usage of the Optax lookahead optimizer. We will:\n", + "\n", + "1. Identify the issue in the code.\n", + "2. Reproduce the bug.\n", + "3. Apply the fix.\n", + "4. Verify the fix with unit tests.\n", + "5. Check the output.\n", + "6. Run the fixed code in the integrated terminal." + ] + }, + { + "cell_type": "markdown", + "id": "dde58b8b", + "metadata": {}, + "source": [ + "## 1. Identify the Issue\n", + "\n", + "Suppose we have a bug in our usage of the Optax lookahead optimizer, such as incorrect initialization or improper application in a training loop. Below is a snippet of the problematic code section:\n", + "\n", + "```python\n", + "import optax\n", + "base_optimizer = optax.sgd(learning_rate=0.1)\n", + "lookahead = optax.lookahead(base_optimizer)\n", + "# ...\n", + "# Incorrect usage: not updating the lookahead state properly\n", + "```\n", + "\n", + "The issue: The lookahead optimizer state is not being updated correctly during training, leading to suboptimal or incorrect training behavior." + ] + }, + { + "cell_type": "markdown", + "id": "6d5deef3", + "metadata": {}, + "source": [ + "## 2. Reproduce the Bug\n", + "\n", + "Let's reproduce the bug by running a minimal MNIST training loop using the incorrect lookahead usage. This will show the error or unexpected behavior." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "54741a5a", + "metadata": {}, + "outputs": [], + "source": [ + "# Minimal MNIST training loop with incorrect lookahead usage\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import optax\n", + "import numpy as np\n", + "\n", + "# Dummy data for demonstration\n", + "x = jnp.ones((32, 784))\n", + "y = jnp.zeros((32,), dtype=jnp.int32)\n", + "\n", + "# Simple model\n", + "def model(params, x):\n", + " return jnp.dot(x, params['w']) + params['b']\n", + "\n", + "def loss_fn(params, x, y):\n", + " logits = model(params, x)\n", + " return jnp.mean((logits - y) ** 2)\n", + "\n", + "params = {'w': jnp.zeros((784, 10)), 'b': jnp.zeros((10,))}\n", + "base_optimizer = optax.sgd(learning_rate=0.1)\n", + "lookahead = optax.lookahead(base_optimizer)\n", + "opt_state = lookahead.init(params)\n", + "\n", + "@jax.jit\n", + "def update(params, opt_state, x, y):\n", + " grads = jax.grad(loss_fn)(params, x, y)\n", + " updates, new_opt_state = lookahead.update(grads, opt_state, params)\n", + " new_params = optax.apply_updates(params, updates)\n", + " return new_params, new_opt_state\n", + "\n", + "# Incorrect usage: not updating lookahead state properly in a loop\n", + "for step in range(5):\n", + " params, opt_state = update(params, opt_state, x, y)\n", + " print(f\"Step {step}, Loss: {loss_fn(params, x, y)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "a6793fd0", + "metadata": {}, + "source": [ + "## 3. Apply the Fix\n", + "\n", + "To fix the bug, ensure that the lookahead optimizer state is updated correctly and that the slow weights are properly synchronized. Here is the corrected code:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "862b51f5", + "metadata": {}, + "outputs": [], + "source": [ + "# Corrected lookahead usage\n", + "params = {'w': jnp.zeros((784, 10)), 'b': jnp.zeros((10,))}\n", + "base_optimizer = optax.sgd(learning_rate=0.1)\n", + "lookahead = optax.lookahead(base_optimizer)\n", + "opt_state = lookahead.init(params)\n", + "\n", + "@jax.jit\n", + "def update(params, opt_state, x, y):\n", + " grads = jax.grad(loss_fn)(params, x, y)\n", + " updates, new_opt_state = lookahead.update(grads, opt_state, params)\n", + " new_params = optax.apply_updates(params, updates)\n", + " return new_params, new_opt_state\n", + "\n", + "for step in range(5):\n", + " params, opt_state = update(params, opt_state, x, y)\n", + " # Correct: always use the updated opt_state\n", + " print(f\"Step {step}, Loss: {loss_fn(params, x, y)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "053261f5", + "metadata": {}, + "source": [ + "## 4. Verify the Fix with Unit Tests\n", + "\n", + "Let's write a simple test to confirm that the lookahead optimizer now updates the parameters and state as expected." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4ffbe0f", + "metadata": {}, + "outputs": [], + "source": [ + "# Simple test: check that parameters are updated\n", + "params = {'w': jnp.zeros((784, 10)), 'b': jnp.zeros((10,))}\n", + "opt_state = lookahead.init(params)\n", + "initial_loss = loss_fn(params, x, y)\n", + "for _ in range(3):\n", + " params, opt_state = update(params, opt_state, x, y)\n", + "final_loss = loss_fn(params, x, y)\n", + "assert final_loss < initial_loss + 1e-5, \"Loss did not decrease as expected!\"\n", + "print(f\"Initial loss: {initial_loss}, Final loss: {final_loss}\")" + ] + }, + { + "cell_type": "markdown", + "id": "09d7c3f4", + "metadata": {}, + "source": [ + "## 5. Check Output in Output Pane\n", + "\n", + "The output above should show a decreasing loss value, confirming that the optimizer is working as expected after the fix." + ] + }, + { + "cell_type": "markdown", + "id": "15f199a7", + "metadata": {}, + "source": [ + "## 6. Run in Integrated Terminal\n", + "\n", + "To validate end-to-end functionality, you can run the fixed code in the integrated terminal or as a script. This ensures the bug is resolved in all environments." + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/Desktop/GSoC/Deepmind/optax/optax b/Desktop/GSoC/Deepmind/optax/optax new file mode 160000 index 00000000..ac177aa4 --- /dev/null +++ b/Desktop/GSoC/Deepmind/optax/optax @@ -0,0 +1 @@ +Subproject commit ac177aa49bc57d49e20b6cbadef104b2c496b877 diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/Untitled0.ipynb b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/Untitled0.ipynb new file mode 100644 index 00000000..8d09d668 --- /dev/null +++ b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/Untitled0.ipynb @@ -0,0 +1 @@ +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"mount_file_id":"1nnp8oaJ0xLYyGRKg7HxMvz3ogzdVhcQo","authorship_tag":"ABX9TyNmFgJvRVwFtguMONKHwZIm"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":3,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"aOlzDFou39zu","executionInfo":{"status":"ok","timestamp":1769608951676,"user_tz":-330,"elapsed":3956,"user":{"displayName":"Auysh Debnath","userId":"15768827574649638244"}},"outputId":"8ae9fb90-4cf5-411a-8033-66cd53c8e719"},"outputs":[{"output_type":"stream","name":"stdout","text":["Traceback (most recent call last):\n"," File \"/content/drive/MyDrive/purdueprj/main_train.py\", line 25, in \n"," main()\n"," File \"/content/drive/MyDrive/purdueprj/main_train.py\", line 14, in main\n"," ds = ModelNetDataset(root=\"/content/drive/MyDrive/ModelNet10\", split='train')\n"," ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n"," File \"/content/drive/MyDrive/purdueprj/data/modelnet.py\", line 12, in __init__\n"," self.files = self._scan_files()\n"," ^^^^^^^^^^^^^^^^^^\n"," File \"/content/drive/MyDrive/purdueprj/data/modelnet.py\", line 17, in _scan_files\n"," for class_name in sorted(os.listdir(self.root)):\n"," ^^^^^^^^^^^^^^^^^^^^^\n","FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/ModelNet10'\n"]}],"source":["# !python /content/drive/MyDrive/purdueprj/main_train.py\n"]},{"cell_type":"code","source":["!mkdir -p /content/drive/MyDrive/ModelNet10\n","\n","!wget -O /content/drive/MyDrive/ModelNet10.zip \\\n"," https://huggingface.co/datasets/ShapeNet/ModelNet10/resolve/main/ModelNet10_npy.zip\n","\n","!unzip /content/drive/MyDrive/ModelNet10.zip -d /content/drive/MyDrive/\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"tkQlOM5c6cTo","executionInfo":{"status":"ok","timestamp":1769609193866,"user_tz":-330,"elapsed":456,"user":{"displayName":"Auysh Debnath","userId":"15768827574649638244"}},"outputId":"5e457186-3a33-425c-dff8-cfb89153266e"},"execution_count":4,"outputs":[{"output_type":"stream","name":"stdout","text":["--2026-01-28 14:06:32-- https://huggingface.co/datasets/ShapeNet/ModelNet10/resolve/main/ModelNet10_npy.zip\n","Resolving huggingface.co (huggingface.co)... 18.239.50.103, 18.239.50.49, 18.239.50.80, ...\n","Connecting to huggingface.co (huggingface.co)|18.239.50.103|:443... connected.\n","HTTP request sent, awaiting response... 401 Unauthorized\n","\n","Username/Password Authentication Failed.\n","Archive: /content/drive/MyDrive/ModelNet10.zip\n"," End-of-central-directory signature not found. Either this file is not\n"," a zipfile, or it constitutes one disk of a multi-part archive. In the\n"," latter case the central directory and zipfile comment will be found on\n"," the last disk(s) of this archive.\n","unzip: cannot find zipfile directory in one of /content/drive/MyDrive/ModelNet10.zip or\n"," /content/drive/MyDrive/ModelNet10.zip.zip, and cannot find /content/drive/MyDrive/ModelNet10.zip.ZIP, period.\n"]}]},{"cell_type":"code","source":[],"metadata":{"id":"EvHHfqNl6b_c"},"execution_count":null,"outputs":[]}]} \ No newline at end of file diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/data/__init__.py b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/data/__pycache__/__init__.cpython-312.pyc b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/data/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 00000000..a46d7f3e Binary files /dev/null and b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/data/__pycache__/__init__.cpython-312.pyc differ diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/data/__pycache__/modelnet.cpython-312.pyc b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/data/__pycache__/modelnet.cpython-312.pyc new file mode 100644 index 00000000..9c019b8d Binary files /dev/null and b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/data/__pycache__/modelnet.cpython-312.pyc differ diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/data/__pycache__/slicing.cpython-312.pyc b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/data/__pycache__/slicing.cpython-312.pyc new file mode 100644 index 00000000..028f15e7 Binary files /dev/null and b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/data/__pycache__/slicing.cpython-312.pyc differ diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/data/modelnet.py b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/data/modelnet.py new file mode 100644 index 00000000..49b1db38 --- /dev/null +++ b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/data/modelnet.py @@ -0,0 +1,58 @@ +import os +import torch +import numpy as np +from torch.utils.data import Dataset + +class ModelNetDataset(Dataset): + def __init__(self, root, num_points=1024, split='train'): + self.root = root + self.num_points = num_points + self.split = split + + self.files = self._scan_files() + self.data, self.labels = self._load_all() + + def _scan_files(self): + items = [] + for class_name in sorted(os.listdir(self.root)): + class_path = os.path.join(self.root, class_name, self.split) + if not os.path.isdir(class_path): + continue + label = sorted(os.listdir(self.root)).index(class_name) + for f in os.listdir(class_path): + if f.endswith('.npy') or f.endswith('.txt'): + items.append((os.path.join(class_path, f), label)) + return items + + def _load_points(self, path): + if path.endswith('.npy'): + pts = np.load(path).astype(np.float32) + else: + pts = np.loadtxt(path).astype(np.float32) + return pts + + def _load_all(self): + all_pts, all_labels = [], [] + for path, label in self.files: + pts = self._load_points(path) + + if pts.shape[0] >= self.num_points: + idx = np.random.choice(pts.shape[0], self.num_points, replace=False) + pts = pts[idx] + else: + pad = self.num_points - pts.shape[0] + pts = np.vstack([pts, pts[:pad]]) + + all_pts.append(pts) + all_labels.append(label) + + return np.array(all_pts), np.array(all_labels) + + def __len__(self): + return len(self.labels) + + def __getitem__(self, idx): + pts = self.data[idx] + label = self.labels[idx] + np.random.shuffle(pts) + return torch.tensor(pts, dtype=torch.float32), torch.tensor(label, dtype=torch.long) diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/data/slicing.py b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/data/slicing.py new file mode 100644 index 00000000..bc619231 --- /dev/null +++ b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/data/slicing.py @@ -0,0 +1,20 @@ +import torch + +def slice_random(points, T=8): + N = points.shape[0] + perm = torch.randperm(N) + return torch.chunk(perm, T) + +def slice_radial(points, T=8): + center = points.mean(dim=0) + dist = torch.norm(points - center, dim=1) + perm = torch.argsort(dist) # inner → outer + return torch.chunk(perm, T) + +def slice_pca(points, T=8): + X = points - points.mean(dim=0) + U, S, V = torch.pca_lowrank(X) + pc1 = V[:, 0] + proj = X @ pc1 + perm = torch.argsort(proj) + return torch.chunk(perm, T) diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/inference/__init__.py b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/inference/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/main_train.py b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/main_train.py new file mode 100644 index 00000000..5acc0ab2 --- /dev/null +++ b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/main_train.py @@ -0,0 +1,25 @@ +import torch +from torch.utils.data import DataLoader + +from data.modelnet import ModelNetDataset +from models.pointnet_snn import PointNetSNN +from training.train_loop import train_one_epoch +from training.optimizers import build_optimizer + +def main(): + + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + # Debug dataset + ds = ModelNetDataset(root="/content/drive/MyDrive/ModelNet10", split='train') + dataloader = DataLoader(ds, batch_size=1, shuffle=True) + + model = PointNetSNN(num_classes=10).to(device) + optimizer = build_optimizer(model) + + for epoch in range(5): + loss, acc = train_one_epoch(model, dataloader, optimizer, device) + print(f"Epoch {epoch} | Loss: {loss:.4f} | Acc: {acc:.4f}") + +if __name__ == "__main__": + main() diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/__init__.py b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/__pycache__/__init__.cpython-312.pyc b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 00000000..64ea4580 Binary files /dev/null and b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/__pycache__/__init__.cpython-312.pyc differ diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/__pycache__/pointnet_backbone.cpython-312.pyc b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/__pycache__/pointnet_backbone.cpython-312.pyc new file mode 100644 index 00000000..d0ecfb25 Binary files /dev/null and b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/__pycache__/pointnet_backbone.cpython-312.pyc differ diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/__pycache__/pointnet_snn.cpython-312.pyc b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/__pycache__/pointnet_snn.cpython-312.pyc new file mode 100644 index 00000000..8cf94643 Binary files /dev/null and b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/__pycache__/pointnet_snn.cpython-312.pyc differ diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/__pycache__/snn_layers.cpython-312.pyc b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/__pycache__/snn_layers.cpython-312.pyc new file mode 100644 index 00000000..82014594 Binary files /dev/null and b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/__pycache__/snn_layers.cpython-312.pyc differ diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/__pycache__/temporal_snn.cpython-312.pyc b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/__pycache__/temporal_snn.cpython-312.pyc new file mode 100644 index 00000000..6162ca44 Binary files /dev/null and b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/__pycache__/temporal_snn.cpython-312.pyc differ diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/pointnet_backbone.py b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/pointnet_backbone.py new file mode 100644 index 00000000..61ef53f2 --- /dev/null +++ b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/pointnet_backbone.py @@ -0,0 +1,25 @@ +import torch +import torch.nn as nn +from models.snn_layers import LIFLayer + +class PointNetBackbone(nn.Module): + def __init__(self, hidden_dims=[64,128,256]): + super().__init__() + self.layers = nn.ModuleList() + in_dim = 3 + for h in hidden_dims: + self.layers.append(LIFLayer(in_dim, h)) + in_dim = h + + def reset_state(self, batch_size, device=None): + for layer in self.layers: + layer.reset_state(batch_size, device) + + def forward(self, pts): + # pts: [B, N, 3] + B, N, _ = pts.shape + x = pts.view(B*N, -1) # merge points + for layer in self.layers: + spk, mem = layer(x) + x = mem + return mem.view(B, N, -1) # unmerge diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/pointnet_snn.py b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/pointnet_snn.py new file mode 100644 index 00000000..36805ec3 --- /dev/null +++ b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/pointnet_snn.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn + +from models.pointnet_backbone import PointNetBackbone +from models.temporal_snn import TemporalSNN + + +class PointNetSNN(nn.Module): + """ + Full SNN-PointNet model: + - Per-point spiking MLP (PointNet backbone) + - Slice pooling (mean pooling) + - Temporal SNN accumulator + - Final classifier (inside temporal SNN module) + """ + + def __init__(self, + point_dims=[64, 128, 256], # per-point feature sizes + temporal_dim=256, # temporal SNN hidden dimension + num_classes=10): # ModelNet10 for debugging + super().__init__() + + self.backbone = PointNetBackbone(hidden_dims=point_dims) + self.temporal = TemporalSNN(dim=temporal_dim) + self.num_classes = num_classes + + def reset_state(self, batch_size, device=None): + """Reset membrane states before each new sample.""" + self.backbone.reset_state(batch_size, device) + self.temporal.reset_state(batch_size, device) + + def forward_step(self, pts_slice): + """ + Process a single slice of points. + pts_slice: [B, n_points, 3] + + Returns: + logits_t : [B, num_classes] + """ + + # 1. Per-point spiking MLP + per_point_feat = self.backbone(pts_slice) # [B, M, 256] + + # 2. Mean pooling across slice points → slice embedding + slice_feat = per_point_feat.mean(dim=1) # [B, 256] + + # 3. Feed slice embedding into temporal SNN + logits_t = self.temporal(slice_feat) # [B, num_classes] + + return logits_t diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/snn_layers.py b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/snn_layers.py new file mode 100644 index 00000000..8bacbcbd --- /dev/null +++ b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/snn_layers.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn + +class SurrogateSpike(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + out = (x > 0).float() + ctx.save_for_backward(x) + return out + + @staticmethod + def backward(ctx, grad_output): + (x,) = ctx.saved_tensors + grad = 1.0 / (1 + torch.abs(x))**2 + return grad_output * grad + +spike_fn = SurrogateSpike.apply + +class LIFLayer(nn.Module): + def __init__(self, in_features, out_features, tau=0.9): + super().__init__() + self.fc = nn.Linear(in_features, out_features) + self.tau = tau + self.register_buffer("mem", None) + + def reset_state(self, batch_size, device=None): + dev = device if device else next(self.fc.parameters()).device + self.mem = torch.zeros(batch_size, self.fc.out_features, device=dev) + + def forward(self, x): + cur = self.fc(x) + self.mem = self.tau * self.mem + cur + spk = spike_fn(self.mem - 1.0) + self.mem = self.mem * (1 - spk) + return spk, self.mem diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/temporal_snn.py b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/temporal_snn.py new file mode 100644 index 00000000..723505a3 --- /dev/null +++ b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/models/temporal_snn.py @@ -0,0 +1,20 @@ +import torch +import torch.nn as nn +from models.snn_layers import LIFLayer + +class TemporalSNN(nn.Module): + def __init__(self, dim=256): + super().__init__() + self.lif1 = LIFLayer(dim, dim) + self.lif2 = LIFLayer(dim, dim) + self.fc = nn.Linear(dim, 10) # 10 classes for ModelNet10 + + def reset_state(self, batch_size, device=None): + self.lif1.reset_state(batch_size, device) + self.lif2.reset_state(batch_size, device) + + def forward(self, x): + spk1, mem1 = self.lif1(x) + spk2, mem2 = self.lif2(mem1) + logits = self.fc(mem2) + return logits diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/__init__.py b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/__pycache__/__init__.cpython-312.pyc b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 00000000..9d72c167 Binary files /dev/null and b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/__pycache__/__init__.cpython-312.pyc differ diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/__pycache__/loss_functions.cpython-312.pyc b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/__pycache__/loss_functions.cpython-312.pyc new file mode 100644 index 00000000..fa698829 Binary files /dev/null and b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/__pycache__/loss_functions.cpython-312.pyc differ diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/__pycache__/metrics.cpython-312.pyc b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/__pycache__/metrics.cpython-312.pyc new file mode 100644 index 00000000..3c6e0fcd Binary files /dev/null and b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/__pycache__/metrics.cpython-312.pyc differ diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/__pycache__/optimizers.cpython-312.pyc b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/__pycache__/optimizers.cpython-312.pyc new file mode 100644 index 00000000..77966f53 Binary files /dev/null and b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/__pycache__/optimizers.cpython-312.pyc differ diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/__pycache__/train_loop.cpython-312.pyc b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/__pycache__/train_loop.cpython-312.pyc new file mode 100644 index 00000000..b5ea0353 Binary files /dev/null and b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/__pycache__/train_loop.cpython-312.pyc differ diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/loss_functions.py b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/loss_functions.py new file mode 100644 index 00000000..d2ed287a --- /dev/null +++ b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/loss_functions.py @@ -0,0 +1,15 @@ +import torch +import torch.nn.functional as F + +def ce_loss_final(logits_T, labels): + return F.cross_entropy(logits_T, labels) + +def ce_loss_aux(logits_all, labels, aux_weight=0.2): + """ + logits_all: list of logits at each timestep [logits_1, ..., logits_T] + """ + loss = 0.0 + T = len(logits_all) + for t in range(T - 1): # exclude final + loss += F.cross_entropy(logits_all[t], labels) * aux_weight + return loss diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/metrics.py b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/metrics.py new file mode 100644 index 00000000..b50bd00e --- /dev/null +++ b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/metrics.py @@ -0,0 +1,10 @@ +import torch + +def accuracy(logits, labels): + preds = logits.argmax(dim=1) + return (preds == labels).float().mean().item() + +def margin(logits): + probs = logits.softmax(dim=-1) + top2 = probs.topk(2, dim=-1).values + return (top2[:, 0] - top2[:, 1]).mean().item() diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/optimizers.py b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/optimizers.py new file mode 100644 index 00000000..af95b534 --- /dev/null +++ b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/optimizers.py @@ -0,0 +1,4 @@ +import torch + +def build_optimizer(model, lr=1e-3, weight_decay=1e-4): + return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/train_loop.py b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/train_loop.py new file mode 100644 index 00000000..aac9cf67 --- /dev/null +++ b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/training/train_loop.py @@ -0,0 +1,52 @@ +import torch +from torch.utils.data import DataLoader + +from training.loss_functions import ce_loss_final, ce_loss_aux +from training.metrics import accuracy +from data.slicing import slice_radial, slice_pca, slice_random + + +def train_one_epoch(model, dataloader, optimizer, device, num_slices=8, aux_weight=0.2): + model.train() + + total_loss = 0.0 + total_acc = 0.0 + count = 0 + + for pts, labels in dataloader: + + pts = pts.to(device) # [B,1024,3] + labels = labels.to(device) + + B = pts.size(0) + + model.reset_state(batch_size=B, device=device) + + # Choose slicing method (radial is stable) + slice_idx = slice_radial(pts[0], T=num_slices) # idx chunks for sample 0 + # For batch > 1, you'd compute slicing per-sample, but debug with batch=1. + + logits_all = [] + + # Sequential slice processing + for t in range(num_slices): + idx = slice_idx[t] # indices for this slice + pts_slice = pts[:, idx, :] # [B, slice_size, 3] + + logits_t = model.forward_step(pts_slice) + logits_all.append(logits_t) + + # Compute loss + loss = ce_loss_final(logits_all[-1], labels) + loss += ce_loss_aux(logits_all, labels, aux_weight) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # Track stats + total_loss += loss.item() + total_acc += accuracy(logits_all[-1], labels) + count += 1 + + return total_loss / count, total_acc / count diff --git a/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/utils/__init__.py b/Downloads/purdueprj-20260128T151940Z-3-001/purdueprj/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Gemma/[Gemma_3]Data_Parallel_Inference_JAX_TPU_v5e8.ipynb b/Gemma/[Gemma_3]Data_Parallel_Inference_JAX_TPU_v5e8.ipynb new file mode 100644 index 00000000..e69de29b