diff --git a/src/benchmark_utils.py b/src/benchmark_utils.py index ee1c478..20f0662 100644 --- a/src/benchmark_utils.py +++ b/src/benchmark_utils.py @@ -18,13 +18,28 @@ import subprocess import shutil +# The dictionary to map a JAX (collective) function to its main HLO. +TARGET_TASK_NAME_COLLECTIVES_MAP = { + "all_to_all_ici_op": r"all-to-all.[0-9]+", + "all_gather_ici_op": r"all-gather.[0-9]+", + "psum_ici_op": r"all-reduce.[0-9]+", + "ppermute_ici_op": r"collective-permute.[0-9]+", +} def simple_timeit(f, *args, matrix_dim=None, tries=10, task=None, trace_dir=None) -> float: """Simple utility to time a function for multiple runs.""" assert task is not None if trace_dir: - return timeit_from_trace(f, *args, matrix_dim=matrix_dim, tries=tries, task=task, trace_dir=trace_dir) + try: + outcomes_ms = timeit_from_trace( + f, *args, matrix_dim=matrix_dim, tries=tries, task=task, trace_dir=trace_dir + ) + if outcomes_ms is not None: + return outcomes_ms + print("Warning: timeit_from_trace returned empty results. Falling back to manual timing.") + except Exception as e: + print(f"Warning: Failed to get metrics from trace due to: {e}. Falling back to manual timing.") outcomes_ms = [] jax.block_until_ready(f(*args)) # warm it up! @@ -60,8 +75,19 @@ def get_trace(log_dir: str) -> dict[str, Any]: return trace -def get_metrics_from_trace(trace: dict[str, Any], task: str) -> float: +def get_metrics_from_trace(trace: dict[str, Any], task: str) -> list[float]: + + # Check if the given task name is a collective with corresponding TPU opertion. + # This is a workaround and should be reverted or refactored in future. + # If task is not present in the map, fallback to the default behavior to measure the timing from the CPU end. + if task in TARGET_TASK_NAME_COLLECTIVES_MAP: + try: + task = TARGET_TASK_NAME_COLLECTIVES_MAP[task] + return get_metrics_from_trace_tpu(trace, task) + except: + return None event_matcher = re.compile(task) + if "traceEvents" not in trace: raise KeyError("Key 'traceEvents' not found in trace.") @@ -85,6 +111,26 @@ def get_metrics_from_trace(trace: dict[str, Any], task: str) -> float: raise return durations_ms +def get_metrics_from_trace_tpu(trace: dict[str, Any], task: str) -> list[float]: + event_matcher = re.compile(task) + + if "traceEvents" not in trace: + raise KeyError("Key 'traceEvents' not found in trace.") + + events = [] + for e in trace["traceEvents"]: + if "name" in e and event_matcher.match(e["name"]): + events.append(e) + + # For each trace, find the TPU with smallest `pid` value and consider it to be TPU-0 + min_pid = min([e["pid"] for e in events]) + events_from_min_pid = [e for e in events if e["pid"] == min_pid] + try: + durations_ms = [float(e["args"]["device_duration_ps"]) / 1e9 for e in events_from_min_pid] + except KeyError: + print("KeyError: Key 'device_duration_ps' not found in the event object") + raise + return durations_ms def is_local_directory_path(dir: str) -> bool: """ diff --git a/src/run_benchmark.py b/src/run_benchmark.py index 594f4a7..2df741a 100644 --- a/src/run_benchmark.py +++ b/src/run_benchmark.py @@ -346,7 +346,7 @@ def run_single_benchmark(benchmark_config: Dict[str, Any]): test_name = f"t_{benchmark_name}_" + "".join( random.choices(string.ascii_uppercase + string.digits, k=10) ) - write_to_csv(f"{csv_path}/{test_name}.csv", calculate_metrics_results) + write_to_csv(f"{csv_path}/{test_name}.tsv", calculate_metrics_results) def main(config_path: str, multithreaded: bool): @@ -455,7 +455,7 @@ def run_benchmark_multithreaded(benchmark_config): calculate_metrics_results.append({"metadata": metadata, "metrics": metrics}) if csv_path: - write_to_csv(f"{csv_path}/{test_name}.csv", calculate_metrics_results) + write_to_csv(f"{csv_path}/{test_name}.tsv", calculate_metrics_results) if __name__ == "__main__":