Skip to content

Commit acd30c9

Browse files
committed
Add exception get_metrics_from_trace_tpu.
1 parent 44727bc commit acd30c9

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

src/benchmark_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ def simple_timeit(f, *args, matrix_dim=None, tries=10, task=None, trace_dir=None
3131
assert task is not None
3232

3333
if trace_dir:
34-
return timeit_from_trace(f, *args, matrix_dim=matrix_dim, tries=tries, task=task, trace_dir=trace_dir)
34+
outcomes_ms = timeit_from_trace(f, *args, matrix_dim=matrix_dim, tries=tries, task=task, trace_dir=trace_dir)
35+
if outcomes_ms is not None:
36+
return outcomes_ms
3537

3638
outcomes_ms = []
3739
jax.block_until_ready(f(*args)) # warm it up!
@@ -71,9 +73,13 @@ def get_metrics_from_trace(trace: dict[str, Any], task: str) -> list[float]:
7173

7274
# Check if the given task name is a collective with corresponding TPU opertion.
7375
# This is a workaround and should be reverted or refactored in future.
76+
# If task is not present in the map, fallback to the default behavior to measure the timing from the CPU end.
7477
if task in TARGET_TASK_NAME_COLLECTIVES_MAP:
75-
task = TARGET_TASK_NAME_COLLECTIVES_MAP[task]
76-
return get_metrics_from_trace_tpu(trace, task)
78+
try:
79+
task = TARGET_TASK_NAME_COLLECTIVES_MAP[task]
80+
return get_metrics_from_trace_tpu(trace, task)
81+
except:
82+
return None
7783
event_matcher = re.compile(task)
7884

7985
if "traceEvents" not in trace:

0 commit comments

Comments
 (0)