Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 48 additions & 2 deletions src/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,28 @@
import subprocess
import shutil

# The dictionary to map a JAX (collective) function to its main HLO.
TARGET_TASK_NAME_COLLECTIVES_MAP = {
"all_to_all_ici_op": r"all-to-all.[0-9]+",
"all_gather_ici_op": r"all-gather.[0-9]+",
"psum_ici_op": r"all-reduce.[0-9]+",
"ppermute_ici_op": r"collective-permute.[0-9]+",
}

def simple_timeit(f, *args, matrix_dim=None, tries=10, task=None, trace_dir=None) -> float:
"""Simple utility to time a function for multiple runs."""
assert task is not None

if trace_dir:
return timeit_from_trace(f, *args, matrix_dim=matrix_dim, tries=tries, task=task, trace_dir=trace_dir)
try:
outcomes_ms = timeit_from_trace(
f, *args, matrix_dim=matrix_dim, tries=tries, task=task, trace_dir=trace_dir
)
if outcomes_ms is not None:
return outcomes_ms
print("Warning: timeit_from_trace returned empty results. Falling back to manual timing.")
except Exception as e:
print(f"Warning: Failed to get metrics from trace due to: {e}. Falling back to manual timing.")

outcomes_ms = []
jax.block_until_ready(f(*args)) # warm it up!
Expand Down Expand Up @@ -60,8 +75,19 @@ def get_trace(log_dir: str) -> dict[str, Any]:
return trace


def get_metrics_from_trace(trace: dict[str, Any], task: str) -> float:
def get_metrics_from_trace(trace: dict[str, Any], task: str) -> list[float]:

# Check if the given task name is a collective with corresponding TPU opertion.
# This is a workaround and should be reverted or refactored in future.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comment: if task is not present in the map, fallback to the default behavior to measure the timing from the CPU end.

# If task is not present in the map, fallback to the default behavior to measure the timing from the CPU end.
if task in TARGET_TASK_NAME_COLLECTIVES_MAP:
try:
task = TARGET_TASK_NAME_COLLECTIVES_MAP[task]
return get_metrics_from_trace_tpu(trace, task)
except:
return None
event_matcher = re.compile(task)

if "traceEvents" not in trace:
raise KeyError("Key 'traceEvents' not found in trace.")

Expand All @@ -85,6 +111,26 @@ def get_metrics_from_trace(trace: dict[str, Any], task: str) -> float:
raise
return durations_ms

def get_metrics_from_trace_tpu(trace: dict[str, Any], task: str) -> list[float]:
event_matcher = re.compile(task)

if "traceEvents" not in trace:
raise KeyError("Key 'traceEvents' not found in trace.")

events = []
for e in trace["traceEvents"]:
if "name" in e and event_matcher.match(e["name"]):
events.append(e)

# For each trace, find the TPU with smallest `pid` value and consider it to be TPU-0
min_pid = min([e["pid"] for e in events])
events_from_min_pid = [e for e in events if e["pid"] == min_pid]
try:
durations_ms = [float(e["args"]["device_duration_ps"]) / 1e9 for e in events_from_min_pid]
except KeyError:
print("KeyError: Key 'device_duration_ps' not found in the event object")
raise
return durations_ms

def is_local_directory_path(dir: str) -> bool:
"""
Expand Down
4 changes: 2 additions & 2 deletions src/run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def run_single_benchmark(benchmark_config: Dict[str, Any]):
test_name = f"t_{benchmark_name}_" + "".join(
random.choices(string.ascii_uppercase + string.digits, k=10)
)
write_to_csv(f"{csv_path}/{test_name}.csv", calculate_metrics_results)
write_to_csv(f"{csv_path}/{test_name}.tsv", calculate_metrics_results)


def main(config_path: str, multithreaded: bool):
Expand Down Expand Up @@ -455,7 +455,7 @@ def run_benchmark_multithreaded(benchmark_config):
calculate_metrics_results.append({"metadata": metadata, "metrics": metrics})

if csv_path:
write_to_csv(f"{csv_path}/{test_name}.csv", calculate_metrics_results)
write_to_csv(f"{csv_path}/{test_name}.tsv", calculate_metrics_results)


if __name__ == "__main__":
Expand Down