diff --git a/axlearn/cloud/gcp/measurement.py b/axlearn/cloud/gcp/measurement.py index e6c98c811..228084fa4 100644 --- a/axlearn/cloud/gcp/measurement.py +++ b/axlearn/cloud/gcp/measurement.py @@ -25,6 +25,7 @@ from typing import Optional, Sequence import jax +import orbax.checkpoint as ocp from absl import flags, logging from ml_goodput_measurement import goodput from ml_goodput_measurement import monitoring as goodput_monitoring @@ -134,6 +135,19 @@ def record_event(self, event: measurement.EventType, *args, **kwargs): ) # pylint: enable=try-except-raise + def create_checkpoint_logger(self) -> Optional[ocp.logging.CloudLogger]: + try: + logging.info("Creating a Goodput checkpoint logger.") + return ocp.logging.CloudLogger( + options=ocp.logging.CloudLoggerOptions( + job_name=self._job_name, + logger_name=self._logger_name, + ) + ) + except Exception as e: # pylint: disable=broad-exception-caught + logging.warning("Failed to create Goodput checkpoint logger: %s", e, exc_info=True) + return None + @contextlib.contextmanager def _maybe_monitor_goodput(self, *args, **kwargs): """Monitor cumulative goodput if enabled. @@ -221,35 +235,32 @@ def record(self, event: measurement.Event, *args, **kwargs): """ # Lazily instantiate the recorder. This avoids invoking jax before setup is complete. if self._recorder is None: - cfg: GoodputRecorder.Config = self.config + if jax.process_index() == 0: + logging.info("Lazily instantiating goodput recorder.") self._recorder = goodput.GoodputRecorder( - job_name=cfg.name, - logger_name=f"goodput_logger_{cfg.name}", + job_name=self._job_name, + logger_name=self._logger_name, logging_enabled=(jax.process_index() == 0), ) - if event == measurement.Event.START_JOB: - self._recorder.record_job_start_time(*args, **kwargs) - elif event == measurement.Event.END_JOB: - self._recorder.record_job_end_time(*args, **kwargs) - elif event == measurement.Event.START_STEP: - self._recorder.record_step_start_time(*args, **kwargs) - elif event == measurement.Event.START_ACCELERATOR_INIT: - self._recorder.record_tpu_init_start_time(*args, **kwargs) - elif event == measurement.Event.END_ACCELERATOR_INIT: - self._recorder.record_tpu_init_end_time(*args, **kwargs) - elif event == measurement.Event.START_TRAINING_PREPARATION: - self._recorder.record_training_preparation_start_time(*args, **kwargs) - elif event == measurement.Event.END_TRAINING_PREPARATION: - self._recorder.record_training_preparation_end_time(*args, **kwargs) - elif event == measurement.Event.START_DATA_LOADING: - self._recorder.record_data_loading_start_time(*args, **kwargs) - elif event == measurement.Event.END_DATA_LOADING: - self._recorder.record_data_loading_end_time(*args, **kwargs) - elif event == measurement.Event.START_CUSTOM_BADPUT_EVENT: - self._recorder.record_custom_badput_event_start_time(*args, **kwargs) - elif event == measurement.Event.END_CUSTOM_BADPUT_EVENT: - self._recorder.record_custom_badput_event_end_time(*args, **kwargs) + start_method_name = f"record_{event.value}_start_time" + end_method_name = f"record_{event.value}_end_time" + + record_event_start = getattr(self._recorder, start_method_name, None) + record_event_end = getattr(self._recorder, end_method_name, None) + + if record_event_start: + try: + record_event_start(*args, **kwargs) + except RuntimeError as e: + logging.warning( + "Failed to record start of event %s. Error: %s", event.value, e, exc_info=True + ) + # pylint: disable=try-except-raise + try: + yield # Run the user code in the context + except Exception: + raise else: logging.log_first_n( logging.WARNING, diff --git a/axlearn/cloud/gcp/measurement_test.py b/axlearn/cloud/gcp/measurement_test.py index a5245980b..6cfb09784 100644 --- a/axlearn/cloud/gcp/measurement_test.py +++ b/axlearn/cloud/gcp/measurement_test.py @@ -373,3 +373,46 @@ def test_maybe_monitor_all( else: mock_monitor_instance.start_rolling_window_goodput_uploader.assert_not_called() mock_monitor_instance.stop_rolling_window_goodput_uploader.assert_not_called() + + @mock.patch("jax.process_index", return_value=0) + def test_create_checkpoint_logger_success(self, _): + """Tests that create_checkpoint_logger creates a CloudLogger with correct config.""" + cfg = GoodputRecorder.default_config().set( + name="test-job", + upload_dir="/test", + upload_interval=30, + ) + recorder = GoodputRecorder(cfg) + + with mock.patch("orbax.checkpoint.logging.CloudLogger") as mock_logger_cls: + mock_logger_instance = mock_logger_cls.return_value + logger = recorder.create_checkpoint_logger() + + mock_logger_cls.assert_called_once() + self.assertIs(logger, mock_logger_instance) + + _, kwargs = mock_logger_cls.call_args + options = kwargs["options"] + self.assertEqual(options.job_name, "test-job") + self.assertEqual(options.logger_name, "goodput_logger_test-job") + + @mock.patch("jax.process_index", return_value=0) + def test_create_checkpoint_logger_failure(self, _): + """Tests that create_checkpoint_logger logs a warning on failure and returns None.""" + cfg = GoodputRecorder.default_config().set( + name="fail-job", + upload_dir="/test", + upload_interval=30, + ) + recorder = GoodputRecorder(cfg) + + with mock.patch( + "orbax.checkpoint.logging.CloudLogger", side_effect=RuntimeError("TestError") + ) as mock_logger_cls, mock.patch.object(logging, "warning") as mock_warning: + logger = recorder.create_checkpoint_logger() + self.assertIsNone(logger) + mock_logger_cls.assert_called_once() + mock_warning.assert_called_once() + self.assertIn( + "Failed to create Goodput checkpoint logger", mock_warning.call_args[0][0] + ) diff --git a/axlearn/common/checkpointer_orbax.py b/axlearn/common/checkpointer_orbax.py index 2eb205b7f..3d07ada18 100644 --- a/axlearn/common/checkpointer_orbax.py +++ b/axlearn/common/checkpointer_orbax.py @@ -17,7 +17,7 @@ import tensorflow as tf from absl import logging -from axlearn.common import utils +from axlearn.common import measurement, utils from axlearn.common.checkpointer import ( STEP_NUM_DIGITS, STEP_PREFIX, @@ -232,6 +232,9 @@ def save_fn_with_summaries(step: int, last_saved_step: Optional[int]) -> bool: step_prefix=STEP_PREFIX, step_format_fixed_length=STEP_NUM_DIGITS, ) + self._checkpoint_logger = None + if measurement.global_recorder: + self._checkpoint_logger = measurement.global_recorder.create_checkpoint_logger() self._manager = ocp.CheckpointManager( directory=cfg.dir, options=ocp.CheckpointManagerOptions( @@ -255,6 +258,7 @@ def save_fn_with_summaries(step: int, last_saved_step: Optional[int]) -> bool: restore_concurrent_gb=cfg.max_concurrent_restore_gb, ), }, + logger=self._checkpoint_logger, ) def _get_spec(self, *, step: int, state: Nested[Any]) -> Nested[Any]: diff --git a/axlearn/common/checkpointer_orbax_emergency.py b/axlearn/common/checkpointer_orbax_emergency.py index f868488e0..3faae6b86 100644 --- a/axlearn/common/checkpointer_orbax_emergency.py +++ b/axlearn/common/checkpointer_orbax_emergency.py @@ -27,7 +27,7 @@ from jax.experimental.array_serialization import serialization from axlearn.common import file_system as fs -from axlearn.common import utils, utils_spmd +from axlearn.common import measurement, utils, utils_spmd from axlearn.common.checkpointer import ( STEP_NUM_DIGITS, STEP_PREFIX, @@ -667,6 +667,9 @@ def _composite_save_policy(*, step: int, evaler_summaries: dict[str, Any]): # See comments of _eval_summaries in `OrbaxCheckpointer`. self._eval_summaries = None self._reached_preemption = False + self._checkpoint_logger = None + if measurement.global_recorder: + self._checkpoint_logger = measurement.global_recorder.create_checkpoint_logger() # pylint: disable-next=redefined-builtin def ckpt_dir(self, step: int, dir: Optional[str] = None) -> str: @@ -731,6 +734,7 @@ def _orbax_save_fn( cleanup_tmp_directories=True, enable_async_checkpointing=True, ), + logger=self._checkpoint_logger, ) return self._tensor_manager diff --git a/axlearn/common/checkpointer_orbax_emergency_test.py b/axlearn/common/checkpointer_orbax_emergency_test.py index 406d34760..63303474b 100644 --- a/axlearn/common/checkpointer_orbax_emergency_test.py +++ b/axlearn/common/checkpointer_orbax_emergency_test.py @@ -10,6 +10,7 @@ import tempfile from contextlib import ExitStack, closing from typing import Optional +from unittest import mock import jax import numpy as np @@ -18,7 +19,7 @@ from absl.testing import parameterized from jax import numpy as jnp -from axlearn.common import utils_spmd +from axlearn.common import measurement, utils_spmd from axlearn.common.checkpointer_orbax_emergency import ( OrbaxEmergencyCheckpointer, _dump_process_info, @@ -299,3 +300,28 @@ def start_processes(reverse_process_id: bool = False): finally: for p in processes: p.kill() + + @mock.patch("orbax.checkpoint._src.multihost.multihost.initialize_runtime_to_distributed_ids") + @mock.patch("orbax.checkpoint._src.multihost.multihost.initialize_distributed_to_device_ids") + def test_emergency_checkpointer_initializes_logger_from_global_recorder( + self, mock_init_runtime, mock_init_device_ids + ): # pylint: disable=unused-argument + """Tests OrbaxEmergencyCheckpointer initializes _checkpoint_logger.""" + with tempfile.TemporaryDirectory() as temp_dir, mock.patch.object( + measurement, "global_recorder", mock.MagicMock() + ) as mock_recorder: + mock_logger = mock.MagicMock() + mock_recorder.create_checkpoint_logger.return_value = mock_logger + + cfg = OrbaxEmergencyCheckpointer.default_config().set( + name="test_logger", + trainer_dir=temp_dir, + dir=temp_dir, + local_dir=temp_dir, + replica_axis_index=0, + ) + + ckpt: OrbaxEmergencyCheckpointer = cfg.instantiate(parent=None) + + mock_recorder.create_checkpoint_logger.assert_called_once() + self.assertEqual(ckpt._checkpoint_logger, mock_logger) diff --git a/axlearn/common/checkpointer_orbax_test.py b/axlearn/common/checkpointer_orbax_test.py index cbb676309..e45e5db23 100644 --- a/axlearn/common/checkpointer_orbax_test.py +++ b/axlearn/common/checkpointer_orbax_test.py @@ -10,13 +10,14 @@ import os import tempfile from typing import Sequence +from unittest import mock import jax import orbax.checkpoint as ocp from jax import numpy as jnp from jax.experimental import mesh_utils -from axlearn.common import test_utils +from axlearn.common import measurement, test_utils from axlearn.common.checkpointer import read_index_file from axlearn.common.checkpointer_orbax import OrbaxCheckpointer @@ -52,3 +53,21 @@ def test_index(self): ), ) self.assertEqual(ref_index, test_index["index"]) + + def test_initializes_checkpoint_logger_from_global_recorder(self): + """Tests that OrbaxCheckpointer initializes _checkpoint_logger if global_recorder is set.""" + with tempfile.TemporaryDirectory() as temp_dir, mock.patch.object( + measurement, "global_recorder", mock.MagicMock() + ) as mock_recorder: + mock_logger = mock.MagicMock(spec=ocp.logging.CloudLogger) + mock_recorder.create_checkpoint_logger.return_value = mock_logger + + ckpt = ( + OrbaxCheckpointer.default_config() + .set(name="test", dir=temp_dir) + .instantiate(parent=None) + ) + + # Ensure create_checkpoint_logger was called and the logger was set. + mock_recorder.create_checkpoint_logger.assert_called_once() + self.assertEqual(ckpt._checkpoint_logger, mock_logger) diff --git a/axlearn/common/launch_trainer.py b/axlearn/common/launch_trainer.py index bba28533e..697dbef6f 100644 --- a/axlearn/common/launch_trainer.py +++ b/axlearn/common/launch_trainer.py @@ -2,6 +2,7 @@ """Utilities to launch a trainer.""" +import contextlib import json import os from typing import Any, Optional @@ -128,8 +129,7 @@ def get_trainer_config( return trainer_config -def run_trainer(trainer_config: SpmdTrainer.Config) -> Any: - measurement.record_event(measurement.Event.START_JOB) +def _run_trainer_impl(trainer_config: SpmdTrainer.Config) -> Any: trainer_config_debug_string = trainer_config.debug_string() logging.info("Trainer config:\n%s", trainer_config_debug_string) if jax.process_index() == 0: @@ -150,5 +150,13 @@ def run_trainer(trainer_config: SpmdTrainer.Config) -> Any: trainer: SpmdTrainer = trainer_config.instantiate(parent=None) prng_key = jax.random.PRNGKey(seed=FLAGS.trainer_prng_seed) output = trainer.run(prng_key) - measurement.record_event(measurement.Event.END_JOB) return output + + +def run_trainer(trainer_config: SpmdTrainer.Config) -> Any: + recorder = measurement.global_recorder + job_events_manager = ( + recorder.record_event(measurement.EventType.JOB) if recorder else contextlib.nullcontext() + ) + with job_events_manager: + return _run_trainer_impl(trainer_config) diff --git a/axlearn/common/launch_trainer_main.py b/axlearn/common/launch_trainer_main.py index 2f617b4cd..8d170a950 100644 --- a/axlearn/common/launch_trainer_main.py +++ b/axlearn/common/launch_trainer_main.py @@ -13,7 +13,6 @@ def main(_): launch.setup() trainer_config = launch_trainer.get_trainer_config() trainer_config.set(recorder=config_for_function(lambda: measurement.global_recorder)) - measurement.start_monitoring() launch_trainer.run_trainer(trainer_config) diff --git a/axlearn/common/measurement.py b/axlearn/common/measurement.py index a4af06659..83ba65d05 100644 --- a/axlearn/common/measurement.py +++ b/axlearn/common/measurement.py @@ -131,6 +131,10 @@ def maybe_monitor_all(self): """ yield + def create_checkpoint_logger(self) -> Optional[object]: + """Optionally returns a fully functional and independent checkpoint logger.""" + return None + _recorders: dict[str, type] = {} _T = TypeVar("_T") diff --git a/axlearn/common/measurement_test.py b/axlearn/common/measurement_test.py index 7fa6b3386..1f0ec9ab4 100644 --- a/axlearn/common/measurement_test.py +++ b/axlearn/common/measurement_test.py @@ -96,3 +96,6 @@ def test_initialize(self, recorder_type, expected): # Ensure that maybe_monitor_all does not fail (just enter and exit context). with measurement.global_recorder.maybe_monitor_all(): pass + + # Ensure that create_checkpoint_logger does not crash. + measurement.global_recorder.create_checkpoint_logger() diff --git a/axlearn/common/trainer.py b/axlearn/common/trainer.py index ac78a7c75..9671cee10 100644 --- a/axlearn/common/trainer.py +++ b/axlearn/common/trainer.py @@ -241,118 +241,121 @@ def __init__( self._device_monitor = maybe_instantiate(cfg.device_monitor) self._recorder = maybe_instantiate(cfg.recorder) self._is_initialized: bool = False - self._maybe_record_event(measurement.Event.START_ACCELERATOR_INIT) + # Accelerator initialization. + with self._record_event(measurement.EventType.ACCELERATOR_INIT): + if cfg.model.dtype is None: + raise ValueError(f"dtype must be explicitly specified for {self.path()}.model") + if cfg.model.param_init is None: + cfg.model.param_init = DefaultInitializer.default_config() + logging.info( + "model.param_init is not specified. Default to DefaultInitializer: %s", + cfg.model.param_init, + ) - if cfg.model.dtype is None: - raise ValueError(f"dtype must be explicitly specified for {self.path()}.model") - if cfg.model.param_init is None: - cfg.model.param_init = DefaultInitializer.default_config() - logging.info( - "model.param_init is not specified. Default to DefaultInitializer: %s", - cfg.model.param_init, + self._per_param_train_dtype = maybe_instantiate( + canonicalize_per_param_dtype(cfg.train_dtype) ) - self._per_param_train_dtype = maybe_instantiate( - canonicalize_per_param_dtype(cfg.train_dtype) - ) - - # Create the device mesh. - if devices is None: - self._step_log( - "Devices: global=%s local=%s %s", - jax.device_count(), - jax.local_device_count(), - [device.platform for device in jax.local_devices()], - ) - else: - local_devices = [d for d in devices.flatten() if d.process_index == jax.process_index()] - self._step_log( - "Devices: global=%s local=%s %s", - len(devices), - len(local_devices), - [device.platform for device in local_devices], - ) - self._step_log("Mesh shape: %s", cfg.mesh_shape) - devices = ( - utils.create_device_mesh(mesh_shape=cfg.mesh_shape) if devices is None else devices - ) - mesh = jax.sharding.Mesh(devices, cfg.mesh_axis_names) - self._step_log("Global mesh: %s", mesh) - self._mesh = mesh - self._context_manager: Callable[[], ContextManager] = ( - maybe_instantiate(cfg.context_manager) or contextlib.nullcontext - ) - xsc_check_policy = None - if cfg.xsc_check_policy: - if jax.default_backend() != "tpu": - # XSC is currently only supported on TPU XLA backend. - logging.warning( - "xsc_check_policy was set for non-TPU XLA backend. Running without XSC." + # Create the device mesh. + if devices is None: + self._step_log( + "Devices: global=%s local=%s %s", + jax.device_count(), + jax.local_device_count(), + [device.platform for device in jax.local_devices()], ) else: - xsc_check_policy = maybe_instantiate(cfg.xsc_check_policy) - self._xsc_check_policy: Optional[Callable[[int], bool]] = xsc_check_policy - self._compiled_train_step: Optional[jax.stages.Compiled] = None - - # Create all children within the mesh context so that utils.input_partition_spec() works - # properly. - with self.mesh(): - if cfg.batch_axis_names is not None: - cfg.input = maybe_set_config( - cfg.input, partition_spec=PartitionSpec(cfg.batch_axis_names) + local_devices = [ + d for d in devices.flatten() if d.process_index == jax.process_index() + ] + self._step_log( + "Devices: global=%s local=%s %s", + len(devices), + len(local_devices), + [device.platform for device in local_devices], ) - self.input: Input = self._add_child( - "input", maybe_set_config(cfg.input, is_training=True) - ) - # Start from the beginning of the input dataset by default. - self._input_iter = iter(self.input.dataset()) - cfg.summary_writer.dir = cfg.summary_writer.dir or os.path.join( - cfg.dir, "summaries", "train_train" + self._step_log("Mesh shape: %s", cfg.mesh_shape) + devices = ( + utils.create_device_mesh(mesh_shape=cfg.mesh_shape) if devices is None else devices ) - self._add_child("summary_writer", cfg.summary_writer) - self._add_child("model", cfg.model) - self._add_child("learner", cfg.learner) - cfg.checkpointer.dir = cfg.checkpointer.dir or os.path.join(cfg.dir, "checkpoints") - self._add_child("checkpointer", cfg.checkpointer) - if cfg.init_state_builder is not None: - self._add_child("init_state_builder", cfg.init_state_builder) - - self._model_param_specs = self.model.create_parameter_specs_recursively() - model_param_partition_specs = jax.tree.map( - lambda spec: spec.mesh_axes, self._model_param_specs + mesh = jax.sharding.Mesh(devices, cfg.mesh_axis_names) + self._step_log("Global mesh: %s", mesh) + self._mesh = mesh + self._context_manager: Callable[[], ContextManager] = ( + maybe_instantiate(cfg.context_manager) or contextlib.nullcontext ) - for name, spec in utils.flatten_items(self._model_param_specs): - self._step_log("Model param spec: %s=%s", name, spec) - self._learner_state_partition_specs = self.learner.create_state_partition_specs( - self._model_param_specs - ) - for name, spec in utils.flatten_items(self._learner_state_partition_specs): - self._step_log("Learner state spec: %s=%s", name, spec) - self._trainer_state_specs = TrainerState( - prng_key=ParameterSpec(dtype=jnp.uint32, shape=[4], mesh_axes=PartitionSpec(None)), - model=self._model_param_specs, - learner=self._learner_state_partition_specs, - ) - self._trainer_state_partition_specs: TrainerState = jax.tree.map( - lambda spec: spec.sharding, self._trainer_state_specs - ) - # Create evalers, which depend on model_param_partition_specs. - self._evalers = {} - for evaler_name, evaler_cfg in cfg.evalers.items(): - evaler_cfg.summary_writer.dir = evaler_cfg.summary_writer.dir or os.path.join( - cfg.dir, "summaries", evaler_name - ) + xsc_check_policy = None + if cfg.xsc_check_policy: + if jax.default_backend() != "tpu": + # XSC is currently only supported on TPU XLA backend. + logging.warning( + "xsc_check_policy was set for non-TPU XLA backend. Running without XSC." + ) + else: + xsc_check_policy = maybe_instantiate(cfg.xsc_check_policy) + self._xsc_check_policy: Optional[Callable[[int], bool]] = xsc_check_policy + self._compiled_train_step: Optional[jax.stages.Compiled] = None + + # Create all children within the mesh context so that utils.input_partition_spec() works + # properly. + with self.mesh(): if cfg.batch_axis_names is not None: - maybe_set_config( - evaler_cfg.input, partition_spec=PartitionSpec(cfg.batch_axis_names) + cfg.input = maybe_set_config( + cfg.input, partition_spec=PartitionSpec(cfg.batch_axis_names) ) - self._evalers[evaler_name] = self._add_child( - evaler_name, - evaler_cfg, - model=self.model, - model_param_partition_specs=model_param_partition_specs, + self.input: Input = self._add_child( + "input", maybe_set_config(cfg.input, is_training=True) + ) + # Start from the beginning of the input dataset by default. + self._input_iter = iter(self.input.dataset()) + cfg.summary_writer.dir = cfg.summary_writer.dir or os.path.join( + cfg.dir, "summaries", "train_train" + ) + self._add_child("summary_writer", cfg.summary_writer) + self._add_child("model", cfg.model) + self._add_child("learner", cfg.learner) + cfg.checkpointer.dir = cfg.checkpointer.dir or os.path.join(cfg.dir, "checkpoints") + self._add_child("checkpointer", cfg.checkpointer) + if cfg.init_state_builder is not None: + self._add_child("init_state_builder", cfg.init_state_builder) + + self._model_param_specs = self.model.create_parameter_specs_recursively() + model_param_partition_specs = jax.tree.map( + lambda spec: spec.mesh_axes, self._model_param_specs + ) + for name, spec in utils.flatten_items(self._model_param_specs): + self._step_log("Model param spec: %s=%s", name, spec) + self._learner_state_partition_specs = self.learner.create_state_partition_specs( + self._model_param_specs ) - self._maybe_record_event(measurement.Event.END_ACCELERATOR_INIT) + for name, spec in utils.flatten_items(self._learner_state_partition_specs): + self._step_log("Learner state spec: %s=%s", name, spec) + self._trainer_state_specs = TrainerState( + prng_key=ParameterSpec( + dtype=jnp.uint32, shape=[4], mesh_axes=PartitionSpec(None) + ), + model=self._model_param_specs, + learner=self._learner_state_partition_specs, + ) + self._trainer_state_partition_specs: TrainerState = jax.tree.map( + lambda spec: spec.sharding, self._trainer_state_specs + ) + # Create evalers, which depend on model_param_partition_specs. + self._evalers = {} + for evaler_name, evaler_cfg in cfg.evalers.items(): + evaler_cfg.summary_writer.dir = evaler_cfg.summary_writer.dir or os.path.join( + cfg.dir, "summaries", evaler_name + ) + if cfg.batch_axis_names is not None: + maybe_set_config( + evaler_cfg.input, partition_spec=PartitionSpec(cfg.batch_axis_names) + ) + self._evalers[evaler_name] = self._add_child( + evaler_name, + evaler_cfg, + model=self.model, + model_param_partition_specs=model_param_partition_specs, + ) @property def step(self): @@ -370,6 +373,15 @@ def trainer_state_specs(self): def trainer_state_partition_specs(self): return self._trainer_state_partition_specs + @contextlib.contextmanager + def _record_event(self, event: measurement.EventType, *args, **kwargs): + """A helper to record an event if a recorder is configured.""" + if self._recorder: + with self._recorder.record_event(event, *args, **kwargs) as event_manager: + yield event_manager + else: + yield + def _train_step_input_partition_specs(self): # Note that subclasses may override this method to set a partition spec for pjit which is # different from that of the input partition spec. @@ -563,6 +575,7 @@ def run( different types of values such as WeightedScalar, Tensor, or string, depending on the specific `metric_calculator` config of the evaler. """ + with ( ( self._device_monitor.start_monitoring() @@ -582,8 +595,9 @@ def run( ) # Prepare training. - if not self._prepare_training(prng_key): - return None + with self._record_event(measurement.EventType.TRAINING_PREPARATION): + if not self._prepare_training(prng_key): + return None self._is_initialized = True @@ -596,10 +610,10 @@ def run( input_iterator = self.input.batches(self._input_iter) while True: - self._maybe_record_event(measurement.Event.START_DATA_LOADING) try: - input_batch = next(input_iterator) - self._maybe_record_event(measurement.Event.END_DATA_LOADING) + with self._record_event(measurement.EventType.DATA_LOADING): + input_batch = next(input_iterator) + logging.log_first_n( logging.INFO, "host_input_batch=%s", 3, utils.shapes(input_batch) ) @@ -609,18 +623,18 @@ def run( self._step = self._step + 1 self.vlog(3, "Start step %s", self.step) - self._maybe_record_event(measurement.Event.START_STEP, self._step) - output = self._run_step( - utils.host_to_global_array( - input_batch, - partition=self._train_step_input_partition_specs(), - ), - force_run_evals=( - force_run_eval_sets_at_max_step - if self.step >= cfg.max_step - else None - ), - ) + with self._record_event(measurement.EventType.STEP, self._step): + output = self._run_step( + utils.host_to_global_array( + input_batch, + partition=self._train_step_input_partition_specs(), + ), + force_run_evals=( + force_run_eval_sets_at_max_step + if self.step >= cfg.max_step + else None + ), + ) self.vlog(3, "Done step %s", self.step) num_steps += 1 if num_steps % 100 == 0: @@ -634,9 +648,6 @@ def run( self._step_log("Reached max_step=%s. Stopping", cfg.max_step) break except StopIteration: - # Add END_DATA_LOADING event here to close the unpaired START_DATA_LOADING - # event. - self._maybe_record_event(measurement.Event.END_DATA_LOADING) break if self.step < cfg.max_step: self._step_log("Reached end of inputs. Stopping") @@ -877,7 +888,6 @@ def _prepare_training(self, prng_key: Tensor) -> bool: A boolean indicating whether the model training should start. If not, return None from the `run` function. """ - self._maybe_record_event(measurement.Event.START_TRAINING_PREPARATION) cfg = self.config # Attempt to restore the latest checkpoint, which may contain a saved `_input_iter`. @@ -910,7 +920,6 @@ def _prepare_training(self, prng_key: Tensor) -> bool: return False self._jit_train_step = self._pjit_train_step() - self._maybe_record_event(measurement.Event.END_TRAINING_PREPARATION) return True def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int]: @@ -1051,36 +1060,29 @@ def _get_compiled_train_step_fn( mesh_shape=cfg.mesh_shape, mesh_axis_names=cfg.mesh_axis_names, device_kind=device_kind ) if not with_xsc: - self._maybe_record_event( - measurement.Event.START_CUSTOM_BADPUT_EVENT, + with self._record_event( + measurement.EventType.CUSTOM_BADPUT_EVENT, custom_badput_event_type="COMPILATION_NO_XSC", - ) - self._compiled_train_step = self.compile_train_step( - trainer_state=trainer_state, input_batch=input_batch, compiler_options=options - ) - self._maybe_record_event( - measurement.Event.END_CUSTOM_BADPUT_EVENT, - custom_badput_event_type="COMPILATION_NO_XSC", - ) + ): + self._compiled_train_step = self.compile_train_step( + trainer_state=trainer_state, input_batch=input_batch, compiler_options=options + ) return self._compiled_train_step + logging.log_first_n(logging.INFO, "Compiling XSC train step.", 1) - self._maybe_record_event( - measurement.Event.START_CUSTOM_BADPUT_EVENT, + with self._record_event( + measurement.EventType.CUSTOM_BADPUT_EVENT, custom_badput_event_type="COMPILATION_WITH_XSC", - ) - compiled_jit_train_step_fn = self.compile_train_step( - trainer_state=trainer_state, - input_batch=input_batch, - compiler_options=options - | infer_xsc_compiler_options( - halt_on_detection=True, repeat_count=1, device_kind=device_kind - ), - ) - self._maybe_record_event( - measurement.Event.END_CUSTOM_BADPUT_EVENT, - custom_badput_event_type="COMPILATION_WITH_XSC", - ) + ): + compiled_jit_train_step_fn = self.compile_train_step( + trainer_state=trainer_state, + input_batch=input_batch, + compiler_options=options + | infer_xsc_compiler_options( + halt_on_detection=True, repeat_count=1, device_kind=device_kind + ), + ) return compiled_jit_train_step_fn def _run_step( @@ -1138,26 +1140,23 @@ def _run_eval( force_runs: Optional[set[str]] = None, ) -> dict[str, Any]: """Runs evaluations and returns the corresponding summaries.""" - self._maybe_record_event( - measurement.Event.START_CUSTOM_BADPUT_EVENT, custom_badput_event_type="EVAL" - ) - evaler_summaries = {} - # Note: we will use the same eval key as the training keys of the future step, - # which should be okay. - prng_key = self._trainer_state.prng_key - for evaler_name, evaler in self._evalers.items(): - prng_key, summaries, _ = evaler.eval_step( - self.step, - prng_key=prng_key, - model_params=self.model_params_for_eval(), - train_summaries=train_summaries, - force_run=bool(force_runs is not None and evaler_name in force_runs), - ) - evaler_summaries[evaler_name] = summaries - self._maybe_record_event( - measurement.Event.END_CUSTOM_BADPUT_EVENT, custom_badput_event_type="EVAL" - ) - return evaler_summaries + with self._record_event( + measurement.EventType.CUSTOM_BADPUT_EVENT, custom_badput_event_type="EVAL" + ): + evaler_summaries = {} + # Note: we will use the same eval key as the training keys of the future step, + # which should be okay. + prng_key = self._trainer_state.prng_key + for evaler_name, evaler in self._evalers.items(): + prng_key, summaries, _ = evaler.eval_step( + self.step, + prng_key=prng_key, + model_params=self.model_params_for_eval(), + train_summaries=train_summaries, + force_run=bool(force_runs is not None and evaler_name in force_runs), + ) + evaler_summaries[evaler_name] = summaries + return evaler_summaries def _pjit_train_step(self) -> jax.stages.Wrapped: return pjit(