Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 73 additions & 25 deletions autoparallel/graph_passes/debug_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,39 @@ def _get_tid(node):
if _is_communication_node(node):
if node.target == torch.ops._c10d_functional.wait_tensor.default:
return 0
return node.args[-1]
return f"group-{node.args[-1]}"
return 0


def _get_category(node):
"""Get trace category for node."""
if _is_communication_node(node):
return "nccl"
if node.op == "call_function":
target_str = str(node.target)
if any(x in target_str.lower() for x in ["mm", "bmm", "matmul", "addmm"]):
return "gemm"
if any(x in target_str.lower() for x in ["flash", "attention", "sdpa"]):
return "attention"
return "kernel"


def _get_collective_type(node):
"""Get the type of collective operation."""
if not _is_communication_node(node):
return None
target_str = str(node.target)
if "all_gather" in target_str:
return "AllGather"
elif "reduce_scatter" in target_str:
return "ReduceScatter"
elif "all_reduce" in target_str:
return "AllReduce"
elif "wait_tensor" in target_str:
return "Wait"
return "Unknown"


def get_repr(arg, mode="full"):
def get_dtype_repr(dtype):
return dtype_abbrs[dtype]
Expand Down Expand Up @@ -259,51 +288,70 @@ def get_dtype_repr(dtype):
def create_execution_trace(
gm: torch.fx.GraphModule,
runtime_estimator: Callable[[torch.fx.Node], float],
file_path: str = "fake_trace.json",
):
name: str = "fake_trace",
file_path: str | None = None,
) -> dict[str, Any]:
"""
Create a perfetto trace from a GraphModule representing its execution
trace. This is useful for inspecting communication-computation overlapping
for different reordering strategies.
"""
trace: dict[str, Any] = {}
launch_overhead = 1 # 1us
ms_to_us = 1000

trace_events = []
curr_time = {0: 0}
global_time: dict[torch.fx.Node, int] = {}
curr_time: dict[int | str, float] = {0: 0}
global_time: dict[torch.fx.Node, float] = {}

for node_idx, node in enumerate(gm.graph.nodes):
dur = int(runtime_estimator(node))
dur = runtime_estimator(node) * ms_to_us
tid = _get_tid(node)
if tid not in curr_time:
curr_time[tid] = curr_time[0]
event = {"ph": "X", "cat": "kernel", "name": str(node), "pid": 0, "tid": tid}

cat = _get_category(node)
coll_type = _get_collective_type(node)
node_name = f"nccl:{coll_type}:{node.name}" if coll_type else str(node)

event = {"ph": "X", "cat": cat, "name": node_name, "pid": 0, "tid": tid}

if _is_communication_node(node):
if tid == 0 and is_wait_tensor(node) and node.args[0].op != "placeholder":
# if it's wait tensor, let's sync with compute stream
comm_end_time = global_time.pop(node.args[0])
# if it's wait tensor, walk up chained waits to find the collective
# Now this may happen for all_to_all in dsv3 (wait(wait(all_to_all)))
comm_node = node.args[0]
while is_wait_tensor(comm_node):
comm_node = comm_node.args[0]
Comment on lines +323 to +324
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain a bit more what is going on here? Can there be multiple redundant wait_tensor calls in the graph>

comm_end_time = global_time[comm_node]
curr_time[tid] = max(curr_time[tid], comm_end_time)
else:
curr_time[tid] = max(curr_time[0], curr_time[tid])

event["ts"] = curr_time[tid]
event["dur"] = dur
launch_overhead = 1 # 1us
curr_time[tid] += dur + launch_overhead
if tid != 0:
curr_time[0] += launch_overhead
# keep track of when a given collective will finish
global_time[node] = curr_time[tid]

args: dict[str, Any] = {}
args["order"] = node_idx

args["output"] = get_repr(node, mode="content_only")
node_args = []
for arg in node.args:
node_args.append(get_repr(arg))
args["inputs"] = node_args
event["args"] = args
trace_events.append(event)
trace["traceEvents"] = trace_events
trace["traceName"] = "fake_trace.json"
with open(file_path, "w") as fp:
json.dump(trace, fp)
event["args"] = {
"order": node_idx,
"output": get_repr(node, mode="content_only"),
"inputs": [get_repr(arg) for arg in node.args],
}

if dur > 0.0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When can this happen, is it to remove wait_tensor and getitem operators?

I have found it nice to have a 1:1 mapping between the graph and the trace, is the added launch_overhead biasing the visualization?

trace_events.append(event)

trace = {
"traceEvents": trace_events,
"traceName": f"{name}_trace.json",
"displayTimeUnit": "us",
}

if file_path is not None:
with open(file_path, "w") as fp:
json.dump(trace, fp, indent=2)

