Skip to content

Conversation

@amitsrivastava78
Copy link
Collaborator

Supports following feature

  • Asynchronous Checkpointing
  • Composite Checkpointing
  • Preservation Policies
  • Save Decision Policies
  • Transformations - Custom Handlers

…re Supports following feature - Asynchronous Checkpointing - Composite Checkpointing - Preservation Policies - Save Decision Policies - Transformations - Custom Handlers
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @amitsrivastava78, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates Orbax checkpointing into Keras 3.0, providing a robust and flexible mechanism for saving and restoring training progress. The new OrbaxCheckpoint callback offers features like asynchronous saving, customizable save policies, and the ability to save complex states including model weights, optimizer variables, metrics, and data iterator positions. This enhancement significantly improves the reliability and efficiency of training large models, especially in distributed environments, by leveraging Orbax's advanced capabilities.

Highlights

  • New OrbaxCheckpoint Callback: Introduces a new OrbaxCheckpoint callback for Keras 3.0, enabling advanced data-centric saving and restoration of model states.
  • Asynchronous Checkpointing: Supports asynchronous saving of model weights and optimizer states, allowing training to continue without I/O blocking.
  • Comprehensive Checkpointing Features: Includes support for composite checkpointing, preservation policies (e.g., max_to_keep, keep_period), save decision policies (e.g., save_interval), and custom transformations during saving.
  • Distributed Training Support: Adds a get_process_index utility function to the Keras backend, facilitating distributed training setups by identifying the primary process for checkpoint operations across JAX, TensorFlow, and PyTorch.
  • Extensible with Custom Handlers: Exposes advanced Orbax functionalities like CheckpointManager, TypeHandler, and register_type_handler to allow users to define custom serialization logic for complex objects.
  • Iterator State Saving and Restoration: Enables saving and restoring the state of data iterators, crucial for seamless training resumption from a specific point, with backend-specific examples for TensorFlow, JAX, and PyTorch.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces OrbaxCheckpoint, a new Keras callback for advanced checkpointing using the Orbax library. This is a significant feature addition that enables asynchronous saving, composite checkpoints, and other powerful capabilities. The implementation is extensive and is supported by a comprehensive suite of tests.

My review has identified several important issues that need attention. There are critical correctness and performance bugs in the main implementation: the batch-based saving logic is flawed, and the asynchronous saving feature is effectively disabled by blocking calls. Additionally, some features are incomplete, and there are minor areas for improvement in the tests to enhance maintainability. I have provided specific suggestions to address these points. After these fixes, this will be a very valuable addition to Keras.

