Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 48 additions & 69 deletions init2winit/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,28 @@
This is useful for training neural networks with stax, where model parameters
are nested numpy arrays.
"""
import os
import sys

from absl import flags
from absl import logging
from flax.training import checkpoints as flax_checkpoints
import jax
# pylint: disable=g-importing-member
from jax.experimental.multihost_utils import process_allgather
import orbax.checkpoint as ocp

FLAGS = flags.FLAGS


def load_pytree(pytree_file, orbax_checkpointer=None):
def load_pytree(pytree_file, orbax_checkpoint_manager=None):
"""Loads the checkpointed pytree."""
latest = load_latest_checkpoint(pytree_file,
target=None,
orbax_checkpointer=orbax_checkpointer)
if latest:
# Because we pass target=None, flax checkpointing will return the raw
# state dict, where 'state' will be a dict with keys ['0', '1', ...]
# instead of a list.
return latest['pytree']
return None
if orbax_checkpoint_manager is None:
orbax_checkpoint_manager = ocp.CheckpointManager(
pytree_file,
)
restore_step = orbax_checkpoint_manager.latest_step()
try:
restored = orbax_checkpoint_manager.restore(restore_step)
return restored
except FileNotFoundError:
return None


def maybe_restore_checkpoint(
Expand All @@ -51,7 +49,7 @@ def maybe_restore_checkpoint(
unreplicated_training_metrics_state,
train_dir,
external_checkpoint_path=None,
orbax_checkpointer=None):
orbax_checkpoint_manager=None):
"""Optionally restores from a checkpoint.

The checkpoint logic is as follows: if there is a checkpoint in `train_dir`,
Expand All @@ -68,7 +66,7 @@ def maybe_restore_checkpoint(
train_dir: (str) The training directory where we will look for a checkpoint.
external_checkpoint_path: (str) If this argument is set, then we will load
the external checkpoint stored there.
orbax_checkpointer: orbax.Checkpointer
orbax_checkpoint_manager: orbax.CheckpointManager

Returns:
unreplicated_optimizer_state
Expand All @@ -89,24 +87,27 @@ def maybe_restore_checkpoint(
training_metrics_grabber=unreplicated_training_metrics_state,
global_step=uninitialized_global_step,
preemption_count=0,
sum_train_cost=0.0)
sum_train_cost=0.0,
)
logging.info('Loading latest checkpoint from train_dir: %s', train_dir)
latest_ckpt = load_latest_checkpoint(train_dir,
target=unreplicated_checkpoint_state,
orbax_checkpointer=orbax_checkpointer)
latest_ckpt = load_latest_checkpoint(
target=unreplicated_checkpoint_state,
orbax_checkpoint_manager=orbax_checkpoint_manager,
)
logging.info('Loading checkpoint from train_dir %s complete.', train_dir)
# Load_latest_checkpoint() will return unreplicated_checkpoint_state if
# train_dir does not exist or if it exists and contains no checkpoints.
# Note that we could likely change the below line to:
# found_checkpoint = latest_ckpt != unreplicated_checkpoint_state
found_checkpoint = (latest_ckpt['global_step'] != uninitialized_global_step)
found_checkpoint = latest_ckpt['global_step'] != uninitialized_global_step

# If there's a latest checkpoint in the train_dir, restore from that.
if found_checkpoint:
ckpt_to_return = latest_ckpt
is_restored = True # We do want trainer to increment preemption_count.
logging.info('Restoring checkpoint from ckpt_%d',
latest_ckpt['global_step'])
logging.info(
'Restoring checkpoint from ckpt_%d', latest_ckpt['global_step']
)
# Else, if external_checkpoint_path is non-null, restore from that checkpoint.
elif external_checkpoint_path is not None:
# TODO(jeremycohen) This code will crash if we try to load an external
Expand All @@ -121,7 +122,6 @@ def maybe_restore_checkpoint(
ckpt_to_return = load_checkpoint(
external_checkpoint_path,
target=unreplicated_checkpoint_state,
orbax_checkpointer=orbax_checkpointer,
)
is_restored = False # We don't want trainer to increment preemption_count.

Expand Down Expand Up @@ -159,16 +159,14 @@ def maybe_restore_checkpoint(


def save_unreplicated_checkpoint(
train_dir,
optimizer_state,
params,
batch_stats,
training_metrics_state,
global_step,
preemption_count,
sum_train_cost,
orbax_checkpointer,
max_to_keep=1):
orbax_checkpoint_manager):
"""Saves pytree, step, preemption_count, and sum_train_cost to train_dir."""
logging.info('Saving checkpoint to ckpt_%d', global_step)
# jax.device_get doesn't work if jax.Array lives on multiple hosts.
Expand All @@ -191,20 +189,15 @@ def save_unreplicated_checkpoint(
params=unreplicated_params,
batch_stats=unreplicated_batch_stats,
training_metrics_grabber=unreplicated_training_metrics_state)
save_checkpoint(train_dir,
global_step,
save_checkpoint(global_step,
state,
max_to_keep=max_to_keep,
orbax_checkpointer=orbax_checkpointer)
orbax_checkpoint_manager=orbax_checkpoint_manager)
logging.info('Done saving checkpoint.')


def save_checkpoint(train_dir,
step,
def save_checkpoint(step,
state,
prefix='ckpt_',
max_to_keep=None,
orbax_checkpointer=None):
orbax_checkpoint_manager=None):
"""Saves checkpoint to train_dir/{prefix}{step}.

A list of checkpoints will be stored in train_dir. The user
Expand All @@ -214,68 +207,54 @@ def save_checkpoint(train_dir,
is not None.

Args:
train_dir: (str) Directory to create the checkpoint directory in.
step: (int) Step of the checkpoint.
state: (dict) The state to save.
prefix: (str) Prefix of the checkpoint name.
max_to_keep: (int) Checkpoints older than the max_to_keep'th will be
deleted. Defaults to never deleting.
orbax_checkpointer: orbax.Checkpointer
orbax_checkpoint_manager: orbax.CheckpointManager

Returns:
The path of the checkpoint directory.
"""
if max_to_keep is None:
max_to_keep = sys.maxsize
flax_checkpoints.save_checkpoint_multiprocess(
train_dir,
target=state,
step=step,
prefix=prefix,
keep=max_to_keep,
overwrite=True,
orbax_checkpointer=orbax_checkpointer,
)
save_dir = os.path.join(train_dir, prefix + str(step))
return save_dir
orbax_checkpoint_manager.save(step, args=ocp.args.StandardSave(state))

return orbax_checkpoint_manager.directory


def load_checkpoint(
checkpoint_path, target=None, prefix='ckpt_', orbax_checkpointer=None
checkpoint_path,
target=None,
orbax_checkpointer=None,
):
"""Loads the specified checkpoint."""
restored = flax_checkpoints.restore_checkpoint(
if orbax_checkpointer is None:
orbax_checkpointer = ocp.StandardCheckpointer()

restored = orbax_checkpointer.restore(
checkpoint_path,
target=target,
prefix=prefix,
orbax_checkpointer=orbax_checkpointer,
item=target,
restore_args=ocp.args.StandardRestore(target, strict=False),
)
return restored


def load_latest_checkpoint(
train_dir, target=None, prefix='ckpt_', orbax_checkpointer=None
):
def load_latest_checkpoint(target=None, orbax_checkpoint_manager=None):
"""Loads the most recent checkpoint listed in train_dir.

Args:
train_dir: the directory to read checkpoints from.
target: used for Flax checkpointing, a pytree whose structure will be used
to structure the restored checkpoint data.
prefix: the prefix of the names of checkpoint files.
orbax_checkpointer: orbax.Checkpointer
orbax_checkpoint_manager: An orbax.CheckpointManager instance.
Returns:
The state restored from the checkpoint. If using Flax checkpointing and
target=None, this will return a unstructured dictionary containing the
checkpoint data, as returned by to_state_dict in serialization.py:
https://github.com/google/flax/blob/master/flax/serialization.py#L67. If
the directory doesn't exist, it will return the original target.
"""
restore_step = orbax_checkpoint_manager.latest_step()
try:
restored = flax_checkpoints.restore_checkpoint(
train_dir, target=target, prefix=prefix,
orbax_checkpointer=orbax_checkpointer
restored = orbax_checkpoint_manager.restore(
restore_step, args=ocp.args.StandardRestore(target, strict=False)
)
return restored
except ValueError:
except FileNotFoundError:
return target
5 changes: 1 addition & 4 deletions init2winit/gradient_statistics_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,7 @@ def run_eval(self, params, batch_stats, optimizer_state, global_step):
)

checkpoint.save_checkpoint(
self.save_path,
step=global_step,
state=state,
prefix='measurement_',
max_to_keep=None)
state=state)

return {}
44 changes: 21 additions & 23 deletions init2winit/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,12 @@ def setUp(self):
model_init_fn = jax.jit(
functools.partial(model.flax_module.init, train=False))
init_dict = model_init_fn({'params': params_rng}, xs)
self.orbax_checkpointer = orbax_checkpoint.AsyncCheckpointer(
orbax_checkpoint.PyTreeCheckpointHandler(), timeout_secs=60)
self.orbax_checkpoint_manager = orbax_checkpoint.CheckpointManager(
self.test_dir,
options=orbax_checkpoint.CheckpointManagerOptions(
max_to_keep=1, create=True
),
)
self.params = init_dict['params']

def tearDown(self):
Expand All @@ -73,10 +77,13 @@ def test_save_load_roundtrip(self):
"""Test that saving and loading produces the original state."""
baz = ['a', 'b', 'ccc']
state = dict(params=self.params, global_step=5, completed_epochs=4, baz=baz)
checkpoint.save_checkpoint(self.test_dir, 0, state,
orbax_checkpointer=self.orbax_checkpointer)
checkpoint.save_checkpoint(
0,
state,
orbax_checkpoint_manager=self.orbax_checkpoint_manager,
)
latest = checkpoint.load_latest_checkpoint(
self.test_dir, target=state, orbax_checkpointer=self.orbax_checkpointer
target=state, orbax_checkpoint_manager=self.orbax_checkpoint_manager
)

self.assertEqual(latest['baz'], baz)
Expand All @@ -90,22 +97,18 @@ def test_delete_old_checkpoints(self):
global_step=5,
completed_epochs=4,)
checkpoint.save_checkpoint(
self.test_dir,
0,
state1,
orbax_checkpointer=self.orbax_checkpointer,
max_to_keep=1)
orbax_checkpoint_manager=self.orbax_checkpoint_manager)

state2 = dict(params=self.params,
global_step=10,
completed_epochs=8)
checkpoint.save_checkpoint(
self.test_dir,
1,
state2,
orbax_checkpointer=self.orbax_checkpointer,
max_to_keep=1)
self.orbax_checkpointer.wait_until_finished()
orbax_checkpoint_manager=self.orbax_checkpoint_manager)
self.orbax_checkpoint_manager.wait_until_finished()
dir_contents = gfile.glob(os.path.join(self.test_dir, '*'))
# Due to Flax Orbax migration using Orbax AsyncCheckpointer will result
# in 'max_to_keep + 1' files.
Expand Down Expand Up @@ -135,7 +138,6 @@ def test_all_variables_restored(self):
initial_training_metrics = {'ema': 0}

checkpoint.save_checkpoint(
train_dir=fresh_train_dir,
step=global_step,
state=dict(global_step=global_step,
preemption_count=preemption_count,
Expand All @@ -144,8 +146,7 @@ def test_all_variables_restored(self):
params=saved_params,
batch_stats=saved_batch_stats,
training_metrics_grabber=saved_training_metrics),
orbax_checkpointer=self.orbax_checkpointer,
max_to_keep=1)
orbax_checkpoint_manager=self.orbax_checkpoint_manager,)

(
ret_state,
Expand All @@ -162,7 +163,7 @@ def test_all_variables_restored(self):
initial_batch_stats,
initial_training_metrics,
fresh_train_dir,
orbax_checkpointer=self.orbax_checkpointer,
orbax_checkpoint_manager=self.orbax_checkpoint_manager,
)

assert pytree_equal(
Expand Down Expand Up @@ -211,12 +212,11 @@ def test_maybe_restore_from_checkpoint_logic(self):
external_dir = tempfile.mkdtemp()

# two helper functions
def save_checkpoint(train_dir, global_step, preemption_count,
def save_checkpoint(global_step, preemption_count,
sum_train_cost, params):
"""Helper function to save a checkpoint."""

checkpoint.save_checkpoint(
train_dir=train_dir,
step=global_step,
state=dict(global_step=global_step,
preemption_count=preemption_count,
Expand All @@ -225,8 +225,7 @@ def save_checkpoint(train_dir, global_step, preemption_count,
params=params,
batch_stats={},
training_metrics_grabber={}),
orbax_checkpointer=self.orbax_checkpointer,
max_to_keep=1)
orbax_checkpoint_manager=self.orbax_checkpoint_manager)

def maybe_restore_checkpoint(params, train_dir, external_checkpoint_path):
"""Helper function to replicate_and_maybe_restore a checkpoint."""
Expand All @@ -235,16 +234,15 @@ def maybe_restore_checkpoint(params, train_dir, external_checkpoint_path):
ret_global_step, ret_sum_train_cost, ret_preemption_count,
ret_is_restored) = checkpoint.maybe_restore_checkpoint(
{}, params, {}, {}, train_dir, external_checkpoint_path,
orbax_checkpointer=self.orbax_checkpointer)
orbax_checkpoint_manager=self.orbax_checkpoint_manager)

ret_params_unrep = ret_params

return (ret_params_unrep, ret_global_step, ret_sum_train_cost,
ret_preemption_count, ret_is_restored)

# Save external checkpoint.
save_checkpoint(train_dir=external_dir,
global_step=5,
save_checkpoint(global_step=5,
preemption_count=4,
sum_train_cost=7.0,
params=external_params)
Expand Down
7 changes: 6 additions & 1 deletion init2winit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from init2winit import utils
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint as ocp


def _identity(i):
Expand Down Expand Up @@ -94,7 +95,11 @@ def testAppendPytree(self):
for pytree in pytrees:
logger.append_pytree(pytree)

latest = checkpoint.load_latest_checkpoint(pytree_path, prefix='')
# TODO(kasimbeg): FIX this
orbax_checkpoint_manager = ocp.CheckpointManager(pytree_path)
latest = checkpoint.latest = checkpoint.load_latest_checkpoint(
target=pytrees[0], orbax_checkpoint_manager=orbax_checkpoint_manager
)
saved_pytrees = latest['pytree'] if latest else []
self.assertEqual(
pytrees, [saved_pytrees[str(i)] for i in range(len(saved_pytrees))])
Expand Down
Loading