return trace
8 changes: 8 additions & 0 deletions autoparallel/tools/overlap_simulator/colls16_8.table
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Group Group Size Collective 1MB (ms) 2MB (ms) 4MB (ms) 8MB (ms) 16MB (ms) 32MB (ms) 64MB (ms) 128MB (ms) 256MB (ms) 512MB (ms) 1024MB (ms) 2048MB (ms)
------- ------------ -------------------------- ---------- ---------- ---------- ---------- ----------- ----------- ----------- ------------ ------------ ------------ ------------- -------------
1 8 all_gather_into_tensor 0.0495 0.0716 0.1138 0.1953 0.3584 0.6846 1.3371 2.642 5.2518 10.4714 20.9105 41.7888
1 8 reduce_scatter_tensor 0.0173 0.0238 0.0368 0.0495 0.0716 0.1138 0.1953 0.3584 0.6846 1.3371 2.642 5.2518
1 8 all_reduce 0.028 0.041 0.0628 0.0849 0.1292 0.2179 0.3822 0.7084 1.3609 2.6658 5.2756 10.4952
0 16 all_gather_into_tensor 0.4977 0.8538 1.5291 2.8398 5.4613 10.7042 21.1899 42.1614 84.1045 167.9904 335.763 671.3073
0 16 reduce_scatter_tensor 0.1282 0.1638 0.2247 0.3136 0.4858 0.8166 1.4697 2.7525 5.2865 10.3546 20.4909 40.7633
0 16 all_gather_into_tensor_out 0.4977 0.8538 1.5291 2.8398 5.4613 10.7042 21.1899 42.1614 84.1045 167.9904 335.763 671.3073
8 changes: 8 additions & 0 deletions autoparallel/tools/overlap_simulator/colls32_8.table
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Group Group Size Collective 1MB (ms) 2MB (ms) 4MB (ms) 8MB (ms) 16MB (ms) 32MB (ms) 64MB (ms) 128MB (ms) 256MB (ms) 512MB (ms) 1024MB (ms) 2048MB (ms)
------- ------------ -------------------------- ---------- ---------- ---------- ---------- ----------- ----------- ----------- ------------ ------------ ------------ ------------- -------------
1 8 all_gather_into_tensor 0.0495 0.0716 0.1138 0.1953 0.3584 0.6846 1.3371 2.642 5.2518 10.4714 20.9105 41.7888
1 8 reduce_scatter_tensor 0.0173 0.0238 0.0368 0.0495 0.0716 0.1138 0.1953 0.3584 0.6846 1.3371 2.642 5.2518
1 8 all_reduce 0.028 0.041 0.0628 0.0849 0.1292 0.2179 0.3822 0.7084 1.3609 2.6658 5.2756 10.4952
0 32 all_gather_into_tensor 1.0136 1.7497 3.1512 5.86 11.2777 22.113 43.7835 87.1247 173.807 347.171 693.901 1387.36
0 32 reduce_scatter_tensor 0.2114 0.2612 0.3608 0.4615 0.6455 1.0136 1.7497 3.1512 5.86 11.2777 22.113 43.7835
0 32 all_gather_into_tensor_out 1.0136 1.7497 3.1512 5.86 11.2777 22.113 43.7835 87.1247 173.807 347.171 693.901 1387.36
6 changes: 6 additions & 0 deletions autoparallel/tools/overlap_simulator/colls64_1.table
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Group Group Size Collective 1MB (ms) 2MB (ms) 4MB (ms) 8MB (ms) 16MB (ms) 32MB (ms) 64MB (ms) 128MB (ms) 256MB (ms) 512MB (ms) 1024MB (ms) 2048MB (ms)
------- ------------ -------------------------- ---------- ---------- ---------- ---------- ----------- ----------- ----------- ------------ ------------ ------------ ------------- -------------
0 64 all_reduce 0.20 0.35 0.60 1.10 2.10 4.10 8.10 16.10 32.10 64.10 128.10 256.10
0 64 all_gather_into_tensor 0.25 0.45 0.80 1.50 2.90 5.70 11.30 22.50 44.90 89.70 179.30 358.50
0 64 reduce_scatter_tensor 0.25 0.45 0.80 1.50 2.90 5.70 11.30 22.50 44.90 89.70 179.30 358.50
0 64 all_gather_into_tensor_out 0.25 0.45 0.80 1.50 2.90 5.70 11.30 22.50 44.90 89.70 179.30 358.50
7 changes: 7 additions & 0 deletions autoparallel/tools/overlap_simulator/colls8_8.table
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Group Group Size Collective 1MB (ms) 2MB (ms) 4MB (ms) 8MB (ms) 16MB (ms) 32MB (ms) 64MB (ms) 128MB (ms) 256MB (ms) 512MB (ms) 1024MB (ms) 2048MB (ms)
------- ------------ -------------------------- ---------- ---------- ---------- ---------- ----------- ----------- ----------- ------------ ------------ ------------ ------------- -------------
1 8 all_reduce 0.028 0.041 0.0628 0.0849 0.1292 0.2179 0.3822 0.7084 1.3609 2.6658 5.2756 10.4952
1 8 all_gather_into_tensor 0.0495 0.0716 0.1138 0.1953 0.3584 0.6846 1.3371 2.642 5.2518 10.4714 20.9105 41.7888
0 8 reduce_scatter_tensor 0.0866 0.1151 0.1566 0.2397 0.4059 0.7181 1.3297 2.5531 4.9998 9.8931 19.6798 39.2532
0 8 all_gather_into_tensor_out 0.2397 0.4059 0.7181 1.3297 2.5531 4.9998 9.8931 19.6798 39.2532 78.4001 156.694 313.281
0 8 all_gather_into_tensor 0.2397 0.4059 0.7181 1.3297 2.5531 4.9998 9.8931 19.6798 39.2532 78.4001 156.694 313.281
Loading
Loading