-
-
Notifications
You must be signed in to change notification settings - Fork 53
[mcall_correlated] Convert to JAX and content checks #616
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
📖 Netlify Preview Ready! Preview URL: https://pr-616--sunny-cactus-210e3e.netlify.app (7b5bb5d) 📚 Changed Lecture Pages: mccall_correlated |
|
📖 Netlify Preview Ready! Preview URL: https://pr-616--sunny-cactus-210e3e.netlify.app (d3a847c) 📚 Changed Lecture Pages: mccall_correlated |
|
📖 Netlify Preview Ready! Preview URL: https://pr-616--sunny-cactus-210e3e.netlify.app (c43170a) 📚 Changed Lecture Pages: mccall_correlated |
|
📖 Netlify Preview Ready! Preview URL: https://pr-616--sunny-cactus-210e3e.netlify.app (03884e1) 📚 Changed Lecture Pages: mccall_correlated |
|
📖 Netlify Preview Ready! Preview URL: https://pr-616--sunny-cactus-210e3e.netlify.app (b22bcf0) 📚 Changed Lecture Pages: mccall_correlated |
Removed @jax.jit decorator from the Q operator function since it's called within the already-jitted compute_fixed_point function. JAX documentation recommends avoiding nested jit decorators as they create compilation boundaries that prevent XLA from optimizing the full computation graph. Performance testing showed ~10% improvement by letting the outer jit compile the entire computation including Q. Also cleaned up formatting: - Renamed 'state' to 'loop_state' for clarity in while_loop functions - Improved function signature formatting - Standardized code style 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
|
📖 Netlify Preview Ready! Preview URL: https://pr-616--sunny-cactus-210e3e.netlify.app (3f361f1) 📚 Changed Lecture Pages: mccall_correlated |
|
Thanks @HumphreyYang ! Merging. |
This PR updates
mcall_correlatedby converting the code to JAX.Lecture runtime cuts from 79.07 to 9.14 seconds.
The main loop cuts from 5.85 to 0.16 seconds with
block_until_ready()