Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 73 additions & 25 deletions autoparallel/graph_passes/debug_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,39 @@ def _get_tid(node):
if _is_communication_node(node):
if node.target == torch.ops._c10d_functional.wait_tensor.default:
return 0
return node.args[-1]
return f"group-{node.args[-1]}"
return 0


def _get_category(node):
"""Get trace category for node."""
if _is_communication_node(node):
return "nccl"
if node.op == "call_function":
target_str = str(node.target)
if any(x in target_str.lower() for x in ["mm", "bmm", "matmul", "addmm"]):
return "gemm"
if any(x in target_str.lower() for x in ["flash", "attention", "sdpa"]):
return "attention"
return "kernel"


def _get_collective_type(node):
"""Get the type of collective operation."""
if not _is_communication_node(node):
return None
target_str = str(node.target)
if "all_gather" in target_str:
return "AllGather"
elif "reduce_scatter" in target_str:
return "ReduceScatter"
elif "all_reduce" in target_str:
return "AllReduce"
elif "wait_tensor" in target_str:
return "Wait"
return "Unknown"


def get_repr(arg, mode="full"):
def get_dtype_repr(dtype):
return dtype_abbrs[dtype]
Expand Down Expand Up @@ -221,51 +250,70 @@ def get_dtype_repr(dtype):
def create_execution_trace(
gm: torch.fx.GraphModule,
runtime_estimator: Callable[[torch.fx.Node], float],
file_path: str = "fake_trace.json",
):
name: str = "fake_trace",
file_path: str | None = None,
) -> dict[str, Any]:
"""
Create a perfetto trace from a GraphModule representing its execution
trace. This is useful for inspecting communication-computation overlapping
for different reordering strategies.
"""
trace: dict[str, Any] = {}
launch_overhead = 1 # 1us
ms_to_us = 1000

trace_events = []
curr_time = {0: 0}
global_time: dict[torch.fx.Node, int] = {}
curr_time: dict[int | str, float] = {0: 0}
global_time: dict[torch.fx.Node, float] = {}

for node_idx, node in enumerate(gm.graph.nodes):
dur = int(runtime_estimator(node))
dur = runtime_estimator(node) * ms_to_us
tid = _get_tid(node)
if tid not in curr_time:
curr_time[tid] = curr_time[0]
event = {"ph": "X", "cat": "kernel", "name": str(node), "pid": 0, "tid": tid}

cat = _get_category(node)
coll_type = _get_collective_type(node)
node_name = f"nccl:{coll_type}:{node.name}" if coll_type else str(node)

event = {"ph": "X", "cat": cat, "name": node_name, "pid": 0, "tid": tid}

if _is_communication_node(node):
if tid == 0 and is_wait_tensor(node) and node.args[0].op != "placeholder":
# if it's wait tensor, let's sync with compute stream
comm_end_time = global_time.pop(node.args[0])
# if it's wait tensor, walk up chained waits to find the collective
# Now this may happen for all_to_all in dsv3 (wait(wait(all_to_all)))
comm_node = node.args[0]
while is_wait_tensor(comm_node):
comm_node = comm_node.args[0]
comm_end_time = global_time.pop(comm_node)
curr_time[tid] = max(curr_time[tid], comm_end_time)
else:
curr_time[tid] = max(curr_time[0], curr_time[tid])

event["ts"] = curr_time[tid]
event["dur"] = dur
launch_overhead = 1 # 1us
curr_time[tid] += dur + launch_overhead
if tid != 0:
curr_time[0] += launch_overhead
# keep track of when a given collective will finish
global_time[node] = curr_time[tid]

args: dict[str, Any] = {}
args["order"] = node_idx

args["output"] = get_repr(node, mode="content_only")
node_args = []
for arg in node.args:
node_args.append(get_repr(arg))
args["inputs"] = node_args
event["args"] = args
trace_events.append(event)
trace["traceEvents"] = trace_events
trace["traceName"] = "fake_trace.json"
with open(file_path, "w") as fp:
json.dump(trace, fp)
event["args"] = {
"order": node_idx,
"output": get_repr(node, mode="content_only"),
"inputs": [get_repr(arg) for arg in node.args],
}

