Refactor TitansMAG forward logic and fix memory state handling#32
Refactor TitansMAG forward logic and fix memory state handling#32ryancinsight wants to merge 252 commits intomainfrom
Conversation
fix(readme): correct repo URL and directory path in Quick Start
* isolate data loading * pair * encode to bytes for vocab * data loading from json * data loading from csv * csv files added * cargo run works! * cargo update and dataset_loader redundant paren --------- Co-authored-by: anshumanpatil <info@anshumanpatil.com> Co-authored-by: Nikhil Sriram <nikhil.sriram5@gmail.com> Co-authored-by: hobs <github@totalgood.com>
CI script to build and test
Fix Readme Page Badge
…urves - Add AdaptiveScalar type supporting fixed values or Richards curve modulation - Implement set_training_progress method throughout layer hierarchy for adaptive parameters - Add CLI arguments for adaptive modulation of ce_weight, min_snr_gamma, and MoH thresholds - Integrate training progress tracking into training loops and forward passes - Update diffusion training to use adaptive scalars for loss weights and SNR gamma
- Validated that TitansMAL correctly processes input through NeuralMemory and then SlidingWindowAttention. - Added a `#[cfg(test)]` module to `src/memory/titans/mal.rs` to verify the forward pass output shape. - Note: Integration tests could not be run due to pre-existing compilation errors in `src/layers/transformer/block.rs` and `model_config` (unrelated to this task).
Introduce AdaptiveScalar enum to modulate MoH activation thresholds based on training progress. This enables dynamic gating behavior where thresholds can be fixed, follow a Richards curve, or be learned. The modulation is integrated into MoHGating, training pipeline, and all block types (Transformer, Diffusion, HRM, LRM, Mamba2). Update configuration structures to include moh_threshold_modulation field and modify initialization to use Box<RichardsCurve> for proper ownership. Add proptest regression tests for adaptive behavior.
- Add predict_with_limit method to control generation length - Eliminate unnecessary allocations in neural memory by using row views - Optimize engram memory with scratch buffer and Zip operations - Improve hybrid memory with cached positional encoding and scratch buffers - Add contrastive margin and gradient methods to adaptive residuals - Remove outdated attention documentation file
…e cloning Replaced full struct cloning with manual lightweight construction of the RichardsCurve instance. This avoids copying heavy heap-allocated fields (optimizer, grad_norm_history, gamma, bias) when creating temporary scaled views for gating operations. Performance improvement: ~83% reduction in execution time (from ~1.72 µs to ~288 ns per call). Added benchmark `benches/richards_curve_bench.rs` to verify the improvement.
…915794032518470943
- Implement `HeadCache` for managing Key/Value states in attention heads. - Update `PolyAttention` to support incremental forward pass with caching. - Optimize `LLM::forward_with_limit` to pass only new tokens during generation. - Add dynamic switching between sequential/parallel execution in attention forward pass to optimize for small batches. - Add benchmark `benches/inference.rs` validating ~28x speedup.
…ation-4029054152255653832
Optimizes `get_data_from_json` by attempting to deserialize into a strongly typed `TextRow` struct before falling back to `serde_json::Value`. This avoids the overhead of parsing generic Values and allocating Maps for the common case where the JSON is an array of objects with a "text" field. Measured improvement: ~22% reduction in loading time for 10k rows (39.75ms -> 30.77ms). Includes new benchmark `benches/json_loading.rs`.
- Replaced heap-allocated Vec with stack-allocated buffer in inner loop - Used `grad_weights_scalar_into` for zero-allocation gradient computation - Added safety assertion for scalar parameter count - Added benchmark case for grad_weights_matrix - Achieved ~22% performance improvement in micro-benchmark
…4091504918954747775
…747997173202827830
Refactored PolyAttention to allow extracting and injecting state (cache) for gradient computation. Updated TitansMAC to capture this state during chunked forward processing and use it during chunked backward processing. This eliminates the need to clone and re-run the attention core during the backward pass, fixing the state assumption issue and improving correctness for chunked inputs. - Added `PolyAttentionCache` struct. - Added `take_cache` and `compute_gradients_with_cache` to `PolyAttention`. - Refactored `compute_gradients_parallel` to `compute_gradients_parallel_from_state`. - Updated `TitansMAC` to use `PolyAttentionCache` in `SegmentForwardData`.
…8210137866213949 Fix PolyAttention state assumption in TitansMAC
Replaced full-dataset pre-tokenization with lazy, batched tokenization inside training loops (`train_with_warmup`, `train_with_warmup_eprop`, `train_trm_autoencoding`, `train_diffusion_ce`). This significantly reduces memory pressure for large datasets by only holding one batch of tokens in memory at a time, at the cost of re-tokenizing data each epoch (which is parallelized). Fixes memory inefficiency in `src/models/llm.rs`.
Fixed a critical issue in `train_diffusion_ce` where the model was training on the entire dataset (including the validation split) due to incorrect loop boundaries. - Split `data` into `train_data` and `val_data` upfront using the validation ratio. - Training loop now iterates only over `train_data` chunks. - Validation loop iterates over `val_data` chunks (unchanged logic, just context).
…-9622796312022304099 ⚡ Optimize training loops with lazy batch tokenization
Add gradient computation and parameter updates to the engram memory module and the hybrid memory layer. The engram cache now properly handles zero-size configurations, and the hybrid memory caches intermediate results during forward pass to enable efficient backward pass. The implementation includes: - compute_gradients and apply_gradients methods for EngramMemory - Full backpropagation support for HybridMemory with proper gradient flow - Cache optimizations to avoid repeated computations - Support for both adaptive and fixed routing during training - Proper handling of cache statistics within the engram memory structure
Move cache insertion after usage to avoid unnecessary clone operations. Refactor cache lookup logic to eliminate redundant cloning and improve hit handling when tier 1 is disabled. Simplify eviction logic in insert to maintain proper size limits across both cache tiers.
- Add gradient_count methods to NeuralMemory and EngramMemory - Extend LLM forward passes to handle DiffusionBlock and LRM layers with similarity context - Implement activation similarity matrix tracking and context propagation in LRM and DiffusionBlock - Enhance HybridMemory with learnable routing hyperparameters and improved gradient flow - Add similarity context strength parameter with optimizer in DiffusionBlock
…ty propagation - Introduce `moh_moe_contrastive_weight` configuration to align MoH head activity with MoE routing distributions - Propagate activation similarity context between diffusion blocks to improve coherence - Cache FiLM gamma/beta vectors to avoid reallocation in forward pass - Extend adaptive residuals and feed-forward layers to utilize head activity features - Implement contrastive loss as symmetric KL divergence between head‑conditioned and token‑routed expert distributions
- Added `self.memory.reset_memory()` to `TitansMAG::forward` to ensure consistent state initialization, matching the assumption in `compute_gradients` that execution starts from `init_memory`. This prevents state leakage across independent forward calls (e.g., between batches). - Refactored `src/memory/titans/mag.rs` to use `NeuralMemory::mlp_forward` instead of a duplicated local helper function, improving code maintainability. - Exposed `NeuralMemory::mlp_forward` as `pub(crate)` in `src/memory/titans/neural.rs` to support the refactoring. - Added `test_titans_mag_deterministic_forward` to verify that repeated forward passes on the same input produce identical outputs, confirming the state reset logic works. Co-authored-by: ryancinsight <55164720+ryancinsight@users.noreply.github.com>
|
👋 Jules, reporting for duty! I'm here to lend a hand with this pull request. When you start a review, I'll add a 👀 emoji to each comment to let you know I've read it. I'll focus on feedback directed at me and will do my best to stay out of conversations between you and other bots or reviewers to keep the noise down. I'll push a commit with your requested changes shortly after. Please note there might be a delay between these steps, but rest assured I'm on the job! For more direct control, you can switch me to Reactive Mode. When this mode is on, I will only act on comments where you specifically mention me with New to Jules? Learn more at jules.google/docs. For security, I will only act on instructions from the user who triggered this task. |
|
Warning Rate limit exceeded
⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @ryancinsight, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refactors the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Review by RecurseML
🔍 Review performed on a004bfd..7f24bef
✨ No bugs found, your code is sparkling clean
✅ Files analyzed, no issues (2)
• src/memory/titans/mag.rs
• src/memory/titans/neural.rs
There was a problem hiding this comment.
Code Review
This pull request introduces a critical fix to the TitansMAG forward pass by ensuring the memory state is reset between calls. This prevents state leakage across batches and aligns the forward pass with the assumptions of the backward pass, resolving a potential bug in training loops. The changes also include a valuable refactoring that removes duplicated code by centralizing the mlp_forward logic into the NeuralMemory implementation. A new regression test has been added to verify that the forward pass is now deterministic. The changes are well-implemented, clearly explained, and improve both the correctness and maintainability of the code.
This PR refactors the
TitansMAGforward pass implementation to ensure correctness and maintainability.Key changes:
reset_memory()at the start offorward. In the Titans architecture, the neural memory is typically treated as "fast weights" learned over the context window. The gradient computation (compute_gradients) assumes the memory starts atinit_memory. Without an explicit reset inforward, the memory state would persist across calls (e.g., across training batches), leading to a disconnect between the forward pass state and the backward pass assumptions. This fixes potential bugs in training loops.mlp_forwardhelper function inmag.rsand exposed the existing implementation inNeuralMemory(neural.rs) to the crate. This eliminates code duplication.test_titans_mag_deterministic_forwardto confirm that the module is stateless across calls, as expected for a standard layer in this context. All existing tests passed.PR created automatically by Jules for task 6565573961390046446 started by @ryancinsight
High-level PR Summary
This PR refactors the
TitansMAGforward pass to fix memory state handling and eliminate code duplication. The key fix addsreset_memory()at the start offorwardto ensure the memory state is reset between calls, preventing state persistence across training batches that would violate backward pass assumptions. The refactor also removes a duplicatedmlp_forwardhelper function and exposes the existing implementation fromNeuralMemoryaspub(crate)instead. A new test verifies that forward passes are now deterministic across calls.⏱️ Estimated Review Time: 5-15 minutes
💡 Review Order Suggestion
src/memory/titans/neural.rssrc/memory/titans/mag.rs