From 89419e3140c8a4d2a906e53115c1041e2868208b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 19 Oct 2025 19:03:56 +0000 Subject: [PATCH 1/3] Initial plan From 2427d5ae99c1f91064b41506470811042757a7d5 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 19 Oct 2025 19:09:33 +0000 Subject: [PATCH 2/3] Fix timing discrepancy when using --gemm_only or --comm_only flags When using --gemm_only or --comm_only flags, total_ms was incorrectly including overhead from barriers, NVTX ranges, and stream management. This made total_ms larger than the individual kernel times (gemm_ms or communication_ms). The fix ensures that when only one operation runs, total_ms uses the individual kernel time measured with CUDA events, which accurately reflects the actual kernel execution time without overhead. When both operations run together, we continue using do_bench() to properly measure overlapped/sequential execution time. Fixes examples 20 and 21. Co-authored-by: neoblizz <9790745+neoblizz@users.noreply.github.com> --- .../benchmark.py | 21 +++++++++++++++---- .../benchmark.py | 19 ++++++++++++++--- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/examples/20_gemm_all_scatter_independent/benchmark.py b/examples/20_gemm_all_scatter_independent/benchmark.py index abfad938..238c13c7 100755 --- a/examples/20_gemm_all_scatter_independent/benchmark.py +++ b/examples/20_gemm_all_scatter_independent/benchmark.py @@ -387,14 +387,27 @@ def run_experiment(): matmul.set_debug(False) shmem.info("Benchmarking...") perf = lambda ms: 2 * args["m"] * args["n"] * args["k"] * 1e-12 / (ms * 1e-3) - triton_ms = iris.do_bench(run_experiment, shmem.barrier) - triton_tflops = perf(triton_ms) - algo_string = "all_scatter" - + # Determine what was run based on flags run_gemm = not args["only_comm"] run_comm = not args["only_gemm"] + # When both operations run, measure total time including potential overlap + # When only one operation runs, use its individual kernel time to avoid overhead + if run_gemm and run_comm: + triton_ms = iris.do_bench(run_experiment, shmem.barrier) + else: + # Run benchmark to populate kernel_timing with accurate measurements + iris.do_bench(run_experiment, shmem.barrier) + # Use the individual kernel time as total time + if run_gemm: + triton_ms = kernel_timing["gemm"]["ms"] / kernel_timing["gemm"]["experiments"] + else: + triton_ms = kernel_timing["communication"]["ms"] / kernel_timing["communication"]["experiments"] + + triton_tflops = perf(triton_ms) + algo_string = "all_scatter" + if run_gemm and run_comm: op_string = f"tile matmul + {algo_string}" elif run_gemm: diff --git a/examples/21_gemm_one_shot_all_reduce_independent/benchmark.py b/examples/21_gemm_one_shot_all_reduce_independent/benchmark.py index fb45922c..4f36a8b4 100644 --- a/examples/21_gemm_one_shot_all_reduce_independent/benchmark.py +++ b/examples/21_gemm_one_shot_all_reduce_independent/benchmark.py @@ -381,13 +381,26 @@ def run_experiment(): matmul.set_debug(False) shmem.info("Benchmarking...") perf = lambda ms: 2 * args["m"] * args["n"] * args["k"] * 1e-12 / (ms * 1e-3) - triton_ms = iris.do_bench(run_experiment, shmem.barrier) - triton_tflops = perf(triton_ms) - + # Determine what was run based on flags run_gemm = not args["only_comm"] run_comm = not args["only_gemm"] + # When both operations run, measure total time including potential overlap + # When only one operation runs, use its individual kernel time to avoid overhead + if run_gemm and run_comm: + triton_ms = iris.do_bench(run_experiment, shmem.barrier) + else: + # Run benchmark to populate kernel_timing with accurate measurements + iris.do_bench(run_experiment, shmem.barrier) + # Use the individual kernel time as total time + if run_gemm: + triton_ms = kernel_timing["gemm"]["ms"] / kernel_timing["gemm"]["experiments"] + else: + triton_ms = kernel_timing["communication"]["ms"] / kernel_timing["communication"]["experiments"] + + triton_tflops = perf(triton_ms) + if run_gemm and run_comm: op_string = "tile matmul + one_shot_all_reduce (independent)" elif run_gemm: From 80869af3736469416195069bc2afdd2258e732c7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 19 Oct 2025 20:39:49 +0000 Subject: [PATCH 3/3] Reduce timing overhead by measuring actual work instead of do_bench wrapper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Instead of working around overhead by ignoring total_ms, this fix reduces overhead by: 1. Adding timing events that wrap the actual work in run_experiment() 2. Recording start after initial barrier, end before final barrier 3. Using these events for total_ms instead of do_bench return value This ensures total_ms accurately reflects kernel execution time plus NVTX/stream overhead (the actual work), excluding the barriers that do_bench adds. Now total_ms ≈ gemm_ms + communication_ms when both operations run. Co-authored-by: neoblizz <9790745+neoblizz@users.noreply.github.com> --- .../benchmark.py | 43 +++++++++++-------- .../benchmark.py | 41 ++++++++++-------- 2 files changed, 49 insertions(+), 35 deletions(-) diff --git a/examples/20_gemm_all_scatter_independent/benchmark.py b/examples/20_gemm_all_scatter_independent/benchmark.py index 238c13c7..5aa23c9f 100755 --- a/examples/20_gemm_all_scatter_independent/benchmark.py +++ b/examples/20_gemm_all_scatter_independent/benchmark.py @@ -220,6 +220,12 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): "ms": 0, "experiments": 0, }, + "total": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, } # Allocate Timestamps @@ -237,6 +243,9 @@ def run_experiment(): timestamps.reset() shmem.barrier() + # Record start of actual work (after initial barrier) + kernel_timing["total"]["start_event"].record() + # Determine what to run based on flags run_gemm = not args["only_comm"] run_comm = not args["only_gemm"] @@ -303,6 +312,10 @@ def run_experiment(): kernel_timing["communication"]["experiments"] += 1 torch.cuda.nvtx.range_pop() + # Record end of actual work (before final barrier) + kernel_timing["total"]["end_event"].record() + kernel_timing["total"]["experiments"] += 1 + shmem.barrier() # Update timing for operations that were run @@ -312,6 +325,10 @@ def run_experiment(): if run_comm: ms = kernel_timing["communication"]["start_event"].elapsed_time(kernel_timing["communication"]["end_event"]) kernel_timing["communication"]["ms"] += ms + + # Update total timing + ms = kernel_timing["total"]["start_event"].elapsed_time(kernel_timing["total"]["end_event"]) + kernel_timing["total"]["ms"] += ms torch.cuda.nvtx.range_pop() @@ -323,7 +340,7 @@ def run_experiment(): shmem.barrier() - for k in ["gemm", "communication"]: + for k in ["gemm", "communication", "total"]: kernel_timing[k]["ms"] = 0 kernel_timing[k]["experiments"] = 0 @@ -387,27 +404,17 @@ def run_experiment(): matmul.set_debug(False) shmem.info("Benchmarking...") perf = lambda ms: 2 * args["m"] * args["n"] * args["k"] * 1e-12 / (ms * 1e-3) - + # Run benchmark to populate all timing measurements + iris.do_bench(run_experiment, shmem.barrier) + # Use the total timing recorded inside run_experiment (excludes do_bench overhead) + triton_ms = kernel_timing["total"]["ms"] / kernel_timing["total"]["experiments"] + triton_tflops = perf(triton_ms) + algo_string = "all_scatter" + # Determine what was run based on flags run_gemm = not args["only_comm"] run_comm = not args["only_gemm"] - # When both operations run, measure total time including potential overlap - # When only one operation runs, use its individual kernel time to avoid overhead - if run_gemm and run_comm: - triton_ms = iris.do_bench(run_experiment, shmem.barrier) - else: - # Run benchmark to populate kernel_timing with accurate measurements - iris.do_bench(run_experiment, shmem.barrier) - # Use the individual kernel time as total time - if run_gemm: - triton_ms = kernel_timing["gemm"]["ms"] / kernel_timing["gemm"]["experiments"] - else: - triton_ms = kernel_timing["communication"]["ms"] / kernel_timing["communication"]["experiments"] - - triton_tflops = perf(triton_ms) - algo_string = "all_scatter" - if run_gemm and run_comm: op_string = f"tile matmul + {algo_string}" elif run_gemm: diff --git a/examples/21_gemm_one_shot_all_reduce_independent/benchmark.py b/examples/21_gemm_one_shot_all_reduce_independent/benchmark.py index 4f36a8b4..539795cc 100644 --- a/examples/21_gemm_one_shot_all_reduce_independent/benchmark.py +++ b/examples/21_gemm_one_shot_all_reduce_independent/benchmark.py @@ -211,6 +211,12 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): "ms": 0, "experiments": 0, }, + "total": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, } # Allocate Timestamps @@ -227,6 +233,9 @@ def run_experiment(): timestamps.reset() shmem.barrier() + # Record start of actual work (after initial barrier) + kernel_timing["total"]["start_event"].record() + # Determine what to run based on flags run_gemm = not args["only_comm"] run_comm = not args["only_gemm"] @@ -297,6 +306,10 @@ def run_experiment(): kernel_timing["communication"]["experiments"] += 1 torch.cuda.nvtx.range_pop() + # Record end of actual work (before final barrier) + kernel_timing["total"]["end_event"].record() + kernel_timing["total"]["experiments"] += 1 + shmem.barrier() # Update timing for operations that were run @@ -306,6 +319,10 @@ def run_experiment(): if run_comm: ms = kernel_timing["communication"]["start_event"].elapsed_time(kernel_timing["communication"]["end_event"]) kernel_timing["communication"]["ms"] += ms + + # Update total timing + ms = kernel_timing["total"]["start_event"].elapsed_time(kernel_timing["total"]["end_event"]) + kernel_timing["total"]["ms"] += ms torch.cuda.nvtx.range_pop() @@ -317,7 +334,7 @@ def run_experiment(): shmem.barrier() - for k in ["gemm", "communication"]: + for k in ["gemm", "communication", "total"]: kernel_timing[k]["ms"] = 0 kernel_timing[k]["experiments"] = 0 @@ -381,26 +398,16 @@ def run_experiment(): matmul.set_debug(False) shmem.info("Benchmarking...") perf = lambda ms: 2 * args["m"] * args["n"] * args["k"] * 1e-12 / (ms * 1e-3) - + # Run benchmark to populate all timing measurements + iris.do_bench(run_experiment, shmem.barrier) + # Use the total timing recorded inside run_experiment (excludes do_bench overhead) + triton_ms = kernel_timing["total"]["ms"] / kernel_timing["total"]["experiments"] + triton_tflops = perf(triton_ms) + # Determine what was run based on flags run_gemm = not args["only_comm"] run_comm = not args["only_gemm"] - # When both operations run, measure total time including potential overlap - # When only one operation runs, use its individual kernel time to avoid overhead - if run_gemm and run_comm: - triton_ms = iris.do_bench(run_experiment, shmem.barrier) - else: - # Run benchmark to populate kernel_timing with accurate measurements - iris.do_bench(run_experiment, shmem.barrier) - # Use the individual kernel time as total time - if run_gemm: - triton_ms = kernel_timing["gemm"]["ms"] / kernel_timing["gemm"]["experiments"] - else: - triton_ms = kernel_timing["communication"]["ms"] / kernel_timing["communication"]["experiments"] - - triton_tflops = perf(triton_ms) - if run_gemm and run_comm: op_string = "tile matmul + one_shot_all_reduce (independent)" elif run_gemm: