Skip to content

Conversation

@Playmaker3334
Copy link

Summary

This PR adds GPU support to the Engram demo and fixes bugs that prevented it from running on CUDA devices.

Bug Fixes

  1. Device mismatch - torch.from_numpy() in Engram.forward() always created CPU tensors, causing runtime errors when model was on GPU
  2. Index out of bounds - CompressedTokenizer._compress() had no bounds checking, could crash with certain input_ids

New Features

HybridNgramHashMapping

Replaces NgramHashMapping with a proper nn.Module that supports both CPU and GPU:

if input_size < self.gpu_threshold:
    # NumPy path - lower overhead for small inputs
else:
    # PyTorch path - better throughput for large inputs

Key changes:

  • Multipliers stored as register_buffer() for automatic device transfer
  • _hash_gpu() uses torch.bitwise_xor instead of numpy
  • Configurable threshold via EngramConfig.gpu_threshold

CompressedTokenizer

  • compress_cpu(): fast numpy path with bounds checking
  • compress_gpu(): lazy tensor initialization, tracks device

Benchmark Suite

New files for validation:

  • benchmark.py: measures latency across configs
  • test_correctness.py: verifies numerical equivalence

Benchmark Results

Metric Value
Mean speedup 1.02x
Median speedup 1.01x
Max speedup 1.09x
Memory delta -0.03%

Tested on NVIDIA GPU with batch_size=[2,4,8], seq_len=[128,256,512]

Backward Compatibility

  • Original engram_demo_v1.py unchanged
  • All original APIs preserved in optimized version
  • Numerical outputs match within rtol=1e-4, atol=1e-5

- Convert CUDA tensors to CPU before numpy conversion in CompressedTokenizer
- Fixes TypeError when running on GPU: 'can't convert cuda:0 device type tensor to numpy'
- Maintains backward compatibility with CPU-only usage
- Use actual tokenizer vocab size instead of config value
- Prevents IndexError when generating test data
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant