From 56099c447caa6e285e94753fc94bf13d8e4a2dc5 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 23 Dec 2025 22:38:23 -0800 Subject: [PATCH] checkpoint change PiperOrigin-RevId: 848426218 --- init2winit/checkpoint.py | 117 +++++++++------------ init2winit/gradient_statistics_callback.py | 5 +- init2winit/test_checkpoint.py | 44 ++++---- init2winit/test_utils.py | 7 +- init2winit/trainer_lib/base_trainer.py | 39 ++++--- init2winit/utils.py | 8 +- 6 files changed, 102 insertions(+), 118 deletions(-) diff --git a/init2winit/checkpoint.py b/init2winit/checkpoint.py index eb94a1c2..f0eb8fb1 100644 --- a/init2winit/checkpoint.py +++ b/init2winit/checkpoint.py @@ -18,30 +18,28 @@ This is useful for training neural networks with stax, where model parameters are nested numpy arrays. """ -import os -import sys - from absl import flags from absl import logging -from flax.training import checkpoints as flax_checkpoints import jax # pylint: disable=g-importing-member from jax.experimental.multihost_utils import process_allgather +import orbax.checkpoint as ocp FLAGS = flags.FLAGS -def load_pytree(pytree_file, orbax_checkpointer=None): +def load_pytree(pytree_file, orbax_checkpoint_manager=None): """Loads the checkpointed pytree.""" - latest = load_latest_checkpoint(pytree_file, - target=None, - orbax_checkpointer=orbax_checkpointer) - if latest: - # Because we pass target=None, flax checkpointing will return the raw - # state dict, where 'state' will be a dict with keys ['0', '1', ...] - # instead of a list. - return latest['pytree'] - return None + if orbax_checkpoint_manager is None: + orbax_checkpoint_manager = ocp.CheckpointManager( + pytree_file, + ) + restore_step = orbax_checkpoint_manager.latest_step() + try: + restored = orbax_checkpoint_manager.restore(restore_step) + return restored + except FileNotFoundError: + return None def maybe_restore_checkpoint( @@ -51,7 +49,7 @@ def maybe_restore_checkpoint( unreplicated_training_metrics_state, train_dir, external_checkpoint_path=None, - orbax_checkpointer=None): + orbax_checkpoint_manager=None): """Optionally restores from a checkpoint. The checkpoint logic is as follows: if there is a checkpoint in `train_dir`, @@ -68,7 +66,7 @@ def maybe_restore_checkpoint( train_dir: (str) The training directory where we will look for a checkpoint. external_checkpoint_path: (str) If this argument is set, then we will load the external checkpoint stored there. - orbax_checkpointer: orbax.Checkpointer + orbax_checkpoint_manager: orbax.CheckpointManager Returns: unreplicated_optimizer_state @@ -89,24 +87,27 @@ def maybe_restore_checkpoint( training_metrics_grabber=unreplicated_training_metrics_state, global_step=uninitialized_global_step, preemption_count=0, - sum_train_cost=0.0) + sum_train_cost=0.0, + ) logging.info('Loading latest checkpoint from train_dir: %s', train_dir) - latest_ckpt = load_latest_checkpoint(train_dir, - target=unreplicated_checkpoint_state, - orbax_checkpointer=orbax_checkpointer) + latest_ckpt = load_latest_checkpoint( + target=unreplicated_checkpoint_state, + orbax_checkpoint_manager=orbax_checkpoint_manager, + ) logging.info('Loading checkpoint from train_dir %s complete.', train_dir) # Load_latest_checkpoint() will return unreplicated_checkpoint_state if # train_dir does not exist or if it exists and contains no checkpoints. # Note that we could likely change the below line to: # found_checkpoint = latest_ckpt != unreplicated_checkpoint_state - found_checkpoint = (latest_ckpt['global_step'] != uninitialized_global_step) + found_checkpoint = latest_ckpt['global_step'] != uninitialized_global_step # If there's a latest checkpoint in the train_dir, restore from that. if found_checkpoint: ckpt_to_return = latest_ckpt is_restored = True # We do want trainer to increment preemption_count. - logging.info('Restoring checkpoint from ckpt_%d', - latest_ckpt['global_step']) + logging.info( + 'Restoring checkpoint from ckpt_%d', latest_ckpt['global_step'] + ) # Else, if external_checkpoint_path is non-null, restore from that checkpoint. elif external_checkpoint_path is not None: # TODO(jeremycohen) This code will crash if we try to load an external @@ -121,7 +122,6 @@ def maybe_restore_checkpoint( ckpt_to_return = load_checkpoint( external_checkpoint_path, target=unreplicated_checkpoint_state, - orbax_checkpointer=orbax_checkpointer, ) is_restored = False # We don't want trainer to increment preemption_count. @@ -159,7 +159,6 @@ def maybe_restore_checkpoint( def save_unreplicated_checkpoint( - train_dir, optimizer_state, params, batch_stats, @@ -167,8 +166,7 @@ def save_unreplicated_checkpoint( global_step, preemption_count, sum_train_cost, - orbax_checkpointer, - max_to_keep=1): + orbax_checkpoint_manager): """Saves pytree, step, preemption_count, and sum_train_cost to train_dir.""" logging.info('Saving checkpoint to ckpt_%d', global_step) # jax.device_get doesn't work if jax.Array lives on multiple hosts. @@ -191,20 +189,15 @@ def save_unreplicated_checkpoint( params=unreplicated_params, batch_stats=unreplicated_batch_stats, training_metrics_grabber=unreplicated_training_metrics_state) - save_checkpoint(train_dir, - global_step, + save_checkpoint(global_step, state, - max_to_keep=max_to_keep, - orbax_checkpointer=orbax_checkpointer) + orbax_checkpoint_manager=orbax_checkpoint_manager) logging.info('Done saving checkpoint.') -def save_checkpoint(train_dir, - step, +def save_checkpoint(step, state, - prefix='ckpt_', - max_to_keep=None, - orbax_checkpointer=None): + orbax_checkpoint_manager=None): """Saves checkpoint to train_dir/{prefix}{step}. A list of checkpoints will be stored in train_dir. The user @@ -214,56 +207,42 @@ def save_checkpoint(train_dir, is not None. Args: - train_dir: (str) Directory to create the checkpoint directory in. step: (int) Step of the checkpoint. state: (dict) The state to save. - prefix: (str) Prefix of the checkpoint name. - max_to_keep: (int) Checkpoints older than the max_to_keep'th will be - deleted. Defaults to never deleting. - orbax_checkpointer: orbax.Checkpointer + orbax_checkpoint_manager: orbax.CheckpointManager Returns: The path of the checkpoint directory. """ - if max_to_keep is None: - max_to_keep = sys.maxsize - flax_checkpoints.save_checkpoint_multiprocess( - train_dir, - target=state, - step=step, - prefix=prefix, - keep=max_to_keep, - overwrite=True, - orbax_checkpointer=orbax_checkpointer, - ) - save_dir = os.path.join(train_dir, prefix + str(step)) - return save_dir + orbax_checkpoint_manager.save(step, args=ocp.args.StandardSave(state)) + + return orbax_checkpoint_manager.directory def load_checkpoint( - checkpoint_path, target=None, prefix='ckpt_', orbax_checkpointer=None + checkpoint_path, + target=None, + orbax_checkpointer=None, ): """Loads the specified checkpoint.""" - restored = flax_checkpoints.restore_checkpoint( + if orbax_checkpointer is None: + orbax_checkpointer = ocp.StandardCheckpointer() + + restored = orbax_checkpointer.restore( checkpoint_path, - target=target, - prefix=prefix, - orbax_checkpointer=orbax_checkpointer, + item=target, + restore_args=ocp.args.StandardRestore(target, strict=False), ) return restored -def load_latest_checkpoint( - train_dir, target=None, prefix='ckpt_', orbax_checkpointer=None -): +def load_latest_checkpoint(target=None, orbax_checkpoint_manager=None): """Loads the most recent checkpoint listed in train_dir. Args: - train_dir: the directory to read checkpoints from. target: used for Flax checkpointing, a pytree whose structure will be used to structure the restored checkpoint data. - prefix: the prefix of the names of checkpoint files. - orbax_checkpointer: orbax.Checkpointer + orbax_checkpoint_manager: An orbax.CheckpointManager instance. Returns: The state restored from the checkpoint. If using Flax checkpointing and target=None, this will return a unstructured dictionary containing the @@ -271,11 +250,11 @@ def load_latest_checkpoint( https://github.com/google/flax/blob/master/flax/serialization.py#L67. If the directory doesn't exist, it will return the original target. """ + restore_step = orbax_checkpoint_manager.latest_step() try: - restored = flax_checkpoints.restore_checkpoint( - train_dir, target=target, prefix=prefix, - orbax_checkpointer=orbax_checkpointer + restored = orbax_checkpoint_manager.restore( + restore_step, args=ocp.args.StandardRestore(target, strict=False) ) return restored - except ValueError: + except FileNotFoundError: return target diff --git a/init2winit/gradient_statistics_callback.py b/init2winit/gradient_statistics_callback.py index 76fbe78e..d69af2c5 100644 --- a/init2winit/gradient_statistics_callback.py +++ b/init2winit/gradient_statistics_callback.py @@ -146,10 +146,7 @@ def run_eval(self, params, batch_stats, optimizer_state, global_step): ) checkpoint.save_checkpoint( - self.save_path, step=global_step, - state=state, - prefix='measurement_', - max_to_keep=None) + state=state) return {} diff --git a/init2winit/test_checkpoint.py b/init2winit/test_checkpoint.py index 7d2fff7f..0db009bd 100644 --- a/init2winit/test_checkpoint.py +++ b/init2winit/test_checkpoint.py @@ -59,8 +59,12 @@ def setUp(self): model_init_fn = jax.jit( functools.partial(model.flax_module.init, train=False)) init_dict = model_init_fn({'params': params_rng}, xs) - self.orbax_checkpointer = orbax_checkpoint.AsyncCheckpointer( - orbax_checkpoint.PyTreeCheckpointHandler(), timeout_secs=60) + self.orbax_checkpoint_manager = orbax_checkpoint.CheckpointManager( + self.test_dir, + options=orbax_checkpoint.CheckpointManagerOptions( + max_to_keep=1, create=True + ), + ) self.params = init_dict['params'] def tearDown(self): @@ -73,10 +77,13 @@ def test_save_load_roundtrip(self): """Test that saving and loading produces the original state.""" baz = ['a', 'b', 'ccc'] state = dict(params=self.params, global_step=5, completed_epochs=4, baz=baz) - checkpoint.save_checkpoint(self.test_dir, 0, state, - orbax_checkpointer=self.orbax_checkpointer) + checkpoint.save_checkpoint( + 0, + state, + orbax_checkpoint_manager=self.orbax_checkpoint_manager, + ) latest = checkpoint.load_latest_checkpoint( - self.test_dir, target=state, orbax_checkpointer=self.orbax_checkpointer + target=state, orbax_checkpoint_manager=self.orbax_checkpoint_manager ) self.assertEqual(latest['baz'], baz) @@ -90,22 +97,18 @@ def test_delete_old_checkpoints(self): global_step=5, completed_epochs=4,) checkpoint.save_checkpoint( - self.test_dir, 0, state1, - orbax_checkpointer=self.orbax_checkpointer, - max_to_keep=1) + orbax_checkpoint_manager=self.orbax_checkpoint_manager) state2 = dict(params=self.params, global_step=10, completed_epochs=8) checkpoint.save_checkpoint( - self.test_dir, 1, state2, - orbax_checkpointer=self.orbax_checkpointer, - max_to_keep=1) - self.orbax_checkpointer.wait_until_finished() + orbax_checkpoint_manager=self.orbax_checkpoint_manager) + self.orbax_checkpoint_manager.wait_until_finished() dir_contents = gfile.glob(os.path.join(self.test_dir, '*')) # Due to Flax Orbax migration using Orbax AsyncCheckpointer will result # in 'max_to_keep + 1' files. @@ -135,7 +138,6 @@ def test_all_variables_restored(self): initial_training_metrics = {'ema': 0} checkpoint.save_checkpoint( - train_dir=fresh_train_dir, step=global_step, state=dict(global_step=global_step, preemption_count=preemption_count, @@ -144,8 +146,7 @@ def test_all_variables_restored(self): params=saved_params, batch_stats=saved_batch_stats, training_metrics_grabber=saved_training_metrics), - orbax_checkpointer=self.orbax_checkpointer, - max_to_keep=1) + orbax_checkpoint_manager=self.orbax_checkpoint_manager,) ( ret_state, @@ -162,7 +163,7 @@ def test_all_variables_restored(self): initial_batch_stats, initial_training_metrics, fresh_train_dir, - orbax_checkpointer=self.orbax_checkpointer, + orbax_checkpoint_manager=self.orbax_checkpoint_manager, ) assert pytree_equal( @@ -211,12 +212,11 @@ def test_maybe_restore_from_checkpoint_logic(self): external_dir = tempfile.mkdtemp() # two helper functions - def save_checkpoint(train_dir, global_step, preemption_count, + def save_checkpoint(global_step, preemption_count, sum_train_cost, params): """Helper function to save a checkpoint.""" checkpoint.save_checkpoint( - train_dir=train_dir, step=global_step, state=dict(global_step=global_step, preemption_count=preemption_count, @@ -225,8 +225,7 @@ def save_checkpoint(train_dir, global_step, preemption_count, params=params, batch_stats={}, training_metrics_grabber={}), - orbax_checkpointer=self.orbax_checkpointer, - max_to_keep=1) + orbax_checkpoint_manager=self.orbax_checkpoint_manager) def maybe_restore_checkpoint(params, train_dir, external_checkpoint_path): """Helper function to replicate_and_maybe_restore a checkpoint.""" @@ -235,7 +234,7 @@ def maybe_restore_checkpoint(params, train_dir, external_checkpoint_path): ret_global_step, ret_sum_train_cost, ret_preemption_count, ret_is_restored) = checkpoint.maybe_restore_checkpoint( {}, params, {}, {}, train_dir, external_checkpoint_path, - orbax_checkpointer=self.orbax_checkpointer) + orbax_checkpoint_manager=self.orbax_checkpoint_manager) ret_params_unrep = ret_params @@ -243,8 +242,7 @@ def maybe_restore_checkpoint(params, train_dir, external_checkpoint_path): ret_preemption_count, ret_is_restored) # Save external checkpoint. - save_checkpoint(train_dir=external_dir, - global_step=5, + save_checkpoint(global_step=5, preemption_count=4, sum_train_cost=7.0, params=external_params) diff --git a/init2winit/test_utils.py b/init2winit/test_utils.py index cee4d033..70d14a2b 100644 --- a/init2winit/test_utils.py +++ b/init2winit/test_utils.py @@ -27,6 +27,7 @@ from init2winit import utils import jax.numpy as jnp import numpy as np +import orbax.checkpoint as ocp def _identity(i): @@ -94,7 +95,11 @@ def testAppendPytree(self): for pytree in pytrees: logger.append_pytree(pytree) - latest = checkpoint.load_latest_checkpoint(pytree_path, prefix='') + # TODO(kasimbeg): FIX this + orbax_checkpoint_manager = ocp.CheckpointManager(pytree_path) + latest = checkpoint.latest = checkpoint.load_latest_checkpoint( + target=pytrees[0], orbax_checkpoint_manager=orbax_checkpoint_manager + ) saved_pytrees = latest['pytree'] if latest else [] self.assertEqual( pytrees, [saved_pytrees[str(i)] for i in range(len(saved_pytrees))]) diff --git a/init2winit/trainer_lib/base_trainer.py b/init2winit/trainer_lib/base_trainer.py index 18ac768d..2ce50c08 100644 --- a/init2winit/trainer_lib/base_trainer.py +++ b/init2winit/trainer_lib/base_trainer.py @@ -174,10 +174,11 @@ def __init__( choose_store_cell=True, ), ) - self._orbax_checkpointer = ocp.AsyncCheckpointer( - ocp.PyTreeCheckpointHandler(use_ocdbt=False), - timeout_secs=600, - file_options=orbax_file_options, + self._orbax_checkpoint_manager = ocp.CheckpointManager( + self._checkpoint_dir, + options=ocp.CheckpointManagerOptions( + max_to_keep=1, create=True, file_options=orbax_file_options + ), ) self._early_stopping_target_name = early_stopping_target_name self._early_stopping_target_value = early_stopping_target_value @@ -208,10 +209,16 @@ def __init__( assert eval_batch_size % (jax.device_count()) == 0 # Only used if checkpoints_steps is non-empty. Standard checkpoints are - # saved in train_dir. + # saved in _checkpoint_dir. self._extra_checkpoint_dir = os.path.join( self._checkpoint_dir, 'checkpoints' ) + self._orbax_checkpoint_manager_extra = ocp.CheckpointManager( + self._extra_checkpoint_dir, + options=ocp.CheckpointManagerOptions( + create=True, file_options=orbax_file_options + ), + ) # During eval, we can donate the 'batch' buffer. We don't donate the # 'params' and 'batch_stats' buffers as we don't re-assign those values in @@ -228,7 +235,7 @@ def __init__( self._training_algorithm_class) def wait_until_orbax_checkpointer_finished(self): - self._orbax_checkpointer.wait_until_finished() + self._orbax_checkpoint_manager.wait_until_finished() def log_model_info(self, unreplicated_params): if jax.process_index() == 0: @@ -269,7 +276,7 @@ def maybe_restore_from_checkpoint(self, unreplicated_metrics_state, train_dir=self._checkpoint_dir, external_checkpoint_path=self._external_checkpoint_path, - orbax_checkpointer=self._orbax_checkpointer, + orbax_checkpoint_manager=self._orbax_checkpoint_manager, ) if self._is_restored: @@ -335,13 +342,14 @@ def setup_data_loader(self, data_rng, global_step): return dataset - def _save(self, checkpoint_dir, max_to_keep=1): + def _save(self, checkpoint_manager=None): if utils.use_mock_tpu_backend(): logging.info('Skip saving checkpoint when running with mock backend.') return + if checkpoint_manager is None: + checkpoint_manager = self._orbax_checkpoint_manager checkpoint.save_unreplicated_checkpoint( - checkpoint_dir, self._optimizer_state, self._params, self._batch_stats, @@ -349,8 +357,7 @@ def _save(self, checkpoint_dir, max_to_keep=1): self._global_step, self._preemption_count, self._sum_train_cost, - self._orbax_checkpointer, - max_to_keep=max_to_keep, + checkpoint_manager, ) def _get_step_frequency(self, cur_step, start_step, start_time): @@ -451,7 +458,7 @@ def _eval(self, start_step, start_time, save=True): ) self._run_eval_callbacks(report) if save: - self._save(self._checkpoint_dir) + self._save() steps_since_last_eval = self._global_step - self._prev_eval_step steps_per_sec_no_eval = steps_since_last_eval / time_since_last_eval run_time = time.time() - self._time_at_prev_eval_end @@ -646,7 +653,7 @@ def train(self): self._prev_eval_step = self._global_step if self._global_step in self._checkpoint_steps: - self._save(self._extra_checkpoint_dir, max_to_keep=None) + self._save(checkpoint_manager=self._orbax_checkpoint_manager_extra) for _ in range(start_step, self._num_train_steps): with jax.profiler.StepTraceAnnotation( @@ -682,13 +689,14 @@ def train(self): self._sum_train_cost, ) if self._global_step in self._checkpoint_steps: - self._save(self._extra_checkpoint_dir, max_to_keep=None) + self._save(checkpoint_manager=self._orbax_checkpoint_manager_extra) # TODO(gdahl, gilmer): consider moving this test up. # NB: Since this test is after we increment self._global_step, having 0 # in eval_steps does nothing. if trainer_utils.should_eval( - self._global_step, self._eval_frequency, self._eval_steps): + self._global_step, self._eval_frequency, self._eval_steps + ): try: report = self._eval(start_step, start_time) except utils.TrainingDivergedError as e: @@ -709,6 +717,7 @@ def train(self): yield report # To make sure the last checkpoint was correctly saved. self.wait_until_orbax_checkpointer_finished() + self._orbax_checkpoint_manager.close() @abc.abstractmethod def update(self, batch, rng, metrics_update_fn, metrics_state, training_cost): diff --git a/init2winit/utils.py b/init2winit/utils.py index f0ba297a..37d7c2c6 100644 --- a/init2winit/utils.py +++ b/init2winit/utils.py @@ -261,21 +261,17 @@ def append_scalar_metrics(self, metrics): # size 512. We could only flush at the end of training to optimize this. self._tb_metric_writer.flush() - def write_pytree(self, pytree, prefix='training_metrics'): + def write_pytree(self, pytree): """Record a serializable pytree to disk, overwriting any previous state. Args: pytree: Any serializable pytree - prefix: The prefix for the checkpoint. Save path is - self._pytree_path/prefix """ state = dict(pytree=pytree) checkpoint.save_checkpoint( self._pytree_path, - step='', state=state, - prefix=prefix, - max_to_keep=None) + orbax_checkpoint_manager=None) def append_pytree(self, pytree, prefix='training_metrics'): """Append and record a serializable pytree to disk.