Comment on lines +119 to +141
def __init__(
self,
directory,
monitor="val_loss",
verbose=0,
save_best_only=False,
mode="auto",
save_freq="epoch",
max_to_keep=5,
keep_period=None,
initial_value_threshold=None,
save_optimizer_state=True,
save_on_background=True,
save_metadata=None,
save_data_iterator=None,
save_metrics_state=False,
async_timeout_secs=600,
enable_background_delete=False,
post_finalization_callback=None,
save_transforms=None,
save_decision_policy=None,
save_interval=None,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The __init__ method has 16 arguments, which is quite high. The Keras API design guidelines suggest reconsidering signatures with more than 6-7 arguments.1 While I understand the need to expose Orbax's functionality, it might be worth exploring if some of these could be grouped into a configuration object to improve readability and usability, similar to how ocp.CheckpointManagerOptions is used internally.

Style Guide References

Footnotes

  1. The style guide recommends that functions with more than 6-7 arguments should be re-evaluated for simplification, possibly by breaking them into smaller objects or modular pieces.

@codecov-commenter
Copy link

codecov-commenter commented Oct 22, 2025

Codecov Report

❌ Patch coverage is 50.59524% with 166 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.45%. Comparing base (47fcb39) to head (276ea9a).
⚠️ Report is 14 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/callbacks/orbax_checkpoint.py 49.16% 128 Missing and 25 partials ⚠️
keras/src/backend/torch/distribution_lib.py 12.50% 7 Missing ⚠️
keras/src/backend/tensorflow/distribution_lib.py 66.66% 2 Missing ⚠️
keras/api/_tf_keras/keras/callbacks/__init__.py 0.00% 1 Missing ⚠️
keras/api/_tf_keras/keras/distribution/__init__.py 0.00% 1 Missing ⚠️
keras/src/backend/numpy/distribution_lib.py 50.00% 1 Missing ⚠️
keras/src/backend/openvino/distribution_lib.py 50.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21762      +/-   ##
==========================================
- Coverage   82.69%   82.45%   -0.24%     
==========================================
  Files         573      581       +8     
  Lines       58888    59655     +767     
  Branches     9218     9385     +167     
==========================================
+ Hits        48696    49190     +494     
- Misses       7845     8052     +207     
- Partials     2347     2413      +66     
Flag Coverage Δ
keras 82.27% <50.29%> (-0.22%) ⬇️
keras-jax 63.23% <47.91%> (-0.01%) ⬇️
keras-numpy 57.31% <15.17%> (-0.41%) ⬇️
keras-openvino 34.18% <14.88%> (-0.22%) ⬇️
keras-tensorflow 64.03% <48.80%> (+0.01%) ⬆️
keras-torch 63.54% <47.91%> (-0.03%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. This checkpointing system has a ton of features!

Quick first pass.

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple more comments I forgot.

- Remove conditional export decorator to ensure OrbaxCheckpoint is always available
- Remove unnecessary exception handling in state tree operations
- Update process index check comment for clarity
- Format code to comply with 80-character line limit
- Add distribution_lib modules for backend-specific distributed training support
- Remove unused 'result' variable in _reconstruct_state_tree_with_values
- Fix long comment line in test file
- Apply code formatting changes
…st handling

- Implement OrbaxCheckpoint callback for async checkpointing with state tree handling
- Add conditional exports for optional orbax-checkpoint dependency
- Use pytest.importorskip for clean optional dependency testing
- Ensure graceful handling when orbax-checkpoint is not installed
Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The JAX implementation of def process_id() is missing.

General questions:

  • Does this as-is support all backends?
  • Does this support JAX sharding? I don't see anything related to sharing (which may be normal). What about re-sharding?

Comment on lines 80 to 92
if isinstance(value, (int, float)):
# Convert Python scalar to numpy scalar
return np.array(value, dtype=obj.dtype)
elif isinstance(value, np.ndarray):
# value is a numpy array, convert to scalar if needed
if value.ndim == 0:
return np.array(value.item(), dtype=obj.dtype)
elif value.ndim == 1 and value.size == 1:
return np.array(value.item(), dtype=obj.dtype)
else:
return value.astype(obj.dtype).reshape(obj.shape)
else:
return np.array(value, dtype=obj.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these needed? Doesn't Orbax give you back what you saved?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conversions are needed because Orbax gives you back exactly what you saved, but the saved data might not match the current model's expectations. This reconstruction code handles mismatches that can occur due to model surgery of model configuration getting changed.

Comment on lines +223 to +226
save_data_iterator: Dict or callable, data iterator state to save with
each checkpoint. If callable, it will be called with (epoch, logs)
and should return a dict with serializable iterator state.
Defaults to None.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious what this feature is for? It seems very similar to the one above.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The save_data_iterator parameter is indeed similar to save_metadata, but serves a very specific and important purpose in training resumption scenarios.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you give an example? We didn't have that in the regular checkpoint format, so I'm curious.

@amitsrivastava78 amitsrivastava78 force-pushed the orbax-checkpoint-test-improvements branch from f69aba8 to 8097bd2 Compare October 28, 2025 08:47
- Preserve nested state tree structures instead of flattening for better layer name preservation
- Add backward compatibility for old flattened format checkpoints
- Simplify test class by using self.get_temp_dir() instead of setUp/tearDown
- Remove silent pytest.importorskip, add explicit skip conditions for backend-specific tests
- Move process_id function from backend to distribution module
- Update imports to use centralized LazyModule for orbax.checkpoint
- Test across all backends (JAX, TensorFlow, PyTorch) - all passing
return tree.map_structure(convert_scalars, state_tree)


def _flatten_state_tree_values(state_tree):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants