From 7d805eb0612d2965048402d85c85d215f12a18ce Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 19 Oct 2025 05:09:15 +0000 Subject: [PATCH 1/5] Initial plan From 38236949f09a1e0a130a9973827a20201ac90eb4 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 19 Oct 2025 05:13:23 +0000 Subject: [PATCH 2/5] Add examples 15, 16, and 20 to documentation Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- docs/reference/examples.md | 3 +++ examples/README.md | 8 ++++++++ 2 files changed, 11 insertions(+) diff --git a/docs/reference/examples.md b/docs/reference/examples.md index 1d54c490..4683e072 100644 --- a/docs/reference/examples.md +++ b/docs/reference/examples.md @@ -24,6 +24,9 @@ We've curated a growing collection of practical examples that showcase the power - **[12_gemm_all_scatter_bulk_synchronous](https://github.com/ROCm/iris/tree/main/examples/12_gemm_all_scatter_bulk_synchronous)**: Matrix multiplication with all-scatter using the bulk synchronous parallel approach - **[13_flash_decode](https://github.com/ROCm/iris/tree/main/examples/13_flash_decode)**: Fused Flash Decode Attention for accelerating LLM inference - **[14_all_gather_gemm](https://github.com/ROCm/iris/tree/main/examples/14_all_gather_gemm)**: Fused All-Gather + GEMM with Pull and Push models +- **[15_gemm_all_reduce_ring_based](https://github.com/ROCm/iris/tree/main/examples/15_gemm_all_reduce_ring_based)**: Matrix multiplication with ring-based all-reduce +- **[16_all_reduce_ring_based](https://github.com/ROCm/iris/tree/main/examples/16_all_reduce_ring_based)**: Ring-based all-reduce operation +- **[20_gemm_all_scatter_independent](https://github.com/ROCm/iris/tree/main/examples/20_gemm_all_scatter_independent)**: Independent GEMM and all-scatter operations with support for CSV input configurations ### Utilities - **[benchmark](https://github.com/ROCm/iris/tree/main/examples/benchmark)**: Benchmarking utilities and performance testing tools diff --git a/examples/README.md b/examples/README.md index 4972bd94..82c812cf 100644 --- a/examples/README.md +++ b/examples/README.md @@ -26,6 +26,8 @@ This directory contains various algorithm implementations for distributed comput - [`12_gemm_all_scatter_bulk_synchronous`](12_gemm_all_scatter_bulk_synchronous): Matrix multiplication with all-scatter using the bulk synchronous parallel approach - [`13_flash_decode`](13_flash_decode): Fused Flash Decode Attention for accelerating LLM inference - [`14_all_gather_gemm`](14_all_gather_gemm): Fused All-Gather + GEMM with Pull and Push models +- [`15_gemm_all_reduce_ring_based`](15_gemm_all_reduce_ring_based): Matrix multiplication with ring-based all-reduce +- [`16_all_reduce_ring_based`](16_all_reduce_ring_based): Ring-based all-reduce operation - [`20_gemm_all_scatter_independent`](20_gemm_all_scatter_independent): Independent GEMM and all-scatter operations with support for CSV input configurations ### Utilities @@ -82,6 +84,12 @@ python examples/14_all_gather_gemm/example_run_pull.py --num_ranks 8 # All-Gather + GEMM - Push model python examples/14_all_gather_gemm/example_run_push.py --num_ranks 8 +# Example command to run benchmark with ring-based all-reduce for GEMM +python examples/15_gemm_all_reduce_ring_based/benchmark.py --benchmark --validate --num_ranks 8 + +# Example command to run benchmark with ring-based all-reduce +python examples/16_all_reduce_ring_based/benchmark.py --benchmark --validate --num_ranks 8 + # Independent GEMM and all-scatter - single configuration python examples/20_gemm_all_scatter_independent/benchmark.py --benchmark --validate --num_ranks 8 From 5011648fa9a61647594d785f844b30d45253e333 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 19 Oct 2025 07:55:49 +0000 Subject: [PATCH 3/5] Merge main and add examples 17 and 21 to documentation Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- docs/conf.py | 15 +- docs/reference/examples.md | 2 + docs/reference/gluon/device-functions.md | 8 - docs/reference/triton/device-functions.md | 8 - .../benchmark.py | 343 +++++++++++++ .../gemm_one_shot_all_reduce_pc.py | 251 +++++++++ .../matmul_wrapper.py | 174 +++++++ .../benchmark.py | 475 ++++++++++++++++++ .../example_config.csv | 3 + .../gemm_one_shot_all_reduce_independent.py | 227 +++++++++ .../matmul_wrapper.py | 163 ++++++ examples/README.md | 40 +- examples/common/validation.py | 37 ++ iris/experimental/iris_gluon.py | 1 + 14 files changed, 1721 insertions(+), 26 deletions(-) create mode 100644 examples/17_gemm_one_shot_all_reduce_pc/benchmark.py create mode 100644 examples/17_gemm_one_shot_all_reduce_pc/gemm_one_shot_all_reduce_pc.py create mode 100644 examples/17_gemm_one_shot_all_reduce_pc/matmul_wrapper.py create mode 100644 examples/21_gemm_one_shot_all_reduce_independent/benchmark.py create mode 100644 examples/21_gemm_one_shot_all_reduce_independent/example_config.csv create mode 100644 examples/21_gemm_one_shot_all_reduce_independent/gemm_one_shot_all_reduce_independent.py create mode 100644 examples/21_gemm_one_shot_all_reduce_independent/matmul_wrapper.py diff --git a/docs/conf.py b/docs/conf.py index bc28c14d..d341b626 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -109,13 +109,22 @@ def __call__(self, func): return func -# Mock triton modules -sys.modules["triton"] = MagicMock() -sys.modules["triton.language"] = MagicMock() +# Mock triton.language first +triton_language_mock = MagicMock() +sys.modules["triton.language"] = triton_language_mock sys.modules["triton.language.core"] = MagicMock() sys.modules["triton.language.core"]._aggregate = lambda cls: cls # Preserve class +# Mock triton modules with docstring-preserving jit decorator +class TritonMock: + jit = PreserveDocstringMock() + language = triton_language_mock + + +sys.modules["triton"] = TritonMock() + + # Mock gluon with docstring-preserving jit class GluonMock: jit = PreserveDocstringMock() diff --git a/docs/reference/examples.md b/docs/reference/examples.md index 4683e072..7ff18529 100644 --- a/docs/reference/examples.md +++ b/docs/reference/examples.md @@ -26,7 +26,9 @@ We've curated a growing collection of practical examples that showcase the power - **[14_all_gather_gemm](https://github.com/ROCm/iris/tree/main/examples/14_all_gather_gemm)**: Fused All-Gather + GEMM with Pull and Push models - **[15_gemm_all_reduce_ring_based](https://github.com/ROCm/iris/tree/main/examples/15_gemm_all_reduce_ring_based)**: Matrix multiplication with ring-based all-reduce - **[16_all_reduce_ring_based](https://github.com/ROCm/iris/tree/main/examples/16_all_reduce_ring_based)**: Ring-based all-reduce operation +- **[17_gemm_one_shot_all_reduce_pc](https://github.com/ROCm/iris/tree/main/examples/17_gemm_one_shot_all_reduce_pc)**: Matrix multiplication with one-shot all-reduce using producer-consumer pattern with two distribution modes (striding and block) - **[20_gemm_all_scatter_independent](https://github.com/ROCm/iris/tree/main/examples/20_gemm_all_scatter_independent)**: Independent GEMM and all-scatter operations with support for CSV input configurations +- **[21_gemm_one_shot_all_reduce_independent](https://github.com/ROCm/iris/tree/main/examples/21_gemm_one_shot_all_reduce_independent)**: Independent GEMM and all-reduce operations with support for CSV input configurations and selective execution ### Utilities - **[benchmark](https://github.com/ROCm/iris/tree/main/examples/benchmark)**: Benchmarking utilities and performance testing tools diff --git a/docs/reference/gluon/device-functions.md b/docs/reference/gluon/device-functions.md index c9e34de0..a7bb6c4f 100644 --- a/docs/reference/gluon/device-functions.md +++ b/docs/reference/gluon/device-functions.md @@ -6,14 +6,6 @@ The Gluon API is **experimental** and may undergo breaking changes in future rel Device-side functions provided by Iris Gluon for remote memory operations and atomics. These methods are part of the `IrisDeviceCtx` aggregate used within Gluon kernels. -```{eval-rst} -.. automodule:: iris.experimental.iris_gluon - :noindex: - :members: - :no-undoc-members: - :no-show-inheritance: -``` - ## Initialization ### initialize diff --git a/docs/reference/triton/device-functions.md b/docs/reference/triton/device-functions.md index 50a0f18d..fa4f5776 100644 --- a/docs/reference/triton/device-functions.md +++ b/docs/reference/triton/device-functions.md @@ -2,14 +2,6 @@ Device-side functions provided by Iris for remote memory operations and atomics. -```{eval-rst} -.. automodule:: iris.iris - :noindex: - :members: - :no-undoc-members: - :no-show-inheritance: -``` - ## Memory transfer operations ### load diff --git a/examples/17_gemm_one_shot_all_reduce_pc/benchmark.py b/examples/17_gemm_one_shot_all_reduce_pc/benchmark.py new file mode 100644 index 00000000..5dfe9c9f --- /dev/null +++ b/examples/17_gemm_one_shot_all_reduce_pc/benchmark.py @@ -0,0 +1,343 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import triton +import random +import sys +import os +import argparse +import json + +from examples.common.utils import ( + JSONWriter, + Timestamps, + is_triton_interpret_set, +) + +import iris + +from matmul_wrapper import matmul +from gemm_one_shot_all_reduce_pc import persistent_all_reduce +from examples.common.validation import validate_gemm + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Parse matrix dimensions and configuration.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=8192, help="Number of rows in matrix A") + parser.add_argument("-n", type=int, default=4608, help="Number of columns in matrix B") + parser.add_argument("-k", type=int, default=36864, help="Common dimension between matrices A and B") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-t", "--trace_tiles", action="store_true", help="Enable tile-tracing mode") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "int8", "bf16"], + help="Datatype of computation", + ) + parser.add_argument( + "--output_file", + type=str, + default="log.json", + help="Output file", + ) + parser.add_argument("--BLK_M", type=int, default=256, help="Block size M") + parser.add_argument("--BLK_N", type=int, default=256, help="Block size N") + parser.add_argument("--BLK_K", type=int, default=64, help="Block size K") + parser.add_argument("--gsize_m", type=int, default=6, help="Grid size M") + parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") + parser.add_argument("--gemm_sms", type=int, default=256, help="Number of SMs for GEMM kernel") + parser.add_argument("--comm_sms", type=int, default=48, help="Number of SMs for All-Reduce kernel") + parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") + parser.add_argument( + "--distribution", + type=int, + default=0, + choices=[0, 1], + help="Distribution mode: 0=striding, 1=block", + ) + + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + + # Main benchmark logic + shmem = iris.iris(args["heap_size"]) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + cu_count = shmem.get_cu_count() + num_xcds = iris.hip.get_num_xcc() + + # GEMM + datatype = torch.float32 + if args["datatype"] == "fp16": + datatype = torch.float16 + elif args["datatype"] == "fp32": + datatype = torch.float32 + elif args["datatype"] == "int8": + datatype = torch.int8 + elif args["datatype"] == "bf16": + datatype = torch.bfloat16 + else: + print("Unknown datatype.") + exit(1) + + assert args["n"] % world_size == 0, f"N ({args['n']}) must be divisible by world size ({world_size})." + assert args["k"] % world_size == 0, f"K ({args['k']}) must be divisible by world size ({world_size})." + + A = shmem.randn(args["m"], args["k"], device="cuda", dtype=datatype) + B = shmem.randn(args["n"], args["k"], device="cuda", dtype=datatype).T + + args["M"] = args["m"] + args["N"] = args["n"] + args["K"] = args["k"] + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + + # Splitting + rows_per_gpu = args["k"] // world_size + args["k"] = rows_per_gpu + start_row = rank * rows_per_gpu + end_row = start_row + rows_per_gpu + local_B = B[start_row:end_row, :] + local_A = A[:, start_row:end_row] + + for key, value in args.items(): + json_writer.add_field(key, value) + + C_global = shmem.zeros((args["M"], args["N"]), device="cuda", dtype=A.dtype) + local_C = shmem.zeros((args["m"], args["n"]), device="cuda", dtype=torch.float32) + + total_blocks_M = triton.cdiv(args["m"], args["BLK_M"]) + total_blocks_N = triton.cdiv(args["n"], args["BLK_N"]) + total_tiles = total_blocks_M * total_blocks_N + + locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) + tile_ready = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) + + bias = None + + num_xcds = iris.hip.get_num_xcc() + + gemm_stream = torch.cuda.Stream() + comm_stream = torch.cuda.Stream() + + json_writer.add_field("gemm_sms", args["gemm_sms"]) + json_writer.add_field("comm_sms", args["comm_sms"]) + json_writer.add_field("distribution", args["distribution"]) + + kernel_timing = { + "gemm": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + "communication": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + } + + # Timestamps + timestamps = Timestamps(num_tiles=total_tiles) + + def preamble(): + shmem.barrier() + locks.zero_() + tile_ready.zero_() + local_C.zero_() + C_global.zero_() + shmem.barrier() + + def run_experiment(): + nonlocal local_C + nonlocal C_global + nonlocal kernel_timing + + shmem.barrier() + + if args["trace_tiles"]: + timestamps.reset() + shmem.barrier() + + torch.cuda.nvtx.range_push("GEMM + All-Reduce") + torch.cuda.nvtx.range_push("GEMM") + with torch.cuda.stream(gemm_stream): + kernel_timing["gemm"]["start_event"].record() + local_C = matmul.apply( + local_A, + local_B, + local_C, + C_global, + bias, + locks, + tile_ready, + rank, + world_size, + args["gemm_sms"], + args["BLK_M"], + args["BLK_N"], + args["BLK_K"], + args["gsize_m"], + shmem.get_heap_bases(), + "gfx942", + args["trace_tiles"], + timestamps.mm_begin_timestamp, + timestamps.mm_end_timestamp, + ) + kernel_timing["gemm"]["end_event"].record() + kernel_timing["gemm"]["experiments"] += 1 + + torch.cuda.nvtx.range_pop() + torch.cuda.nvtx.range_push("All-Reduce") + with torch.cuda.stream(comm_stream): + kernel_timing["communication"]["start_event"].record() + persistent_all_reduce[(args["comm_sms"],)]( + local_C, + C_global, + locks, + tile_ready, + args["M"], + args["N"], + local_C.stride(0), + local_C.stride(1), + C_global.stride(0), + C_global.stride(1), + args["BLK_M"], + args["BLK_N"], + args["gsize_m"], + args["comm_sms"], + num_xcds, + shmem.get_heap_bases(), + rank, + world_size, + args["distribution"], + args["trace_tiles"], + timestamps.mm_begin_timestamp, + timestamps.mm_end_timestamp, + ) + kernel_timing["communication"]["end_event"].record() + kernel_timing["communication"]["experiments"] += 1 + torch.cuda.nvtx.range_pop() + shmem.barrier() + + for k in ["gemm", "communication"]: + ms = kernel_timing[k]["start_event"].elapsed_time(kernel_timing[k]["end_event"]) + kernel_timing[k]["ms"] += ms + + torch.cuda.nvtx.range_pop() + + # Synchronize across all GPUs + shmem.barrier() + + # Warmup + run_experiment() + + shmem.barrier() + preamble() + shmem.barrier() + + for k in ["gemm", "communication"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + if args["validate"]: + shmem.info("Validating...") + matmul.set_debug(True) + # Validate global result + success = validate_gemm(A, B, C_global, shmem, atol=2) + passed_str = "passed" if success else "failed" + shmem.info(f"Final C validation {passed_str}.") + + # Wait for all to finish validation + shmem.barrier() + shmem.info("Validation completed") + + json_writer.add_field("success", success) + + if not is_triton_interpret_set(): + gemm_registers = matmul.get_matmul_registers() + gemm_spills = matmul.get_matmul_spills() + + json_writer.add_field("gemm_registers", gemm_registers) + json_writer.add_field("gemm_spills", gemm_spills) + + if args["benchmark"]: + matmul.set_debug(False) + shmem.info("Benchmarking...") + perf = lambda ms: 2 * args["M"] * args["N"] * args["K"] * 1e-12 / (ms * 1e-3) + triton_ms = iris.do_bench(run_experiment, shmem.barrier, preamble) + triton_tflops = perf(triton_ms) + dist_mode = "striding" if args["distribution"] == 0 else "block" + algo_string = f"one_shot_all_reduce_pc_{dist_mode}" + shmem.info(f"tile matmul + {algo_string} (grid={total_tiles}): {triton_ms:.3f} ms {triton_tflops:.3f} tflops") + + json_writer.add_field("tflops", triton_tflops) + json_writer.add_field("total_ms", triton_ms) + + for k in ["gemm", "communication"]: + json_writer.add_field(k + "_ms", kernel_timing[k]["ms"] / kernel_timing[k]["experiments"]) + json_writer.add_field(k + "_experiments", kernel_timing[k]["experiments"]) + + # Wait for all to finish benchmarking + shmem.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + if args["trace_tiles"] and rank == 0: + gpu_freq = iris.hip.get_wall_clock_rate(rank) * 1e-3 + dist_mode = "striding" if args["distribution"] == 0 else "block" + algo_string = f"one_shot_all_reduce_pc_{dist_mode}" + filename = f"gemm_tiles_{algo_string}_trace_rank{rank}.json" + timestamps.to_json(filename, gpu_freq) + + shmem.barrier() + + dist.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + + num_ranks = args["num_ranks"] + + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/17_gemm_one_shot_all_reduce_pc/gemm_one_shot_all_reduce_pc.py b/examples/17_gemm_one_shot_all_reduce_pc/gemm_one_shot_all_reduce_pc.py new file mode 100644 index 00000000..c0680e3e --- /dev/null +++ b/examples/17_gemm_one_shot_all_reduce_pc/gemm_one_shot_all_reduce_pc.py @@ -0,0 +1,251 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import triton +import triton.language as tl +from examples.common.utils import read_realtime + +import sys +import os + +import iris + + +@triton.jit() +def persistent_gemm( + A, + B, + local_C, + bias_ptr, + locks, + tile_ready, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm_local, + stride_cn_local, + stride_bias, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + GEMM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + BIAS: tl.constexpr, + EVEN_K: tl.constexpr, + heap_bases: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + COLLECT_TIMESTAMPS: tl.constexpr = False, + mm_begin_timestamp_ptr: tl.tensor = None, + mm_end_timestamp_ptr: tl.tensor = None, +): + """ + Producer kernel: Computes all tiles (each rank produces partial results). + All ranks process all tiles and produce partials because K is split across ranks. + """ + pid = tl.program_id(0) + + if NUM_XCDS != 1: + pid = (pid % NUM_XCDS) * (GEMM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_cm_local > 0) + tl.assume(stride_cn_local > 0) + + acc_dtype = tl.float32 if local_C.type.element_ty != tl.int8 else tl.int32 + + # All ranks process all tiles + for tile_id in range(pid, total_tiles, GEMM_SMS): + if COLLECT_TIMESTAMPS: + timestamp = read_realtime() + tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) + + # Map tile_id to (pid_m, pid_n) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + # Compute GEMM for this tile + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + + rk = tl.arange(0, BLOCK_SIZE_K) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + loop_k = tl.cdiv(K, BLOCK_SIZE_K) + if not EVEN_K: + loop_k -= 1 + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for k in range(0, loop_k): + a = tl.load(tl.multiple_of(A_BASE, (1, 16))) + b = tl.load(tl.multiple_of(B_BASE, (16, 1))) + acc += tl.dot(a, b) + A_BASE += BLOCK_SIZE_K * stride_ak + B_BASE += BLOCK_SIZE_K * stride_bk + + if not EVEN_K: + k = loop_k + rk = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + A_BASE = tl.multiple_of(A_BASE, (1, 16)) + B_BASE = tl.multiple_of(B_BASE, (16, 1)) + a = tl.load(A_BASE, mask=rk[None, :] < K, other=0.0) + b = tl.load(B_BASE, mask=rk[:, None] < K, other=0.0) + acc += tl.dot(a, b) + + # Store result locally + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + mask = (rm[:, None] < M) & (rn[None, :] < N) + local_offset = rm[:, None] * stride_cm_local + rn[None, :] * stride_cn_local + + # Write to local buffer + tl.store(local_C + local_offset, acc, mask=mask, cache_modifier=".wt") + + # Signal that this tile is ready + tl.debug_barrier() + tl.store(locks + tile_id, 1, cache_modifier=".wt") + + # Signal to all remote ranks that this tile is ready + for remote_rank in range(world_size): + if remote_rank != cur_rank: + iris.atomic_xchg(tile_ready + tile_id, 1, cur_rank, remote_rank, heap_bases, sem="release", scope="sys") + + if COLLECT_TIMESTAMPS: + timestamp = read_realtime() + tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) + + +@triton.jit() +def persistent_all_reduce( + local_C, + C_global, + locks, + tile_ready, + M, + N, + stride_cm_local, + stride_cn_local, + stride_cm_global, + stride_cn_global, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + COMM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + heap_bases: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + DISTRIBUTION: tl.constexpr, # 0 for striding, 1 for block + COLLECT_TIMESTAMPS: tl.constexpr = False, + mm_begin_timestamp_ptr: tl.tensor = None, + mm_end_timestamp_ptr: tl.tensor = None, +): + """ + Consumer kernel: Waits for tiles from all ranks, accumulates, and scatters results. + Each rank only processes a subset of tiles for reduction based on DISTRIBUTION. + """ + pid = tl.program_id(0) + + if NUM_XCDS != 1: + pid = (pid % NUM_XCDS) * (COMM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + + acc_dtype = tl.float32 if C_global.type.element_ty != tl.int8 else tl.int32 + + # Determine which tiles this rank is responsible for reducing + if DISTRIBUTION == 0: + # Striding: rank reduces tiles cur_rank, cur_rank + world_size, ... + tiles_per_rank = tl.cdiv(total_tiles, world_size) + start_tile = cur_rank + stride = world_size + else: + # Block: rank reduces continuous block of tiles + tiles_per_rank = tl.cdiv(total_tiles, world_size) + start_tile = cur_rank * tiles_per_rank + stride = 1 + + # Each SM processes tiles assigned to this rank for reduction + for tile_offset in range(pid, tiles_per_rank, COMM_SMS): + tile_id = start_tile + tile_offset * stride + + # Boundary check + if tile_id >= total_tiles: + break + + # Wait for all ranks to produce this tile (all ranks have partials) + # Local tile + while tl.load(locks + tile_id, cache_modifier=".cv", volatile=True) != 1: + pass + + # Wait for remote ranks + for remote_rank in range(world_size): + if remote_rank != cur_rank: + while ( + iris.atomic_cas( + tile_ready + tile_id, 0, 0, cur_rank, remote_rank, heap_bases, sem="acquire", scope="sys" + ) + != 1 + ): + pass + + # Map tile_id to (pid_m, pid_n) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + # Compute offsets + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + mask = (rm[:, None] < M) & (rn[None, :] < N) + local_offset = rm[:, None] * stride_cm_local + rn[None, :] * stride_cn_local + + # Accumulate from all ranks + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for remote_rank in range(world_size): + partial = iris.load(local_C + local_offset, cur_rank, remote_rank, heap_bases, mask=mask) + acc += partial.to(acc_dtype) + + # Convert to output type + c_out = acc.to(C_global.type.element_ty) + + # Scatter to all ranks + global_offset = rm[:, None] * stride_cm_global + rn[None, :] * stride_cn_global + for remote_rank in range(world_size): + if remote_rank == cur_rank: + tl.store(C_global + global_offset, c_out, mask=mask) + else: + iris.store(C_global + global_offset, c_out, cur_rank, remote_rank, heap_bases, mask=mask) diff --git a/examples/17_gemm_one_shot_all_reduce_pc/matmul_wrapper.py b/examples/17_gemm_one_shot_all_reduce_pc/matmul_wrapper.py new file mode 100644 index 00000000..48155e10 --- /dev/null +++ b/examples/17_gemm_one_shot_all_reduce_pc/matmul_wrapper.py @@ -0,0 +1,174 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import random +import sys +import os + +from gemm_one_shot_all_reduce_pc import persistent_gemm + +from examples.common.utils import is_triton_interpret_set +import iris + +gemm_kernel = persistent_gemm + + +class matmul(torch.autograd.Function): + _debug = True + _registers = None + _spills = None + + _num_xcds = iris.hip.get_num_xcc() + + @staticmethod + def set_debug(debug: bool): + matmul._debug = debug + + @staticmethod + def get_matmul_registers(): + if matmul._debug: + return matmul._registers + else: + raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") + + @staticmethod + def get_matmul_spills(): + if matmul._debug: + return matmul._spills + else: + raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") + + @staticmethod + def _call( + a: torch.Tensor, + b: torch.Tensor, + local_C: torch.Tensor, + C_global: torch.Tensor, + bias: torch.Tensor, + locks: torch.Tensor, + tile_ready: torch.Tensor, + rank: int, + world_size: int, + gemm_sms: int, + BLK_M: int, + BLK_N: int, + BLK_K: int, + gsize_m: int, + heap_bases_ptr: torch.Tensor = None, + arch: str = "gfx942", + COLLECT_TIMESTAMPS: bool = False, + mm_begin_timestamp: torch.Tensor = None, + mm_end_timestamp: torch.Tensor = None, + ): + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + + num_xcds = matmul._num_xcds + + # TODO: Use arch-specific values. + num_stages = 2 + num_warps = 8 + waves_per_eu = 0 + mfma = 16 + kpack = 1 + + total_blocks_M = triton.cdiv(M, BLK_M) + total_blocks_N = triton.cdiv(N, BLK_N) + total_tiles = total_blocks_M * total_blocks_N + even_k = K % BLK_K == 0 + use_bias = False + + # compute grid (work to do per SM on the first wave) + stride_bias = bias.stride(0) if use_bias else 0 + kk = gemm_kernel[(gemm_sms,)]( + a, + b, + local_C, + bias, + locks, + tile_ready, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + local_C.stride(0), + local_C.stride(1), + stride_bias, + BLOCK_SIZE_M=BLK_M, + BLOCK_SIZE_N=BLK_N, + BLOCK_SIZE_K=BLK_K, + GROUP_SIZE_M=gsize_m, + GEMM_SMS=gemm_sms, + NUM_XCDS=num_xcds, + BIAS=use_bias, + EVEN_K=even_k, + num_stages=num_stages, + num_warps=num_warps, + waves_per_eu=waves_per_eu, + matrix_instr_nonkdim=mfma, + kpack=kpack, + heap_bases=heap_bases_ptr, + cur_rank=rank, + world_size=world_size, + COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + mm_begin_timestamp_ptr=mm_begin_timestamp, + mm_end_timestamp_ptr=mm_end_timestamp, + ) + + matmul._registers = kk.n_regs + matmul._spills = kk.n_spills + + return local_C + + @staticmethod + def forward( + ctx, + a: torch.Tensor, + b: torch.Tensor, + local_C: torch.Tensor, + C_global: torch.Tensor, + bias: torch.Tensor, + locks: torch.Tensor, + tile_ready: torch.Tensor, + rank: int, + world_size: int, + gemm_sms: int, + BLK_M: int, + BLK_N: int, + BLK_K: int, + gsize_m: int, + heap_bases_ptr: torch.Tensor = None, + arch: str = "gfx942", + COLLECT_TIMESTAMPS: bool = False, + mm_begin_timestamp: torch.Tensor = None, + mm_end_timestamp: torch.Tensor = None, + ): + matmul._call( + a=a, + b=b, + local_C=local_C, + C_global=C_global, + bias=bias, + locks=locks, + tile_ready=tile_ready, + rank=rank, + world_size=world_size, + gemm_sms=gemm_sms, + BLK_M=BLK_M, + BLK_N=BLK_N, + BLK_K=BLK_K, + gsize_m=gsize_m, + heap_bases_ptr=heap_bases_ptr, + arch=arch, + COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + mm_begin_timestamp=mm_begin_timestamp, + mm_end_timestamp=mm_end_timestamp, + ) + return local_C diff --git a/examples/21_gemm_one_shot_all_reduce_independent/benchmark.py b/examples/21_gemm_one_shot_all_reduce_independent/benchmark.py new file mode 100644 index 00000000..fb45922c --- /dev/null +++ b/examples/21_gemm_one_shot_all_reduce_independent/benchmark.py @@ -0,0 +1,475 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import triton +import random +import sys +import os +import argparse +import json +import csv + +from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set +from examples.common.validation import validate_gemm, validate_all_reduce + +import iris + +from matmul_wrapper import matmul +from gemm_one_shot_all_reduce_independent import persistent_all_reduce + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Parse matrix dimensions and configuration.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=8192, help="Number of rows in matrix A (GEMM)") + parser.add_argument("-n", type=int, default=4608, help="Number of columns in matrix B (GEMM)") + parser.add_argument("-k", type=int, default=36864, help="Common dimension between matrices A and B (GEMM)") + parser.add_argument("--m_comm", type=int, default=None, help="Number of rows for all-reduce tensor (defaults to m)") + parser.add_argument( + "--n_comm", type=int, default=None, help="Number of columns for all-reduce tensor (defaults to n)" + ) + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-t", "--trace_tiles", action="store_true", help="Enable tile-tracing mode") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "int8", "bf16"], + help="Datatype of computation", + ) + parser.add_argument( + "--output_file", + type=str, + default="log.json", + help="Output file", + ) + parser.add_argument("--BLK_M", type=int, default=256, help="Block size M") + parser.add_argument("--BLK_N", type=int, default=64, help="Block size N") + parser.add_argument("--BLK_K", type=int, default=64, help="Block size K") + parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter") + parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") + parser.add_argument("--gemm_sms", type=int, default=256, help="Number of SMs for GEMM kernel") + parser.add_argument("--comm_sms", type=int, default=48, help="Number of SMs for All-Reduce kernel") + parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") + parser.add_argument( + "--csv", + type=str, + default=None, + help="Path to CSV file with configurations (columns: m, n, k, datatype, blk_m, blk_n, blk_k, gemm_sms, comm_sms)", + ) + parser.add_argument( + "--only_gemm", + action="store_true", + help="Run only GEMM operation (cannot be used with --only_comm)", + ) + parser.add_argument( + "--only_comm", + action="store_true", + help="Run only communication (all-reduce) operation (cannot be used with --only_gemm)", + ) + parser.add_argument( + "--distribution", + type=int, + default=0, + choices=[0, 1], + help="Distribution mode for all-reduce: 0=striding, 1=block", + ) + + args = vars(parser.parse_args()) + + # Validate mutually exclusive flags + if args["only_gemm"] and args["only_comm"]: + parser.error("--only_gemm and --only_comm cannot be used together") + + return args + + +def load_configs_from_csv(csv_path): + """Load configurations from a CSV file. + + Expected CSV format: + m,n,k,datatype,blk_m,blk_n,blk_k,gemm_sms,comm_sms + 8192,4608,36864,fp16,128,128,64,256,48 + 8192,4096,12288,fp32,256,128,64,256,48 + ... + + Args: + csv_path: Path to the CSV file + + Returns: + List of configuration dictionaries + """ + configs = [] + with open(csv_path, "r") as f: + reader = csv.DictReader(f) + for row in reader: + config = { + "m": int(row["m"]), + "n": int(row["n"]), + "k": int(row["k"]), + "datatype": row["datatype"], + "BLK_M": int(row["blk_m"]), + "BLK_N": int(row["blk_n"]), + "BLK_K": int(row["blk_k"]), + "gemm_sms": int(row["gemm_sms"]), + "comm_sms": int(row["comm_sms"]), + } + configs.append(config) + return configs + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + shmem = iris.iris(args["heap_size"]) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + cu_count = shmem.get_cu_count() + + # GEMM + datatype = torch.float32 + if args["datatype"] == "fp16": + datatype = torch.float16 + elif args["datatype"] == "fp32": + datatype = torch.float32 + elif args["datatype"] == "int8": + datatype = torch.int8 + elif args["datatype"] == "bf16": + datatype = torch.bfloat16 + else: + print("Unknown datatype.") + exit(1) + + # Set default values for all-reduce dimensions if not provided + if args["m_comm"] is None: + args["m_comm"] = args["m"] + if args["n_comm"] is None: + args["n_comm"] = args["n"] + + A = shmem.randn(args["m"], args["k"], device="cuda", dtype=datatype) + B = shmem.randn(args["n"], args["k"], device="cuda", dtype=datatype).T + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + + local_A = A + local_B = B + + for key, value in args.items(): + json_writer.add_field(key, value) + + C = shmem.zeros((args["m"], args["n"]), device="cuda", dtype=A.dtype) + + # Create all-reduce tensors (independent from GEMM) + # Each rank has a value of rank+1 + all_reduce_local = shmem.full((args["m_comm"], args["n_comm"]), rank + 1.0, device="cuda", dtype=datatype) + all_reduce_result = shmem.zeros((args["m_comm"], args["n_comm"]), device="cuda", dtype=datatype) + + total_blocks_M = triton.cdiv(args["m"], args["BLK_M"]) + total_blocks_N = triton.cdiv(args["n"], args["BLK_N"]) + total_tiles = total_blocks_M * total_blocks_N + + bias = None + + num_xcds = iris.hip.get_num_xcc() + + # Independent streams for GEMM and all-reduce + gemm_stream = torch.cuda.Stream() + comm_stream = torch.cuda.Stream() + + json_writer.add_field("gemm_sms", args["gemm_sms"]) + json_writer.add_field("comm_sms", args["comm_sms"]) + + kernel_timing = { + "gemm": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + "communication": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + } + + # Allocate Timestamps + timestamps = Timestamps(num_tiles=total_tiles) + + def run_experiment(): + nonlocal C + nonlocal all_reduce_result + nonlocal kernel_timing + + shmem.barrier() + + if args["trace_tiles"]: + timestamps.reset() + shmem.barrier() + + # Determine what to run based on flags + run_gemm = not args["only_comm"] + run_comm = not args["only_gemm"] + + # Set NVTX range name based on what we're running + if run_gemm and run_comm: + nvtx_name = "GEMM + All-Reduce (Independent)" + elif run_gemm: + nvtx_name = "GEMM" + else: + nvtx_name = "All-Reduce" + + torch.cuda.nvtx.range_push(nvtx_name) + + if run_gemm: + torch.cuda.nvtx.range_push("GEMM") + with torch.cuda.stream(gemm_stream): + kernel_timing["gemm"]["start_event"].record() + C = matmul.apply( + local_A, + local_B, + C, + bias, + rank, + world_size, + args["gemm_sms"], + args["BLK_M"], + args["BLK_N"], + args["BLK_K"], + args["gsize_m"], + shmem.get_heap_bases(), + "gfx942", + args["trace_tiles"], + timestamps.mm_begin_timestamp, + timestamps.mm_end_timestamp, + ) + kernel_timing["gemm"]["end_event"].record() + kernel_timing["gemm"]["experiments"] += 1 + torch.cuda.nvtx.range_pop() + + if run_comm: + torch.cuda.nvtx.range_push("All-Reduce") + with torch.cuda.stream(comm_stream): + kernel_timing["communication"]["start_event"].record() + persistent_all_reduce[(args["comm_sms"],)]( + all_reduce_local, + all_reduce_result, + args["m_comm"], + args["n_comm"], + all_reduce_local.stride(0), + all_reduce_local.stride(1), + all_reduce_result.stride(0), + all_reduce_result.stride(1), + args["BLK_M"], + args["BLK_N"], + args["gsize_m"], + args["comm_sms"], + num_xcds, + shmem.get_heap_bases(), + rank, + world_size, + args["distribution"], + args["trace_tiles"], + timestamps.mm_begin_timestamp, + timestamps.mm_end_timestamp, + ) + kernel_timing["communication"]["end_event"].record() + kernel_timing["communication"]["experiments"] += 1 + torch.cuda.nvtx.range_pop() + + shmem.barrier() + + # Update timing for operations that were run + if run_gemm: + ms = kernel_timing["gemm"]["start_event"].elapsed_time(kernel_timing["gemm"]["end_event"]) + kernel_timing["gemm"]["ms"] += ms + if run_comm: + ms = kernel_timing["communication"]["start_event"].elapsed_time(kernel_timing["communication"]["end_event"]) + kernel_timing["communication"]["ms"] += ms + + torch.cuda.nvtx.range_pop() + + # Synchronize across all GPUs + shmem.barrier() + + # Warmup + run_experiment() + + shmem.barrier() + + for k in ["gemm", "communication"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + if args["validate"]: + # Ensure all GPU kernels have completed before validation + torch.cuda.synchronize() + shmem.barrier() + + shmem.info("Validating...") + matmul.set_debug(True) + + # Determine what to validate based on flags + validate_gemm_op = not args["only_comm"] + validate_comm_op = not args["only_gemm"] + + success_gemm = True + success_comm = True + + # Validate GEMM result if it was run + if validate_gemm_op: + shmem.info("Validating GEMM operation...") + success_gemm = validate_gemm(A, B, C, shmem) + passed_str = "passed" if success_gemm else "failed" + shmem.info(f"GEMM validation {passed_str}.") + # Wait for all to finish GEMM validation + shmem.barrier() + + # Validate all-reduce result if it was run + if validate_comm_op: + shmem.info("Validating all-reduce operation...") + success_comm = validate_all_reduce(all_reduce_local, all_reduce_result, shmem) + passed_str = "passed" if success_comm else "failed" + shmem.info(f"All-reduce validation {passed_str}.") + # Wait for all to finish communication validation + shmem.barrier() + + # Overall success + success = success_gemm and success_comm + overall_str = "passed" if success else "failed" + shmem.info(f"Overall validation {overall_str}.") + + # Wait for all to finish validation + shmem.barrier() + + json_writer.add_field("success", success) + if validate_gemm_op: + json_writer.add_field("success_gemm", success_gemm) + if validate_comm_op: + json_writer.add_field("success_comm", success_comm) + + if validate_gemm_op and not is_triton_interpret_set(): + gemm_registers = matmul.get_matmul_registers() + gemm_spills = matmul.get_matmul_spills() + + json_writer.add_field("gemm_registers", gemm_registers) + json_writer.add_field("gemm_spills", gemm_spills) + + shmem.info("Validation completed") + + if args["benchmark"]: + matmul.set_debug(False) + shmem.info("Benchmarking...") + perf = lambda ms: 2 * args["m"] * args["n"] * args["k"] * 1e-12 / (ms * 1e-3) + triton_ms = iris.do_bench(run_experiment, shmem.barrier) + triton_tflops = perf(triton_ms) + + # Determine what was run based on flags + run_gemm = not args["only_comm"] + run_comm = not args["only_gemm"] + + if run_gemm and run_comm: + op_string = "tile matmul + one_shot_all_reduce (independent)" + elif run_gemm: + op_string = "tile matmul" + else: + op_string = "one_shot_all_reduce" + + shmem.info(f"{op_string} (total_tiles={total_tiles}): {triton_ms:.3f} ms {triton_tflops:.3f} tflops") + + json_writer.add_field("tflops", triton_tflops) + json_writer.add_field("total_ms", triton_ms) + + # Only add timing for operations that were run + if run_gemm: + json_writer.add_field("gemm_ms", kernel_timing["gemm"]["ms"] / kernel_timing["gemm"]["experiments"]) + json_writer.add_field("gemm_experiments", kernel_timing["gemm"]["experiments"]) + if run_comm: + json_writer.add_field( + "communication_ms", kernel_timing["communication"]["ms"] / kernel_timing["communication"]["experiments"] + ) + json_writer.add_field("communication_experiments", kernel_timing["communication"]["experiments"]) + + # Wait for all to finish benchmarking + shmem.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + if args["trace_tiles"] and rank == 0: + gpu_freq = iris.hip.get_wall_clock_rate(rank) * 1e-3 + algo_string = "one_shot_all_reduce_independent" + filename = f"gemm_tiles_{algo_string}_trace_rank{rank}.json" + timestamps.to_json(filename, gpu_freq) + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + num_ranks = args["num_ranks"] + init_url = "tcp://127.0.0.1:29500" + + # If CSV is provided, run sweep with configurations from CSV + if args["csv"] is not None: + configs = load_configs_from_csv(args["csv"]) + print(f"Loaded {len(configs)} configurations from {args['csv']}") + + for i, config in enumerate(configs): + # Create a copy of args and update with CSV config + run_args = args.copy() + run_args.update(config) + + print( + f"\nRunning configuration {i + 1}/{len(configs)}:\n" + + "\n".join( + f"\t{k}={config[k]}" + for k in ["m", "n", "k", "datatype", "BLK_M", "BLK_N", "BLK_K", "gemm_sms", "comm_sms"] + ) + ) + # Generate unique output filename for this configuration + base_name, ext = os.path.splitext(args["output_file"]) + run_args["output_file"] = ( + f"{base_name}_m{config['m']}_n{config['n']}_k{config['k']}_{config['datatype']}_{config['BLK_M']}_{config['BLK_N']}_{config['BLK_K']}_{config['gemm_sms']}_{config['comm_sms']}{ext}" + ) + + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, run_args), + nprocs=num_ranks, + join=True, + ) + else: + # Single run with command line arguments + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/21_gemm_one_shot_all_reduce_independent/example_config.csv b/examples/21_gemm_one_shot_all_reduce_independent/example_config.csv new file mode 100644 index 00000000..d24f04f5 --- /dev/null +++ b/examples/21_gemm_one_shot_all_reduce_independent/example_config.csv @@ -0,0 +1,3 @@ +m,n,k,datatype,blk_m,blk_n,blk_k,gemm_sms,comm_sms +8192,4608,36864,fp16,256,64,64,256,48 +4096,4096,12288,fp32,128,128,64,240,56 diff --git a/examples/21_gemm_one_shot_all_reduce_independent/gemm_one_shot_all_reduce_independent.py b/examples/21_gemm_one_shot_all_reduce_independent/gemm_one_shot_all_reduce_independent.py new file mode 100644 index 00000000..40a24213 --- /dev/null +++ b/examples/21_gemm_one_shot_all_reduce_independent/gemm_one_shot_all_reduce_independent.py @@ -0,0 +1,227 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import triton +import triton.language as tl +from examples.common.utils import read_realtime + +import sys +import os + +import iris + + +@triton.jit() +def persistent_gemm( + A, + B, + C, + bias_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bias, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + GEMM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + BIAS: tl.constexpr, + EVEN_K: tl.constexpr, + heap_bases: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + COLLECT_TIMESTAMPS: tl.constexpr = False, + mm_begin_timestamp_ptr: tl.tensor = None, + mm_end_timestamp_ptr: tl.tensor = None, +): + """ + Independent GEMM operation that works on its own data. + No synchronization with all-reduce operation. + """ + pid = tl.program_id(0) + + if NUM_XCDS != 1: + pid = (pid % NUM_XCDS) * (GEMM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + + acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + + for tile_id in range(pid, total_tiles, GEMM_SMS): + if COLLECT_TIMESTAMPS: + timestamp = read_realtime() + tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + + rk = tl.arange(0, BLOCK_SIZE_K) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + loop_k = tl.cdiv(K, BLOCK_SIZE_K) + if not EVEN_K: + loop_k -= 1 + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for k in range(0, loop_k): + a = tl.load(tl.multiple_of(A_BASE, (1, 16))) + b = tl.load(tl.multiple_of(B_BASE, (16, 1))) + acc += tl.dot(a, b) + A_BASE += BLOCK_SIZE_K * stride_ak + B_BASE += BLOCK_SIZE_K * stride_bk + + if not EVEN_K: + k = loop_k + rk = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + A_BASE = tl.multiple_of(A_BASE, (1, 16)) + B_BASE = tl.multiple_of(B_BASE, (16, 1)) + a = tl.load(A_BASE, mask=rk[None, :] < K, other=0.0) + b = tl.load(B_BASE, mask=rk[:, None] < K, other=0.0) + acc += tl.dot(a, b) + + # Accumulator registers with C results + c = acc.to(C.type.element_ty) + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + + # Add compiler hints + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + # Define the C-mask (BLOCK_SIZE_M, 1) x (1, BLOCK_SIZE_N) + sub_mask = (rm[:, None] < M) & (rn[None, :] < N) + + # Calculate the local offset of C. + local_offset = rm[:, None] * stride_cm + rn[None, :] * stride_cn + + # Timestamp for GEMM before store + if COLLECT_TIMESTAMPS: + timestamp = read_realtime() + tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) + + tl.store(C + local_offset, c, mask=sub_mask, cache_modifier=".wt") + + +@triton.jit() +def persistent_all_reduce( + local_data, + global_result, + M, + N, + stride_local_m, + stride_local_n, + stride_global_m, + stride_global_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + COMM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + heap_bases: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + DISTRIBUTION: tl.constexpr, # 0 for striding, 1 for block + COLLECT_TIMESTAMPS: tl.constexpr = False, + mm_begin_timestamp_ptr: tl.tensor = None, + mm_end_timestamp_ptr: tl.tensor = None, +): + """ + Independent all-reduce operation that works on its own data. + Each rank only reduces unique tiles based on distribution and scatters results. + """ + pid = tl.program_id(0) + + if NUM_XCDS != 1: + pid = (pid % NUM_XCDS) * (COMM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + + acc_dtype = tl.float32 if global_result.type.element_ty != tl.int8 else tl.int32 + + # Determine which tiles this rank is responsible for reducing + if DISTRIBUTION == 0: + # Striding: rank reduces tiles cur_rank, cur_rank + world_size, ... + tiles_per_rank = tl.cdiv(total_tiles, world_size) + start_tile = cur_rank + stride = world_size + else: + # Block: rank reduces continuous block of tiles + tiles_per_rank = tl.cdiv(total_tiles, world_size) + start_tile = cur_rank * tiles_per_rank + stride = 1 + + # Each SM processes tiles assigned to this rank for reduction + for tile_offset in range(pid, tiles_per_rank, COMM_SMS): + tile_id = start_tile + tile_offset * stride + + # Boundary check + if tile_id >= total_tiles: + break + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + # Compute offsets + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + sub_mask = (rm[:, None] < M) & (rn[None, :] < N) + local_offset = rm[:, None] * stride_local_m + rn[None, :] * stride_local_n + + # Accumulate from all ranks + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for remote_rank in range(world_size): + partial = iris.load(local_data + local_offset, cur_rank, remote_rank, heap_bases, mask=sub_mask) + acc += partial.to(acc_dtype) + + # Convert to output type + result = acc.to(global_result.type.element_ty) + + # Scatter to all ranks + global_offset = rm[:, None] * stride_global_m + rn[None, :] * stride_global_n + for remote_rank in range(world_size): + if remote_rank == cur_rank: + tl.store(global_result + global_offset, result, mask=sub_mask, cache_modifier=".wt") + else: + iris.store(global_result + global_offset, result, cur_rank, remote_rank, heap_bases, mask=sub_mask) diff --git a/examples/21_gemm_one_shot_all_reduce_independent/matmul_wrapper.py b/examples/21_gemm_one_shot_all_reduce_independent/matmul_wrapper.py new file mode 100644 index 00000000..9691d8c0 --- /dev/null +++ b/examples/21_gemm_one_shot_all_reduce_independent/matmul_wrapper.py @@ -0,0 +1,163 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import random +import sys +import os + +from gemm_one_shot_all_reduce_independent import persistent_gemm + +from examples.common.utils import is_triton_interpret_set +import iris + +gemm_kernel = persistent_gemm + + +class matmul(torch.autograd.Function): + _debug = True + _registers = None + _spills = None + + _num_xcds = iris.hip.get_num_xcc() + + @staticmethod + def set_debug(debug: bool): + matmul._debug = debug + + @staticmethod + def get_matmul_registers(): + if matmul._debug: + return matmul._registers + else: + raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") + + @staticmethod + def get_matmul_spills(): + if matmul._debug: + return matmul._spills + else: + raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") + + @staticmethod + def _call( + a: torch.Tensor, + b: torch.Tensor, + C: torch.Tensor, + bias: torch.Tensor, + rank: int, + world_size: int, + num_sms: int, + BLK_M: int, + BLK_N: int, + BLK_K: int, + gsize_m: int, + heap_bases_ptr: torch.Tensor = None, + arch: str = "gfx942", + COLLECT_TIMESTAMPS: bool = False, + mm_begin_timestamp: torch.Tensor = None, + mm_end_timestamp: torch.Tensor = None, + ): + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + + num_xcds = matmul._num_xcds + + # TODO: Use arch-specific values. + num_stages = 2 + num_warps = 8 + waves_per_eu = 0 + mfma = 16 + kpack = 1 + + total_blocks_M = triton.cdiv(M, BLK_M) + total_blocks_N = triton.cdiv(N, BLK_N) + total_tiles = total_blocks_M * total_blocks_N + even_k = K % BLK_K == 0 + use_bias = False + + # compute grid (work to do per SM on the first wave) + stride_bias = bias.stride(0) if use_bias else 0 + kk = gemm_kernel[(num_sms,)]( + a, + b, + C, + bias, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + C.stride(0), + C.stride(1), + stride_bias, + BLOCK_SIZE_M=BLK_M, + BLOCK_SIZE_N=BLK_N, + BLOCK_SIZE_K=BLK_K, + GROUP_SIZE_M=gsize_m, + GEMM_SMS=num_sms, + NUM_XCDS=num_xcds, + BIAS=use_bias, + EVEN_K=even_k, + num_stages=num_stages, + num_warps=num_warps, + waves_per_eu=waves_per_eu, + matrix_instr_nonkdim=mfma, + kpack=kpack, + heap_bases=heap_bases_ptr, + cur_rank=rank, + world_size=world_size, + COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + mm_begin_timestamp_ptr=mm_begin_timestamp, + mm_end_timestamp_ptr=mm_end_timestamp, + ) + + matmul._registers = kk.n_regs + matmul._spills = kk.n_spills + + return C + + @staticmethod + def forward( + ctx, + a: torch.Tensor, + b: torch.Tensor, + C: torch.Tensor, + bias: torch.Tensor, + rank: int, + world_size: int, + num_sms: int, + BLK_M: int, + BLK_N: int, + BLK_K: int, + gsize_m: int, + heap_bases_ptr: torch.Tensor = None, + arch: str = "gfx942", + COLLECT_TIMESTAMPS: bool = False, + mm_begin_timestamp: torch.Tensor = None, + mm_end_timestamp: torch.Tensor = None, + ): + matmul._call( + a=a, + b=b, + C=C, + bias=bias, + rank=rank, + world_size=world_size, + num_sms=num_sms, + BLK_M=BLK_M, + BLK_N=BLK_N, + BLK_K=BLK_K, + gsize_m=gsize_m, + heap_bases_ptr=heap_bases_ptr, + arch=arch, + COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + mm_begin_timestamp=mm_begin_timestamp, + mm_end_timestamp=mm_end_timestamp, + ) + return C diff --git a/examples/README.md b/examples/README.md index 82c812cf..c7a0b223 100644 --- a/examples/README.md +++ b/examples/README.md @@ -28,7 +28,9 @@ This directory contains various algorithm implementations for distributed comput - [`14_all_gather_gemm`](14_all_gather_gemm): Fused All-Gather + GEMM with Pull and Push models - [`15_gemm_all_reduce_ring_based`](15_gemm_all_reduce_ring_based): Matrix multiplication with ring-based all-reduce - [`16_all_reduce_ring_based`](16_all_reduce_ring_based): Ring-based all-reduce operation +- [`17_gemm_one_shot_all_reduce_pc`](17_gemm_one_shot_all_reduce_pc): Matrix multiplication with one-shot all-reduce using producer-consumer pattern with two distribution modes (striding and block) - [`20_gemm_all_scatter_independent`](20_gemm_all_scatter_independent): Independent GEMM and all-scatter operations with support for CSV input configurations +- [`21_gemm_one_shot_all_reduce_independent`](21_gemm_one_shot_all_reduce_independent): Independent GEMM and all-reduce operations with support for CSV input configurations and selective execution ### Utilities - [`benchmark`](benchmark): Benchmarking utilities and performance testing tools @@ -95,17 +97,41 @@ python examples/20_gemm_all_scatter_independent/benchmark.py --benchmark --valid # Independent GEMM and all-scatter - sweep with CSV configurations python examples/20_gemm_all_scatter_independent/benchmark.py --benchmark --validate --num_ranks 8 --csv dataset/gemm_config.csv + +# One-shot all-reduce with producer-consumer pattern - striding distribution +python examples/17_gemm_one_shot_all_reduce_pc/benchmark.py --benchmark --validate --num_ranks 8 --distribution 0 + +# One-shot all-reduce with producer-consumer pattern - block distribution +python examples/17_gemm_one_shot_all_reduce_pc/benchmark.py --benchmark --validate --num_ranks 8 --distribution 1 + +# Independent GEMM and all-reduce - run both operations +python examples/21_gemm_one_shot_all_reduce_independent/benchmark.py --benchmark --validate --num_ranks 8 + +# Independent GEMM and all-reduce - run only GEMM +python examples/21_gemm_one_shot_all_reduce_independent/benchmark.py --only_gemm --validate --num_ranks 8 + +# Independent GEMM and all-reduce - run only all-reduce +python examples/21_gemm_one_shot_all_reduce_independent/benchmark.py --only_comm --validate --num_ranks 8 + +# Independent GEMM and all-reduce - sweep with CSV configurations +python examples/21_gemm_one_shot_all_reduce_independent/benchmark.py --benchmark --num_ranks 8 --csv examples/21_gemm_one_shot_all_reduce_independent/example_config.csv ``` ### CSV Configuration Format -Example 20 supports loading multiple configurations from a CSV file using the `--csv` argument. The CSV file should have the following format: +**Note:** Only examples 20 and 21 support loading multiple configurations from a CSV file using the `--csv` argument. Example 17 does **not** support CSV configuration files. + +**Example 20 CSV format:** +```csv +m,n,k,datatype,blk_m,blk_n,blk_k,gemm_sms,comm_sms +8192,4608,36864,fp16,256,64,64,256,48 +8192,4096,12288,fp32,256,128,64,256,48 +4096,4096,8192,bf16,128,128,64,240,56 +``` +**Example 21 CSV format:** ```csv -m,n,k,datatype -8192,4608,36864,fp16 -8192,4096,12288,fp32 -8192,3584,14336,bf16 -4096,4096,8192,fp16 -2048,2048,4096,fp16 +m,n,k,datatype,blk_m,blk_n,blk_k,gemm_sms,comm_sms +8192,4608,36864,fp16,256,64,64,256,48 +4096,4096,12288,fp32,128,128,64,240,56 ``` diff --git a/examples/common/validation.py b/examples/common/validation.py index 47802b60..8046e92d 100644 --- a/examples/common/validation.py +++ b/examples/common/validation.py @@ -72,3 +72,40 @@ def validate_all_scatter(local_tensor, global_tensor, shmem, atol=1): return False return True + + +def validate_all_reduce(local_tensor, global_tensor, shmem, atol=1): + """ + Validate all-reduce operation where each rank's local tensor is reduced and result is on all ranks. + + Args: + local_tensor: The local tensor on this rank before all-reduce + global_tensor: The result tensor after all-reduce (should contain sum of all ranks) + shmem: Iris shmem object + atol: Absolute tolerance for comparison + + Returns: + bool: True if validation passes, False otherwise + """ + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Compute expected result: sum of all ranks' local tensors + # Each rank has value (rank+1), so sum should be 1+2+...+world_size = world_size*(world_size+1)/2 + expected = torch.full_like(local_tensor, world_size * (world_size + 1) / 2.0) + + diff_mask = ~torch.isclose(global_tensor, expected, atol=atol) + breaking_indices = torch.nonzero(diff_mask, as_tuple=False) + + if not torch.allclose(global_tensor, expected, atol=atol): + max_diff = (global_tensor - expected).abs().max().item() + shmem.info(f"All-reduce validation: Max absolute difference: {max_diff}") + for idx in breaking_indices: + idx = tuple(idx.tolist()) + computed_val = global_tensor[idx] + expected_val = expected[idx] + shmem.error(f"All-reduce mismatch at rank {rank}, index {idx}: got={computed_val}, expected={expected_val}") + break + return False + + return True diff --git a/iris/experimental/iris_gluon.py b/iris/experimental/iris_gluon.py index 461be9c1..93c706c3 100644 --- a/iris/experimental/iris_gluon.py +++ b/iris/experimental/iris_gluon.py @@ -78,6 +78,7 @@ def __init__(self, cur_rank, num_ranks, heap_bases): self.num_ranks = num_ranks self.heap_bases = heap_bases + @staticmethod @gluon.jit def initialize(context_tensor): """ From 79b6d6cfed92be348afa1cbb9e2a8fbd55c3bec7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 19 Oct 2025 10:01:19 +0000 Subject: [PATCH 4/5] Add README.md files for new examples (15, 16, 17, 20, 21) Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- .../15_gemm_all_reduce_ring_based/README.md | 54 ++++++++++++ examples/16_all_reduce_ring_based/README.md | 53 +++++++++++ .../17_gemm_one_shot_all_reduce_pc/README.md | 73 +++++++++++++++ .../20_gemm_all_scatter_independent/README.md | 71 +++++++++++++++ .../README.md | 88 +++++++++++++++++++ 5 files changed, 339 insertions(+) create mode 100644 examples/15_gemm_all_reduce_ring_based/README.md create mode 100644 examples/16_all_reduce_ring_based/README.md create mode 100644 examples/17_gemm_one_shot_all_reduce_pc/README.md create mode 100644 examples/20_gemm_all_scatter_independent/README.md create mode 100644 examples/21_gemm_one_shot_all_reduce_independent/README.md diff --git a/examples/15_gemm_all_reduce_ring_based/README.md b/examples/15_gemm_all_reduce_ring_based/README.md new file mode 100644 index 00000000..1c8be8bc --- /dev/null +++ b/examples/15_gemm_all_reduce_ring_based/README.md @@ -0,0 +1,54 @@ + + +# Matrix Multiplication with Ring-Based All-Reduce + +This example demonstrates a distributed matrix multiplication (GEMM) operation followed by a ring-based all-reduce communication pattern. The implementation uses a persistent kernel approach where GEMM computation and communication are overlapped. + +The ring-based all-reduce is an efficient collective operation that reduces data across all GPUs by forming a logical ring topology. Each GPU sends data to its neighbor while receiving from the other neighbor, completing the reduction in multiple passes around the ring. + +## Usage + +### Basic Run + +To run the benchmark with default parameters: + +```terminal +python examples/15_gemm_all_reduce_ring_based/benchmark.py --num_ranks 8 +``` + +### Validation + +To verify numerical correctness against a PyTorch reference: + +```terminal +python examples/15_gemm_all_reduce_ring_based/benchmark.py --validate --num_ranks 8 +``` + +### Benchmarking + +To run performance benchmarks: + +```terminal +python examples/15_gemm_all_reduce_ring_based/benchmark.py --benchmark --validate --num_ranks 8 +``` + +### Custom Matrix Dimensions + +You can specify custom matrix dimensions: + +```terminal +python examples/15_gemm_all_reduce_ring_based/benchmark.py --num_ranks 8 -m 4096 -n 4096 -k 4096 +``` + +### Options + +- `-m`: Number of rows in matrix A (default: 8192) +- `-n`: Number of columns in matrix B (default: 4608) +- `-k`: Common dimension between matrices A and B (default: 36864) +- `--datatype`: Data type for computation (`fp16`, `fp32`, `bf16`, `int8`) (default: fp16) +- `--validate`: Enable validation mode +- `--benchmark`: Enable benchmarking mode +- `--BLK_M`, `--BLK_N`, `--BLK_K`: Block sizes for tiling (defaults: 128, 128, 64) diff --git a/examples/16_all_reduce_ring_based/README.md b/examples/16_all_reduce_ring_based/README.md new file mode 100644 index 00000000..21380519 --- /dev/null +++ b/examples/16_all_reduce_ring_based/README.md @@ -0,0 +1,53 @@ + + +# Ring-Based All-Reduce + +This example demonstrates a standalone ring-based all-reduce collective operation across multiple GPUs. The ring-based all-reduce is an efficient communication pattern that reduces data across all GPUs by forming a logical ring topology. + +In this pattern, each GPU sends data to its neighbor while receiving from the other neighbor, completing the reduction in multiple passes around the ring. This approach provides excellent bandwidth utilization and scales well with the number of GPUs. + +## Usage + +### Basic Run + +To run the benchmark with default parameters: + +```terminal +python examples/16_all_reduce_ring_based/benchmark.py --num_ranks 8 +``` + +### Validation + +To verify numerical correctness: + +```terminal +python examples/16_all_reduce_ring_based/benchmark.py --validate --num_ranks 8 +``` + +### Benchmarking + +To run performance benchmarks: + +```terminal +python examples/16_all_reduce_ring_based/benchmark.py --benchmark --validate --num_ranks 8 +``` + +### Custom Matrix Dimensions + +You can specify custom dimensions for the data to reduce: + +```terminal +python examples/16_all_reduce_ring_based/benchmark.py --num_ranks 8 -m 8192 -n 4608 +``` + +### Options + +- `-m`: Number of rows in input/output matrix (default: 8192) +- `-n`: Number of columns in input/output matrix (default: 4608) +- `--datatype`: Data type for computation (`fp16`, `fp32`, `bf16`, `int8`) (default: fp16) +- `--validate`: Enable validation mode +- `--benchmark`: Enable benchmarking mode +- `--BLK_M`, `--BLK_N`: Block sizes for tiling (defaults: 128, 128) diff --git a/examples/17_gemm_one_shot_all_reduce_pc/README.md b/examples/17_gemm_one_shot_all_reduce_pc/README.md new file mode 100644 index 00000000..82cfa22d --- /dev/null +++ b/examples/17_gemm_one_shot_all_reduce_pc/README.md @@ -0,0 +1,73 @@ + + +# Matrix Multiplication with One-Shot All-Reduce (Producer-Consumer Pattern) + +This example demonstrates a distributed matrix multiplication (GEMM) operation with a one-shot all-reduce using a producer-consumer pattern. The implementation explores two distinct distribution modes for managing data communication between GPUs. + +## Distribution Modes + +The example supports two distribution strategies: + +### Mode 0: Striding Distribution +Data is distributed in a strided pattern across GPUs, providing fine-grained interleaving of work. + +### Mode 1: Block Distribution +Data is distributed in contiguous blocks across GPUs, providing coarse-grained partitioning of work. + +## Usage + +### Basic Run with Striding Distribution + +```terminal +python examples/17_gemm_one_shot_all_reduce_pc/benchmark.py --num_ranks 8 --distribution 0 +``` + +### Basic Run with Block Distribution + +```terminal +python examples/17_gemm_one_shot_all_reduce_pc/benchmark.py --num_ranks 8 --distribution 1 +``` + +### Validation + +To verify numerical correctness with striding distribution: + +```terminal +python examples/17_gemm_one_shot_all_reduce_pc/benchmark.py --validate --num_ranks 8 --distribution 0 +``` + +To verify with block distribution: + +```terminal +python examples/17_gemm_one_shot_all_reduce_pc/benchmark.py --validate --num_ranks 8 --distribution 1 +``` + +### Benchmarking + +To run performance benchmarks: + +```terminal +python examples/17_gemm_one_shot_all_reduce_pc/benchmark.py --benchmark --validate --num_ranks 8 --distribution 0 +``` + +### Custom Matrix Dimensions + +You can specify custom matrix dimensions: + +```terminal +python examples/17_gemm_one_shot_all_reduce_pc/benchmark.py --num_ranks 8 -m 4096 -n 4096 -k 4096 --distribution 0 +``` + +### Options + +- `-m`: Number of rows in matrix A (default: 8192) +- `-n`: Number of columns in matrix B (default: 4608) +- `-k`: Common dimension between matrices A and B (default: 36864) +- `--distribution`: Distribution mode (0=striding, 1=block) (default: 0) +- `--datatype`: Data type for computation (`fp16`, `fp32`, `bf16`, `int8`) (default: fp16) +- `--validate`: Enable validation mode +- `--benchmark`: Enable benchmarking mode +- `--BLK_M`, `--BLK_N`, `--BLK_K`: Block sizes for tiling (defaults: 256, 64, 64) diff --git a/examples/20_gemm_all_scatter_independent/README.md b/examples/20_gemm_all_scatter_independent/README.md new file mode 100644 index 00000000..602472d3 --- /dev/null +++ b/examples/20_gemm_all_scatter_independent/README.md @@ -0,0 +1,71 @@ + + +# Independent GEMM and All-Scatter Operations + +This example demonstrates independent execution of matrix multiplication (GEMM) and all-scatter communication operations. The implementation uses a bulk synchronous approach where computation and communication can be run separately or together. + +This example supports loading multiple configurations from a CSV file, allowing for automated sweeps across different matrix dimensions and parameters. + +## Usage + +### Basic Run + +To run both GEMM and all-scatter with default parameters: + +```terminal +python examples/20_gemm_all_scatter_independent/benchmark.py --num_ranks 8 +``` + +### Validation + +To verify numerical correctness: + +```terminal +python examples/20_gemm_all_scatter_independent/benchmark.py --validate --num_ranks 8 +``` + +### Benchmarking + +To run performance benchmarks: + +```terminal +python examples/20_gemm_all_scatter_independent/benchmark.py --benchmark --validate --num_ranks 8 +``` + +### CSV Configuration Sweep + +To run a sweep of configurations from a CSV file: + +```terminal +python examples/20_gemm_all_scatter_independent/benchmark.py --benchmark --validate --num_ranks 8 --csv dataset/gemm_config.csv +``` + +The CSV file should have the following format: +```csv +m,n,k,datatype,blk_m,blk_n,blk_k,gemm_sms,comm_sms +8192,4608,36864,fp16,256,64,64,256,48 +8192,4096,12288,fp32,256,128,64,256,48 +4096,4096,8192,bf16,128,128,64,240,56 +``` + +### Custom Matrix Dimensions + +You can specify custom matrix dimensions: + +```terminal +python examples/20_gemm_all_scatter_independent/benchmark.py --num_ranks 8 -m 4096 -n 4096 -k 4096 +``` + +### Options + +- `-m`: Number of rows in matrix A (default: 8192) +- `-n`: Number of columns in matrix B (default: 4608) +- `-k`: Common dimension between matrices A and B (default: 36864) +- `--datatype`: Data type for computation (`fp16`, `fp32`, `bf16`, `int8`) (default: fp16) +- `--validate`: Enable validation mode +- `--benchmark`: Enable benchmarking mode +- `--csv`: Path to CSV file with multiple configurations +- `--BLK_M`, `--BLK_N`, `--BLK_K`: Block sizes for tiling (defaults: 256, 64, 64) diff --git a/examples/21_gemm_one_shot_all_reduce_independent/README.md b/examples/21_gemm_one_shot_all_reduce_independent/README.md new file mode 100644 index 00000000..e9252722 --- /dev/null +++ b/examples/21_gemm_one_shot_all_reduce_independent/README.md @@ -0,0 +1,88 @@ + + +# Independent GEMM and One-Shot All-Reduce Operations + +This example demonstrates independent execution of matrix multiplication (GEMM) and one-shot all-reduce communication operations. The implementation allows for selective execution of either operation or both together, providing flexibility for testing and benchmarking. + +This example supports loading multiple configurations from a CSV file, enabling automated performance sweeps across different matrix dimensions and parameters. + +## Usage + +### Basic Run (Both Operations) + +To run both GEMM and all-reduce with default parameters: + +```terminal +python examples/21_gemm_one_shot_all_reduce_independent/benchmark.py --num_ranks 8 +``` + +### Run Only GEMM Operation + +To execute only the matrix multiplication: + +```terminal +python examples/21_gemm_one_shot_all_reduce_independent/benchmark.py --only_gemm --num_ranks 8 +``` + +### Run Only All-Reduce Operation + +To execute only the communication operation: + +```terminal +python examples/21_gemm_one_shot_all_reduce_independent/benchmark.py --only_comm --num_ranks 8 +``` + +### Validation + +To verify numerical correctness: + +```terminal +python examples/21_gemm_one_shot_all_reduce_independent/benchmark.py --validate --num_ranks 8 +``` + +### Benchmarking + +To run performance benchmarks: + +```terminal +python examples/21_gemm_one_shot_all_reduce_independent/benchmark.py --benchmark --validate --num_ranks 8 +``` + +### CSV Configuration Sweep + +To run a sweep of configurations from a CSV file: + +```terminal +python examples/21_gemm_one_shot_all_reduce_independent/benchmark.py --benchmark --num_ranks 8 --csv examples/21_gemm_one_shot_all_reduce_independent/example_config.csv +``` + +The CSV file should have the following format: +```csv +m,n,k,datatype,blk_m,blk_n,blk_k,gemm_sms,comm_sms +8192,4608,36864,fp16,256,64,64,256,48 +4096,4096,12288,fp32,128,128,64,240,56 +``` + +### Custom Matrix Dimensions + +You can specify custom matrix dimensions: + +```terminal +python examples/21_gemm_one_shot_all_reduce_independent/benchmark.py --num_ranks 8 -m 4096 -n 4096 -k 4096 +``` + +### Options + +- `-m`: Number of rows in matrix A (default: 8192) +- `-n`: Number of columns in matrix B (default: 4608) +- `-k`: Common dimension between matrices A and B (default: 36864) +- `--only_gemm`: Run only GEMM operation (mutually exclusive with `--only_comm`) +- `--only_comm`: Run only all-reduce operation (mutually exclusive with `--only_gemm`) +- `--datatype`: Data type for computation (`fp16`, `fp32`, `bf16`, `int8`) (default: fp16) +- `--validate`: Enable validation mode +- `--benchmark`: Enable benchmarking mode +- `--csv`: Path to CSV file with multiple configurations +- `--BLK_M`, `--BLK_N`, `--BLK_K`: Block sizes for tiling (defaults: 256, 64, 64) From 5f98d70579f0f3704125051ed5ab3ce86d5effd1 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 19 Oct 2025 10:05:40 +0000 Subject: [PATCH 5/5] Revert individual example README files - examples/README.md already updated Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- .../15_gemm_all_reduce_ring_based/README.md | 54 ------------ examples/16_all_reduce_ring_based/README.md | 53 ----------- .../17_gemm_one_shot_all_reduce_pc/README.md | 73 --------------- .../20_gemm_all_scatter_independent/README.md | 71 --------------- .../README.md | 88 ------------------- 5 files changed, 339 deletions(-) delete mode 100644 examples/15_gemm_all_reduce_ring_based/README.md delete mode 100644 examples/16_all_reduce_ring_based/README.md delete mode 100644 examples/17_gemm_one_shot_all_reduce_pc/README.md delete mode 100644 examples/20_gemm_all_scatter_independent/README.md delete mode 100644 examples/21_gemm_one_shot_all_reduce_independent/README.md diff --git a/examples/15_gemm_all_reduce_ring_based/README.md b/examples/15_gemm_all_reduce_ring_based/README.md deleted file mode 100644 index 1c8be8bc..00000000 --- a/examples/15_gemm_all_reduce_ring_based/README.md +++ /dev/null @@ -1,54 +0,0 @@ - - -# Matrix Multiplication with Ring-Based All-Reduce - -This example demonstrates a distributed matrix multiplication (GEMM) operation followed by a ring-based all-reduce communication pattern. The implementation uses a persistent kernel approach where GEMM computation and communication are overlapped. - -The ring-based all-reduce is an efficient collective operation that reduces data across all GPUs by forming a logical ring topology. Each GPU sends data to its neighbor while receiving from the other neighbor, completing the reduction in multiple passes around the ring. - -## Usage - -### Basic Run - -To run the benchmark with default parameters: - -```terminal -python examples/15_gemm_all_reduce_ring_based/benchmark.py --num_ranks 8 -``` - -### Validation - -To verify numerical correctness against a PyTorch reference: - -```terminal -python examples/15_gemm_all_reduce_ring_based/benchmark.py --validate --num_ranks 8 -``` - -### Benchmarking - -To run performance benchmarks: - -```terminal -python examples/15_gemm_all_reduce_ring_based/benchmark.py --benchmark --validate --num_ranks 8 -``` - -### Custom Matrix Dimensions - -You can specify custom matrix dimensions: - -```terminal -python examples/15_gemm_all_reduce_ring_based/benchmark.py --num_ranks 8 -m 4096 -n 4096 -k 4096 -``` - -### Options - -- `-m`: Number of rows in matrix A (default: 8192) -- `-n`: Number of columns in matrix B (default: 4608) -- `-k`: Common dimension between matrices A and B (default: 36864) -- `--datatype`: Data type for computation (`fp16`, `fp32`, `bf16`, `int8`) (default: fp16) -- `--validate`: Enable validation mode -- `--benchmark`: Enable benchmarking mode -- `--BLK_M`, `--BLK_N`, `--BLK_K`: Block sizes for tiling (defaults: 128, 128, 64) diff --git a/examples/16_all_reduce_ring_based/README.md b/examples/16_all_reduce_ring_based/README.md deleted file mode 100644 index 21380519..00000000 --- a/examples/16_all_reduce_ring_based/README.md +++ /dev/null @@ -1,53 +0,0 @@ - - -# Ring-Based All-Reduce - -This example demonstrates a standalone ring-based all-reduce collective operation across multiple GPUs. The ring-based all-reduce is an efficient communication pattern that reduces data across all GPUs by forming a logical ring topology. - -In this pattern, each GPU sends data to its neighbor while receiving from the other neighbor, completing the reduction in multiple passes around the ring. This approach provides excellent bandwidth utilization and scales well with the number of GPUs. - -## Usage - -### Basic Run - -To run the benchmark with default parameters: - -```terminal -python examples/16_all_reduce_ring_based/benchmark.py --num_ranks 8 -``` - -### Validation - -To verify numerical correctness: - -```terminal -python examples/16_all_reduce_ring_based/benchmark.py --validate --num_ranks 8 -``` - -### Benchmarking - -To run performance benchmarks: - -```terminal -python examples/16_all_reduce_ring_based/benchmark.py --benchmark --validate --num_ranks 8 -``` - -### Custom Matrix Dimensions - -You can specify custom dimensions for the data to reduce: - -```terminal -python examples/16_all_reduce_ring_based/benchmark.py --num_ranks 8 -m 8192 -n 4608 -``` - -### Options - -- `-m`: Number of rows in input/output matrix (default: 8192) -- `-n`: Number of columns in input/output matrix (default: 4608) -- `--datatype`: Data type for computation (`fp16`, `fp32`, `bf16`, `int8`) (default: fp16) -- `--validate`: Enable validation mode -- `--benchmark`: Enable benchmarking mode -- `--BLK_M`, `--BLK_N`: Block sizes for tiling (defaults: 128, 128) diff --git a/examples/17_gemm_one_shot_all_reduce_pc/README.md b/examples/17_gemm_one_shot_all_reduce_pc/README.md deleted file mode 100644 index 82cfa22d..00000000 --- a/examples/17_gemm_one_shot_all_reduce_pc/README.md +++ /dev/null @@ -1,73 +0,0 @@ - - -# Matrix Multiplication with One-Shot All-Reduce (Producer-Consumer Pattern) - -This example demonstrates a distributed matrix multiplication (GEMM) operation with a one-shot all-reduce using a producer-consumer pattern. The implementation explores two distinct distribution modes for managing data communication between GPUs. - -## Distribution Modes - -The example supports two distribution strategies: - -### Mode 0: Striding Distribution -Data is distributed in a strided pattern across GPUs, providing fine-grained interleaving of work. - -### Mode 1: Block Distribution -Data is distributed in contiguous blocks across GPUs, providing coarse-grained partitioning of work. - -## Usage - -### Basic Run with Striding Distribution - -```terminal -python examples/17_gemm_one_shot_all_reduce_pc/benchmark.py --num_ranks 8 --distribution 0 -``` - -### Basic Run with Block Distribution - -```terminal -python examples/17_gemm_one_shot_all_reduce_pc/benchmark.py --num_ranks 8 --distribution 1 -``` - -### Validation - -To verify numerical correctness with striding distribution: - -```terminal -python examples/17_gemm_one_shot_all_reduce_pc/benchmark.py --validate --num_ranks 8 --distribution 0 -``` - -To verify with block distribution: - -```terminal -python examples/17_gemm_one_shot_all_reduce_pc/benchmark.py --validate --num_ranks 8 --distribution 1 -``` - -### Benchmarking - -To run performance benchmarks: - -```terminal -python examples/17_gemm_one_shot_all_reduce_pc/benchmark.py --benchmark --validate --num_ranks 8 --distribution 0 -``` - -### Custom Matrix Dimensions - -You can specify custom matrix dimensions: - -```terminal -python examples/17_gemm_one_shot_all_reduce_pc/benchmark.py --num_ranks 8 -m 4096 -n 4096 -k 4096 --distribution 0 -``` - -### Options - -- `-m`: Number of rows in matrix A (default: 8192) -- `-n`: Number of columns in matrix B (default: 4608) -- `-k`: Common dimension between matrices A and B (default: 36864) -- `--distribution`: Distribution mode (0=striding, 1=block) (default: 0) -- `--datatype`: Data type for computation (`fp16`, `fp32`, `bf16`, `int8`) (default: fp16) -- `--validate`: Enable validation mode -- `--benchmark`: Enable benchmarking mode -- `--BLK_M`, `--BLK_N`, `--BLK_K`: Block sizes for tiling (defaults: 256, 64, 64) diff --git a/examples/20_gemm_all_scatter_independent/README.md b/examples/20_gemm_all_scatter_independent/README.md deleted file mode 100644 index 602472d3..00000000 --- a/examples/20_gemm_all_scatter_independent/README.md +++ /dev/null @@ -1,71 +0,0 @@ - - -# Independent GEMM and All-Scatter Operations - -This example demonstrates independent execution of matrix multiplication (GEMM) and all-scatter communication operations. The implementation uses a bulk synchronous approach where computation and communication can be run separately or together. - -This example supports loading multiple configurations from a CSV file, allowing for automated sweeps across different matrix dimensions and parameters. - -## Usage - -### Basic Run - -To run both GEMM and all-scatter with default parameters: - -```terminal -python examples/20_gemm_all_scatter_independent/benchmark.py --num_ranks 8 -``` - -### Validation - -To verify numerical correctness: - -```terminal -python examples/20_gemm_all_scatter_independent/benchmark.py --validate --num_ranks 8 -``` - -### Benchmarking - -To run performance benchmarks: - -```terminal -python examples/20_gemm_all_scatter_independent/benchmark.py --benchmark --validate --num_ranks 8 -``` - -### CSV Configuration Sweep - -To run a sweep of configurations from a CSV file: - -```terminal -python examples/20_gemm_all_scatter_independent/benchmark.py --benchmark --validate --num_ranks 8 --csv dataset/gemm_config.csv -``` - -The CSV file should have the following format: -```csv -m,n,k,datatype,blk_m,blk_n,blk_k,gemm_sms,comm_sms -8192,4608,36864,fp16,256,64,64,256,48 -8192,4096,12288,fp32,256,128,64,256,48 -4096,4096,8192,bf16,128,128,64,240,56 -``` - -### Custom Matrix Dimensions - -You can specify custom matrix dimensions: - -```terminal -python examples/20_gemm_all_scatter_independent/benchmark.py --num_ranks 8 -m 4096 -n 4096 -k 4096 -``` - -### Options - -- `-m`: Number of rows in matrix A (default: 8192) -- `-n`: Number of columns in matrix B (default: 4608) -- `-k`: Common dimension between matrices A and B (default: 36864) -- `--datatype`: Data type for computation (`fp16`, `fp32`, `bf16`, `int8`) (default: fp16) -- `--validate`: Enable validation mode -- `--benchmark`: Enable benchmarking mode -- `--csv`: Path to CSV file with multiple configurations -- `--BLK_M`, `--BLK_N`, `--BLK_K`: Block sizes for tiling (defaults: 256, 64, 64) diff --git a/examples/21_gemm_one_shot_all_reduce_independent/README.md b/examples/21_gemm_one_shot_all_reduce_independent/README.md deleted file mode 100644 index e9252722..00000000 --- a/examples/21_gemm_one_shot_all_reduce_independent/README.md +++ /dev/null @@ -1,88 +0,0 @@ - - -# Independent GEMM and One-Shot All-Reduce Operations - -This example demonstrates independent execution of matrix multiplication (GEMM) and one-shot all-reduce communication operations. The implementation allows for selective execution of either operation or both together, providing flexibility for testing and benchmarking. - -This example supports loading multiple configurations from a CSV file, enabling automated performance sweeps across different matrix dimensions and parameters. - -## Usage - -### Basic Run (Both Operations) - -To run both GEMM and all-reduce with default parameters: - -```terminal -python examples/21_gemm_one_shot_all_reduce_independent/benchmark.py --num_ranks 8 -``` - -### Run Only GEMM Operation - -To execute only the matrix multiplication: - -```terminal -python examples/21_gemm_one_shot_all_reduce_independent/benchmark.py --only_gemm --num_ranks 8 -``` - -### Run Only All-Reduce Operation - -To execute only the communication operation: - -```terminal -python examples/21_gemm_one_shot_all_reduce_independent/benchmark.py --only_comm --num_ranks 8 -``` - -### Validation - -To verify numerical correctness: - -```terminal -python examples/21_gemm_one_shot_all_reduce_independent/benchmark.py --validate --num_ranks 8 -``` - -### Benchmarking - -To run performance benchmarks: - -```terminal -python examples/21_gemm_one_shot_all_reduce_independent/benchmark.py --benchmark --validate --num_ranks 8 -``` - -### CSV Configuration Sweep - -To run a sweep of configurations from a CSV file: - -```terminal -python examples/21_gemm_one_shot_all_reduce_independent/benchmark.py --benchmark --num_ranks 8 --csv examples/21_gemm_one_shot_all_reduce_independent/example_config.csv -``` - -The CSV file should have the following format: -```csv -m,n,k,datatype,blk_m,blk_n,blk_k,gemm_sms,comm_sms -8192,4608,36864,fp16,256,64,64,256,48 -4096,4096,12288,fp32,128,128,64,240,56 -``` - -### Custom Matrix Dimensions - -You can specify custom matrix dimensions: - -```terminal -python examples/21_gemm_one_shot_all_reduce_independent/benchmark.py --num_ranks 8 -m 4096 -n 4096 -k 4096 -``` - -### Options - -- `-m`: Number of rows in matrix A (default: 8192) -- `-n`: Number of columns in matrix B (default: 4608) -- `-k`: Common dimension between matrices A and B (default: 36864) -- `--only_gemm`: Run only GEMM operation (mutually exclusive with `--only_comm`) -- `--only_comm`: Run only all-reduce operation (mutually exclusive with `--only_gemm`) -- `--datatype`: Data type for computation (`fp16`, `fp32`, `bf16`, `int8`) (default: fp16) -- `--validate`: Enable validation mode -- `--benchmark`: Enable benchmarking mode -- `--csv`: Path to CSV file with multiple configurations -- `--BLK_M`, `--BLK_N`, `--BLK_K`: Block sizes for tiling (defaults: 256, 64, 64)