if dur > 0.0:
trace_events.append(event)

trace = {
"traceEvents": trace_events,
"traceName": f"{name}_trace.json",
"displayTimeUnit": "us",
}

if file_path is not None:
with open(file_path, "w") as fp:
json.dump(trace, fp, indent=2)

return trace
8 changes: 8 additions & 0 deletions autoparallel/tools/overlap_simulator/colls32_8.table
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Group Group Size Collective 1MB (ms) 2MB (ms) 4MB (ms) 8MB (ms) 16MB (ms) 32MB (ms) 64MB (ms) 128MB (ms) 256MB (ms) 512MB (ms) 1024MB (ms) 2048MB (ms)
------- ------------ -------------------------- ---------- ---------- ---------- ---------- ----------- ----------- ----------- ------------ ------------ ------------ ------------- -------------
1 8 all_gather_into_tensor 0.0495 0.0716 0.1138 0.1953 0.3584 0.6846 1.3371 2.642 5.2518 10.4714 20.9105 41.7888
1 8 reduce_scatter_tensor 0.0173 0.0238 0.0368 0.0495 0.0716 0.1138 0.1953 0.3584 0.6846 1.3371 2.642 5.2518
1 8 all_reduce 0.028 0.041 0.0628 0.0849 0.1292 0.2179 0.3822 0.7084 1.3609 2.6658 5.2756 10.4952
0 32 all_gather_into_tensor 1.0136 1.7497 3.1512 5.86 11.2777 22.113 43.7835 87.1247 173.807 347.171 693.901 1387.36
0 32 reduce_scatter_tensor 0.2114 0.2612 0.3608 0.4615 0.6455 1.0136 1.7497 3.1512 5.86 11.2777 22.113 43.7835
0 32 all_gather_into_tensor_out 1.0136 1.7497 3.1512 5.86 11.2777 22.113 43.7835 87.1247 173.807 347.171 693.901 1387.36
6 changes: 6 additions & 0 deletions autoparallel/tools/overlap_simulator/colls64_1.table
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Group Group Size Collective 1MB (ms) 2MB (ms) 4MB (ms) 8MB (ms) 16MB (ms) 32MB (ms) 64MB (ms) 128MB (ms) 256MB (ms) 512MB (ms) 1024MB (ms) 2048MB (ms)
------- ------------ -------------------------- ---------- ---------- ---------- ---------- ----------- ----------- ----------- ------------ ------------ ------------ ------------- -------------
0 64 all_reduce 0.20 0.35 0.60 1.10 2.10 4.10 8.10 16.10 32.10 64.10 128.10 256.10
0 64 all_gather_into_tensor 0.25 0.45 0.80 1.50 2.90 5.70 11.30 22.50 44.90 89.70 179.30 358.50
0 64 reduce_scatter_tensor 0.25 0.45 0.80 1.50 2.90 5.70 11.30 22.50 44.90 89.70 179.30 358.50
0 64 all_gather_into_tensor_out 0.25 0.45 0.80 1.50 2.90 5.70 11.30 22.50 44.90 89.70 179.30 358.50
7 changes: 7 additions & 0 deletions autoparallel/tools/overlap_simulator/colls8_8.table
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Group Group Size Collective 1MB (ms) 2MB (ms) 4MB (ms) 8MB (ms) 16MB (ms) 32MB (ms) 64MB (ms) 128MB (ms) 256MB (ms) 512MB (ms) 1024MB (ms) 2048MB (ms)
------- ------------ -------------------------- ---------- ---------- ---------- ---------- ----------- ----------- ----------- ------------ ------------ ------------ ------------- -------------
1 8 all_reduce 0.028 0.041 0.0628 0.0849 0.1292 0.2179 0.3822 0.7084 1.3609 2.6658 5.2756 10.4952
1 8 all_gather_into_tensor 0.0495 0.0716 0.1138 0.1953 0.3584 0.6846 1.3371 2.642 5.2518 10.4714 20.9105 41.7888
0 8 reduce_scatter_tensor 0.0866 0.1151 0.1566 0.2397 0.4059 0.7181 1.3297 2.5531 4.9998 9.8931 19.6798 39.2532
0 8 all_gather_into_tensor_out 0.2397 0.4059 0.7181 1.3297 2.5531 4.9998 9.8931 19.6798 39.2532 78.4001 156.694 313.281
0 8 all_gather_into_tensor 0.2397 0.4059 0.7181 1.3297 2.5531 4.9998 9.8931 19.6798 39.2532 78.4001 156.694 313.281
5 changes: 5 additions & 0 deletions autoparallel/tools/overlap_simulator/colls_dsv3_bw_128.table
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Group Group Size Collective 1MB (ms) 2MB (ms) 4MB (ms) 8MB (ms) 16MB (ms) 32MB (ms) 64MB (ms) 128MB (ms) 256MB (ms) 512MB (ms) 1024MB (ms) 2048MB (ms)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
0 128 all_gather_into_tensor 0.7353 1.1812 2.1541 4.3082 8.6164 17.2329 34.4658 68.9316 137.8632 275.7263 551.4527 1102.9053
0 128 reduce_scatter_tensor 0.2101 0.2337 0.2809 0.3151 0.3435 0.4002 0.5136 0.7195 1.1094 1.8893 3.5414 7.0828
1025 16 reduce_scatter_tensor 0.0230 0.0459 0.0918 0.1837 0.3674 0.7348 1.4695 2.9391 5.8781 11.7562 23.5124 47.0249
6 changes: 6 additions & 0 deletions autoparallel/tools/overlap_simulator/colls_dsv3_bw_64.table
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Group Group Size Collective 1MB (ms) 2MB (ms) 4MB (ms) 8MB (ms) 16MB (ms) 32MB (ms) 64MB (ms) 128MB (ms) 256MB (ms) 512MB (ms) 1024MB (ms) 2048MB (ms)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
0 64 all_gather_into_tensor 0.3625 0.5548 0.9394 1.7562 3.5123 7.0246 14.0493 28.0986 56.1971 112.3942 224.7885 449.5770
0 64 reduce_scatter_tensor 0.1140 0.1283 0.1570 0.1765 0.2047 0.2609 0.3734 0.5679 0.9258 1.6418 3.1324 6.2648
513 8 all_gather_into_tensor 0.1723 0.3445 0.6890 1.3780 2.7561 5.5121 11.0243 22.0486 44.0972 88.1943 176.3887 352.7773
513 8 reduce_scatter_tensor 0.0203 0.0406 0.0813 0.1626 0.3252 0.6504 1.3007 2.6015 5.2029 10.4058 20.8116 41.6233
4 changes: 4 additions & 0 deletions autoparallel/tools/overlap_simulator/colls_dsv3_fw_128.table
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Group Group Size Collective 1MB (ms) 2MB (ms) 4MB (ms) 8MB (ms) 16MB (ms) 32MB (ms) 64MB (ms) 128MB (ms) 256MB (ms) 512MB (ms) 1024MB (ms) 2048MB (ms)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
0 128 all_gather_into_tensor 0.7353 1.1812 2.1541 4.3082 8.6164 17.2329 34.4658 68.9316 137.8632 275.7263 551.4527 1102.9053
1025 16 all_gather_into_tensor 0.4071 0.8142 1.6284 3.2567 6.5135 13.0269 26.0538 52.1076 104.2153 208.4305 416.8611 833.7222
4 changes: 4 additions & 0 deletions autoparallel/tools/overlap_simulator/colls_dsv3_fw_64.table
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Group Group Size Collective 1MB (ms) 2MB (ms) 4MB (ms) 8MB (ms) 16MB (ms) 32MB (ms) 64MB (ms) 128MB (ms) 256MB (ms) 512MB (ms) 1024MB (ms) 2048MB (ms)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
0 64 all_gather_into_tensor 0.3625 0.5548 0.9394 1.7562 3.5123 7.0246 14.0493 28.0986 56.1971 112.3942 224.7885 449.5770
513 8 all_gather_into_tensor 0.1723 0.3445 0.6890 1.3780 2.7561 5.5121 11.0243 22.0486 44.0972 88.1943 176.3887 352.7773
Loading
Loading