diff --git a/paxml/decode_programs.py b/paxml/decode_programs.py index 531ef617..c5c8030a 100644 --- a/paxml/decode_programs.py +++ b/paxml/decode_programs.py @@ -398,6 +398,13 @@ def _run_decode_loop( 'Finished decoding input batch %d for %s', step_num, self._name ) + if ( + profiler is not None + and step_num - self._task.decode.profiler_capture_step == + profiler._capture_num_steps + ): + profiler.stop_capture_async() + if jax.process_index() == 0: # Copy the tensor from device memory to ram, since accumulating such # tensor on devices may cause HBM OOM, when diff --git a/paxml/profiling.py b/paxml/profiling.py index 3b40b7b4..2e8f5547 100644 --- a/paxml/profiling.py +++ b/paxml/profiling.py @@ -16,7 +16,15 @@ """Expose functionalities for profiling code.""" from absl import logging +from ctypes import cdll +libcudart = cdll.LoadLibrary('libcudart.so') +def cudaProfilerStart(): + libcudart.cudaProfilerStart() +def cudaProfilerStop(): + libcudart.cudaProfilerStop() +def cudaDeviceSynchronize(): + libcudart.cudaDeviceSynchronize() class Profiler: """Dummy class to capture code profiles. @@ -64,8 +72,13 @@ def capture_async(self) -> None: The duration of the trace corresponds to step_duration_estimate_sec. """ + cudaProfilerStart() logging.info('Dummy profiler currently does not capture any trace.') + def stop_capture_async(self) -> None: + cudaDeviceSynchronize() + cudaProfilerStop() + def update_step_moving_mean(self, duration_sec: float): """Updates the step duration moving average with a step duration estimate. diff --git a/paxml/programs.py b/paxml/programs.py index ecc88be6..b8cbc90a 100644 --- a/paxml/programs.py +++ b/paxml/programs.py @@ -392,6 +392,14 @@ def run(self, state: TrainState, step: int) -> TrainProgramOutput: if do_profile and step - self._initial_step < profiler_capture_step: self._profiler.update_step_moving_mean(train_period.elapsed) + + if ( + do_profile + and step - self._initial_step == + profiler_capture_step + self._profiler._capture_num_steps + ): + self._profiler.stop_capture_async() + logging.log_first_n( logging.INFO, '[PAX STATUS]: Writing summaries (attempt).', 5 )