-
Couldn't load subscription status.
- Fork 19.6k
Added OrbaxCheckpoint for keras 3.0 for Data centric saving and restore #21762
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Added OrbaxCheckpoint for keras 3.0 for Data centric saving and restore #21762
Conversation
…re Supports following feature - Asynchronous Checkpointing - Composite Checkpointing - Preservation Policies - Save Decision Policies - Transformations - Custom Handlers
Summary of ChangesHello @amitsrivastava78, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates Orbax checkpointing into Keras 3.0, providing a robust and flexible mechanism for saving and restoring training progress. The new Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces OrbaxCheckpoint, a new Keras callback for advanced checkpointing using the Orbax library. This is a significant feature addition that enables asynchronous saving, composite checkpoints, and other powerful capabilities. The implementation is extensive and is supported by a comprehensive suite of tests.
My review has identified several important issues that need attention. There are critical correctness and performance bugs in the main implementation: the batch-based saving logic is flawed, and the asynchronous saving feature is effectively disabled by blocking calls. Additionally, some features are incomplete, and there are minor areas for improvement in the tests to enhance maintainability. I have provided specific suggestions to address these points. After these fixes, this will be a very valuable addition to Keras.
| def __init__( | ||
| self, | ||
| directory, | ||
| monitor="val_loss", | ||
| verbose=0, | ||
| save_best_only=False, | ||
| mode="auto", | ||
| save_freq="epoch", | ||
| max_to_keep=5, | ||
| keep_period=None, | ||
| initial_value_threshold=None, | ||
| save_optimizer_state=True, | ||
| save_on_background=True, | ||
| save_metadata=None, | ||
| save_data_iterator=None, | ||
| save_metrics_state=False, | ||
| async_timeout_secs=600, | ||
| enable_background_delete=False, | ||
| post_finalization_callback=None, | ||
| save_transforms=None, | ||
| save_decision_policy=None, | ||
| save_interval=None, | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The __init__ method has 16 arguments, which is quite high. The Keras API design guidelines suggest reconsidering signatures with more than 6-7 arguments.1 While I understand the need to expose Orbax's functionality, it might be worth exploring if some of these could be grouped into a configuration object to improve readability and usability, similar to how ocp.CheckpointManagerOptions is used internally.
Style Guide References
Footnotes
-
The style guide recommends that functions with more than 6-7 arguments should be re-evaluated for simplification, possibly by breaking them into smaller objects or modular pieces. ↩
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #21762 +/- ##
==========================================
- Coverage 82.69% 82.45% -0.24%
==========================================
Files 573 581 +8
Lines 58888 59655 +767
Branches 9218 9385 +167
==========================================
+ Hits 48696 49190 +494
- Misses 7845 8052 +207
- Partials 2347 2413 +66
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR. This checkpointing system has a ton of features!
Quick first pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple more comments I forgot.
- Remove conditional export decorator to ensure OrbaxCheckpoint is always available - Remove unnecessary exception handling in state tree operations - Update process index check comment for clarity - Format code to comply with 80-character line limit - Add distribution_lib modules for backend-specific distributed training support
- Remove unused 'result' variable in _reconstruct_state_tree_with_values - Fix long comment line in test file - Apply code formatting changes
…st handling - Implement OrbaxCheckpoint callback for async checkpointing with state tree handling - Add conditional exports for optional orbax-checkpoint dependency - Use pytest.importorskip for clean optional dependency testing - Ensure graceful handling when orbax-checkpoint is not installed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The JAX implementation of def process_id() is missing.
General questions:
- Does this as-is support all backends?
- Does this support JAX sharding? I don't see anything related to sharing (which may be normal). What about re-sharding?
| if isinstance(value, (int, float)): | ||
| # Convert Python scalar to numpy scalar | ||
| return np.array(value, dtype=obj.dtype) | ||
| elif isinstance(value, np.ndarray): | ||
| # value is a numpy array, convert to scalar if needed | ||
| if value.ndim == 0: | ||
| return np.array(value.item(), dtype=obj.dtype) | ||
| elif value.ndim == 1 and value.size == 1: | ||
| return np.array(value.item(), dtype=obj.dtype) | ||
| else: | ||
| return value.astype(obj.dtype).reshape(obj.shape) | ||
| else: | ||
| return np.array(value, dtype=obj.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are these needed? Doesn't Orbax give you back what you saved?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The conversions are needed because Orbax gives you back exactly what you saved, but the saved data might not match the current model's expectations. This reconstruction code handles mismatches that can occur due to model surgery of model configuration getting changed.
| save_data_iterator: Dict or callable, data iterator state to save with | ||
| each checkpoint. If callable, it will be called with (epoch, logs) | ||
| and should return a dict with serializable iterator state. | ||
| Defaults to None. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious what this feature is for? It seems very similar to the one above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The save_data_iterator parameter is indeed similar to save_metadata, but serves a very specific and important purpose in training resumption scenarios.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you give an example? We didn't have that in the regular checkpoint format, so I'm curious.
f69aba8 to
8097bd2
Compare
- Preserve nested state tree structures instead of flattening for better layer name preservation - Add backward compatibility for old flattened format checkpoints - Simplify test class by using self.get_temp_dir() instead of setUp/tearDown - Remove silent pytest.importorskip, add explicit skip conditions for backend-specific tests - Move process_id function from backend to distribution module - Update imports to use centralized LazyModule for orbax.checkpoint - Test across all backends (JAX, TensorFlow, PyTorch) - all passing
8097bd2 to
276ea9a
Compare
| return tree.map_structure(convert_scalars, state_tree) | ||
|
|
||
|
|
||
| def _flatten_state_tree_values(state_tree): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove.
Supports following feature