-
-
Notifications
You must be signed in to change notification settings - Fork 53
Commit 22983b3
[mcall_correlated] Convert to JAX and content checks (#616)
* convert to JAX and make title style sheet compliant
* minor variable name update
* minor updates
* minor update
* update jax install
* Remove @jax.jit from Q operator for better performance
Removed @jax.jit decorator from the Q operator function since it's called
within the already-jitted compute_fixed_point function. JAX documentation
recommends avoiding nested jit decorators as they create compilation
boundaries that prevent XLA from optimizing the full computation graph.
Performance testing showed ~10% improvement by letting the outer jit
compile the entire computation including Q.
Also cleaned up formatting:
- Renamed 'state' to 'loop_state' for clarity in while_loop functions
- Improved function signature formatting
- Standardized code style
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
---------
Co-authored-by: John Stachurski <john.stachurski@gmail.com>
Co-authored-by: Claude <noreply@anthropic.com>1 parent 74949c4 commit 22983b3Copy full SHA for 22983b3
File tree
Expand file treeCollapse file tree
1 file changed
+171
-170
lines changedOpen diff view settings
Filter options
- lectures
Expand file treeCollapse file tree
1 file changed
+171
-170
lines changedOpen diff view settings
0 commit comments