From 42827f4185d48c93c0634fa43226c2446cc0cd90 Mon Sep 17 00:00:00 2001 From: Shawn Wang Date: Wed, 17 Jul 2024 19:58:56 -0700 Subject: [PATCH] add cuda profiler hook --- paxml/decode_programs.py | 7 +++++++ paxml/profiling.py | 13 +++++++++++++ paxml/programs.py | 8 ++++++++ 3 files changed, 28 insertions(+) diff --git a/paxml/decode_programs.py b/paxml/decode_programs.py index 531ef6174..c5c8030ab 100644 --- a/paxml/decode_programs.py +++ b/paxml/decode_programs.py @@ -398,6 +398,13 @@ def _run_decode_loop( 'Finished decoding input batch %d for %s', step_num, self._name ) + if ( + profiler is not None + and step_num - self._task.decode.profiler_capture_step == + profiler._capture_num_steps + ): + profiler.stop_capture_async() + if jax.process_index() == 0: # Copy the tensor from device memory to ram, since accumulating such # tensor on devices may cause HBM OOM, when diff --git a/paxml/profiling.py b/paxml/profiling.py index 3b40b7b4f..2e8f55471 100644 --- a/paxml/profiling.py +++ b/paxml/profiling.py @@ -16,7 +16,15 @@ """Expose functionalities for profiling code.""" from absl import logging +from ctypes import cdll +libcudart = cdll.LoadLibrary('libcudart.so') +def cudaProfilerStart(): + libcudart.cudaProfilerStart() +def cudaProfilerStop(): + libcudart.cudaProfilerStop() +def cudaDeviceSynchronize(): + libcudart.cudaDeviceSynchronize() class Profiler: """Dummy class to capture code profiles. @@ -64,8 +72,13 @@ def capture_async(self) -> None: The duration of the trace corresponds to step_duration_estimate_sec. """ + cudaProfilerStart() logging.info('Dummy profiler currently does not capture any trace.') + def stop_capture_async(self) -> None: + cudaDeviceSynchronize() + cudaProfilerStop() + def update_step_moving_mean(self, duration_sec: float): """Updates the step duration moving average with a step duration estimate. diff --git a/paxml/programs.py b/paxml/programs.py index ecc88be67..b8cbc90aa 100644 --- a/paxml/programs.py +++ b/paxml/programs.py @@ -392,6 +392,14 @@ def run(self, state: TrainState, step: int) -> TrainProgramOutput: if do_profile and step - self._initial_step < profiler_capture_step: self._profiler.update_step_moving_mean(train_period.elapsed) + + if ( + do_profile + and step - self._initial_step == + profiler_capture_step + self._profiler._capture_num_steps + ): + self._profiler.stop_capture_async() + logging.log_first_n( logging.INFO, '[PAX STATUS]: Writing summaries (attempt).', 5 )