From 653a85caf21b29847e97ed312307723b7c4901fd Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 30 Aug 2025 21:59:02 +0000 Subject: [PATCH 1/6] Initial plan From 95d600075a58239ebabf74f56789ef8ac2669fae Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 30 Aug 2025 22:05:33 +0000 Subject: [PATCH 2/6] Add pytest test for gemm_atomics_all_reduce example Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- .../examples/test_gemm_atomics_all_reduce.py | 148 ++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 tests/examples/test_gemm_atomics_all_reduce.py diff --git a/tests/examples/test_gemm_atomics_all_reduce.py b/tests/examples/test_gemm_atomics_all_reduce.py new file mode 100644 index 00000000..8298ddcf --- /dev/null +++ b/tests/examples/test_gemm_atomics_all_reduce.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import importlib.util +from pathlib import Path + +import numpy as np +import pytest +import torch +import triton +import triton.language as tl + +import iris +from examples.common.utils import Timestamps +from examples.common.validation import validate_gemm + +current_dir = Path(__file__).parent +matmul_wrapper_path = (current_dir / "../../examples/08_gemm_atomics_all_reduce/matmul_wrapper.py").resolve() + +# Import matmul_wrapper module +matmul_spec = importlib.util.spec_from_file_location("matmul_wrapper", matmul_wrapper_path) +matmul_module = importlib.util.module_from_spec(matmul_spec) +matmul_spec.loader.exec_module(matmul_module) + + +@pytest.mark.parametrize( + "dtype", + [ + torch.float16, + torch.float32, + torch.bfloat16, + ], +) +@pytest.mark.parametrize( + "m, n, k", + [ + (512, 512, 512), + (1024, 1024, 1024), + ], +) +@pytest.mark.parametrize( + "block_m, block_n, block_k", + [ + (64, 64, 32), + (128, 128, 32), + ], +) +def test_gemm_atomics_all_reduce(dtype, m, n, k, block_m, block_n, block_k): + # Initialize iris with appropriate heap size + heap_size = 1 << 30 # 1GB + shmem = iris.iris(heap_size) + + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + cu_count = shmem.get_cu_count() + + # Skip test if matrix dimensions are not divisible by world size + if n % world_size != 0 or k % world_size != 0: + pytest.skip(f"Matrix dimensions not divisible by world size {world_size}") + + # Create test matrices + A = shmem.randn(m, k, device="cuda", dtype=dtype) + B = shmem.randn(n, k, device="cuda", dtype=dtype).T + C = shmem.zeros((m, n), device="cuda", dtype=dtype) + + # Split matrices according to rank + rows_per_gpu = k // world_size + start_row = rank * rows_per_gpu + end_row = start_row + rows_per_gpu + local_B = B[start_row:end_row, :] + local_A = A[:, start_row:end_row] + + # Create output matrices + global_C = shmem.zeros((m, n), device="cuda", dtype=dtype) + local_C = shmem.zeros((m, n), device="cuda", dtype=dtype) + + # Setup parameters + total_blocks_M = triton.cdiv(m, block_m) + total_blocks_N = triton.cdiv(n, block_n) + total_tiles = total_blocks_M * total_blocks_N + + # Use conservative number of SMs + gemm_sms = min(cu_count // 2, 128) # Use half of available CUs, max 128 + + # Create required tensors + tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) + locks = shmem.zeros((gemm_sms,), device="cuda", dtype=torch.int32) + P = shmem.zeros( + (gemm_sms, block_m * block_n), + device="cuda", + dtype=torch.float32, + ) + bias = None + + # Setup timestamps + timestamps = Timestamps(num_tiles=total_tiles) + + # Synchronize before test + shmem.barrier() + + # Reset tile_completed + iris.memset_tensor(tile_completed, 0) + shmem.barrier() + + # Run the GEMM all-reduce operation + matmul_module.matmul.set_debug(False) + + result_C = matmul_module.matmul.apply( + local_A, + local_B, + local_C, + global_C, + bias, + P, + locks, + tile_completed, + rank, + world_size, + gemm_sms, + block_m, + block_n, + block_k, + 8, # gsize_m + True, # two_tiles + 4, # num_stages + 4, # num_warps + 2, # waves_per_eu + 16, # mfmaInstrSize + 1, # kpack + shmem.get_heap_bases(), + cu_count, + False, # trace_tiles + timestamps.mm_begin_timestamp, + timestamps.mm_end_timestamp, + ) + + # Synchronize after computation + shmem.barrier() + + # Validate results + success = validate_gemm(A, B, global_C, shmem, atol=1e-1) + + # Assert test passed + assert success, "GEMM all-reduce validation failed" + + # Verify that we got a non-zero result + assert not torch.allclose(global_C, torch.zeros_like(global_C)), "Result should not be all zeros" From 5a630b7c6c50c6c52633cf23c077611041f933d4 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 30 Aug 2025 22:07:04 +0000 Subject: [PATCH 3/6] Simplify test parameters and reduce matrix sizes for better performance Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- tests/examples/test_gemm_atomics_all_reduce.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/examples/test_gemm_atomics_all_reduce.py b/tests/examples/test_gemm_atomics_all_reduce.py index 8298ddcf..c19b6449 100644 --- a/tests/examples/test_gemm_atomics_all_reduce.py +++ b/tests/examples/test_gemm_atomics_all_reduce.py @@ -29,21 +29,19 @@ [ torch.float16, torch.float32, - torch.bfloat16, ], ) @pytest.mark.parametrize( "m, n, k", [ + (256, 256, 256), (512, 512, 512), - (1024, 1024, 1024), ], ) @pytest.mark.parametrize( "block_m, block_n, block_k", [ (64, 64, 32), - (128, 128, 32), ], ) def test_gemm_atomics_all_reduce(dtype, m, n, k, block_m, block_n, block_k): @@ -81,7 +79,7 @@ def test_gemm_atomics_all_reduce(dtype, m, n, k, block_m, block_n, block_k): total_tiles = total_blocks_M * total_blocks_N # Use conservative number of SMs - gemm_sms = min(cu_count // 2, 128) # Use half of available CUs, max 128 + gemm_sms = min(cu_count // 2, 64) # Use half of available CUs, max 64 # Create required tensors tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) From d3a733656536c1d855401af99313aabdb7a1a956 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 30 Aug 2025 22:35:42 +0000 Subject: [PATCH 4/6] Fix pytest import errors by moving torch references and adding proper error handling Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- .../examples/test_gemm_atomics_all_reduce.py | 72 +++++++++---------- 1 file changed, 34 insertions(+), 38 deletions(-) diff --git a/tests/examples/test_gemm_atomics_all_reduce.py b/tests/examples/test_gemm_atomics_all_reduce.py index c19b6449..915bab04 100644 --- a/tests/examples/test_gemm_atomics_all_reduce.py +++ b/tests/examples/test_gemm_atomics_all_reduce.py @@ -5,46 +5,42 @@ import importlib.util from pathlib import Path -import numpy as np import pytest -import torch -import triton -import triton.language as tl - -import iris -from examples.common.utils import Timestamps -from examples.common.validation import validate_gemm - -current_dir = Path(__file__).parent -matmul_wrapper_path = (current_dir / "../../examples/08_gemm_atomics_all_reduce/matmul_wrapper.py").resolve() - -# Import matmul_wrapper module -matmul_spec = importlib.util.spec_from_file_location("matmul_wrapper", matmul_wrapper_path) -matmul_module = importlib.util.module_from_spec(matmul_spec) -matmul_spec.loader.exec_module(matmul_module) - - -@pytest.mark.parametrize( - "dtype", - [ - torch.float16, - torch.float32, - ], -) -@pytest.mark.parametrize( - "m, n, k", - [ - (256, 256, 256), - (512, 512, 512), - ], -) -@pytest.mark.parametrize( - "block_m, block_n, block_k", - [ - (64, 64, 32), - ], -) + +# Try to import dependencies - skip test if not available +try: + import numpy as np + import torch + import triton + import triton.language as tl + import iris + from examples.common.utils import Timestamps + from examples.common.validation import validate_gemm + + # Define test parameters after successful import + DTYPES = [torch.float16, torch.float32] + MATRIX_SIZES = [(256, 256, 256), (512, 512, 512)] + BLOCK_SIZES = [(64, 64, 32)] + +except ImportError as e: + pytest.skip(f"Skipping gemm_atomics_all_reduce test due to missing dependencies: {e}", allow_module_level=True) + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("m, n, k", MATRIX_SIZES) +@pytest.mark.parametrize("block_m, block_n, block_k", BLOCK_SIZES) def test_gemm_atomics_all_reduce(dtype, m, n, k, block_m, block_n, block_k): + # Import matmul_wrapper module at test time + try: + current_dir = Path(__file__).parent + matmul_wrapper_path = (current_dir / "../../examples/08_gemm_atomics_all_reduce/matmul_wrapper.py").resolve() + + matmul_spec = importlib.util.spec_from_file_location("matmul_wrapper", matmul_wrapper_path) + matmul_module = importlib.util.module_from_spec(matmul_spec) + matmul_spec.loader.exec_module(matmul_module) + except (ImportError, FileNotFoundError) as e: + pytest.skip(f"Skipping test due to import error: {e}") + # Initialize iris with appropriate heap size heap_size = 1 << 30 # 1GB shmem = iris.iris(heap_size) From 05344623e6af774e4e5f6a774778066b615b0fd1 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Sat, 30 Aug 2025 22:38:26 +0000 Subject: [PATCH 5/6] Apply Ruff auto-fixes --- tests/examples/test_gemm_atomics_all_reduce.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/examples/test_gemm_atomics_all_reduce.py b/tests/examples/test_gemm_atomics_all_reduce.py index 915bab04..c617c1ca 100644 --- a/tests/examples/test_gemm_atomics_all_reduce.py +++ b/tests/examples/test_gemm_atomics_all_reduce.py @@ -16,12 +16,12 @@ import iris from examples.common.utils import Timestamps from examples.common.validation import validate_gemm - + # Define test parameters after successful import DTYPES = [torch.float16, torch.float32] MATRIX_SIZES = [(256, 256, 256), (512, 512, 512)] BLOCK_SIZES = [(64, 64, 32)] - + except ImportError as e: pytest.skip(f"Skipping gemm_atomics_all_reduce test due to missing dependencies: {e}", allow_module_level=True) @@ -34,13 +34,13 @@ def test_gemm_atomics_all_reduce(dtype, m, n, k, block_m, block_n, block_k): try: current_dir = Path(__file__).parent matmul_wrapper_path = (current_dir / "../../examples/08_gemm_atomics_all_reduce/matmul_wrapper.py").resolve() - + matmul_spec = importlib.util.spec_from_file_location("matmul_wrapper", matmul_wrapper_path) matmul_module = importlib.util.module_from_spec(matmul_spec) matmul_spec.loader.exec_module(matmul_module) except (ImportError, FileNotFoundError) as e: pytest.skip(f"Skipping test due to import error: {e}") - + # Initialize iris with appropriate heap size heap_size = 1 << 30 # 1GB shmem = iris.iris(heap_size) From 62f94fb64fe4992dc9643025491622f00abbeffe Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 31 Aug 2025 03:35:09 +0000 Subject: [PATCH 6/6] Refactor GEMM atomics all-reduce example to use reusable function and update test Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- .../08_gemm_atomics_all_reduce/benchmark.py | 150 +++++++++++++++++- .../examples/test_gemm_atomics_all_reduce.py | 134 ++++------------ 2 files changed, 179 insertions(+), 105 deletions(-) diff --git a/examples/08_gemm_atomics_all_reduce/benchmark.py b/examples/08_gemm_atomics_all_reduce/benchmark.py index 45503492..3f365ec1 100755 --- a/examples/08_gemm_atomics_all_reduce/benchmark.py +++ b/examples/08_gemm_atomics_all_reduce/benchmark.py @@ -73,6 +73,134 @@ def parse_args(): return vars(parser.parse_args()) +def run_gemm_all_reduce( + A, + B, + shmem, + block_m=256, + block_n=128, + block_k=64, + gsize_m=6, + two_tiles=True, + num_stages=1, + num_warps=8, + waves_per_eu=0, + mfma_instr_size=16, + kpack=2, + gemm_sms=None, + trace_tiles=False, +): + """ + Run GEMM all-reduce operation on input matrices A and B. + + Args: + A: Input matrix A (M x K) + B: Input matrix B (N x K) - will be transposed internally + shmem: Iris shmem object + block_m, block_n, block_k: Block sizes for GEMM + gsize_m: Grid size M + two_tiles: Use two tiles + num_stages: Number of stages + num_warps: Number of warps + waves_per_eu: Waves per execution unit + mfma_instr_size: MFMA instruction size + kpack: K packing size + gemm_sms: Number of SMs for GEMM (defaults to half of available CUs) + trace_tiles: Enable tile tracing + + Returns: + Tuple of (global_C, local_C) where global_C is the all-reduced result + """ + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + cu_count = shmem.get_cu_count() + + M, K = A.shape + N = B.shape[0] # B is expected to be N x K, will be transposed + + # Validate matrix dimensions + assert N % world_size == 0, f"N ({N}) must be divisible by world size ({world_size})." + assert K % world_size == 0, f"K ({K}) must be divisible by world size ({world_size})." + + # Transpose B if needed + if B.shape != (K, N): + B = B.T + + # Set default gemm_sms if not provided + if gemm_sms is None: + gemm_sms = min(cu_count // 2, 64) + + # Split matrices according to rank + rows_per_gpu = K // world_size + start_row = rank * rows_per_gpu + end_row = start_row + rows_per_gpu + local_B = B[start_row:end_row, :] + local_A = A[:, start_row:end_row] + + # Create output matrices + global_C = shmem.zeros((M, N), device="cuda", dtype=A.dtype) + local_C = shmem.zeros((M, N), device="cuda", dtype=A.dtype) + + # Setup parameters + total_blocks_M = triton.cdiv(M, block_m) + total_blocks_N = triton.cdiv(N, block_n) + total_tiles = total_blocks_M * total_blocks_N + + # Create required tensors + tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) + locks = shmem.zeros((gemm_sms,), device="cuda", dtype=torch.int32) + P = shmem.zeros( + (gemm_sms, block_m * block_n), + device="cuda", + dtype=torch.float32, + ) + bias = None + + # Setup timestamps if tracing + timestamps = Timestamps(num_tiles=total_tiles) if trace_tiles else None + + # Synchronize before computation + shmem.barrier() + iris.memset_tensor(tile_completed, 0) + shmem.barrier() + + # Run the GEMM all-reduce operation + matmul.set_debug(False) + result_C = matmul.apply( + local_A, + local_B, + local_C, + global_C, + bias, + P, + locks, + tile_completed, + rank, + world_size, + gemm_sms, + block_m, + block_n, + block_k, + gsize_m, + two_tiles, + num_stages, + num_warps, + waves_per_eu, + mfma_instr_size, + kpack, + shmem.get_heap_bases(), + cu_count, + trace_tiles, + timestamps.mm_begin_timestamp if timestamps else None, + timestamps.mm_end_timestamp if timestamps else None, + ) + + # Synchronize after computation + shmem.barrier() + + return global_C, local_C + + def main(): args = parse_args() @@ -239,9 +367,27 @@ def run_experiment(): if args["validate"]: shmem.info("Validating...") - matmul.set_debug(False) + # Use the reusable function for validation + global_C_validate, _ = run_gemm_all_reduce( + A, + B, + shmem, + block_m=args["BLK_M"], + block_n=args["BLK_N"], + block_k=args["BLK_K"], + gsize_m=args["gsize_m"], + two_tiles=args["two_tiles"], + num_stages=args["num_stages"], + num_warps=args["num_warps"], + waves_per_eu=args["waves_per_eu"], + mfma_instr_size=args["mfmaInstrSize"], + kpack=args["kpack"], + gemm_sms=args["gemm_sms"], + trace_tiles=False, + ) + # Validate global result - success = validate_gemm(A, B, global_C, shmem, atol=2) + success = validate_gemm(A, B, global_C_validate, shmem, atol=2) passed_str = "passed" if success else "failed" shmem.info(f"Final C validation {passed_str}.") diff --git a/tests/examples/test_gemm_atomics_all_reduce.py b/tests/examples/test_gemm_atomics_all_reduce.py index c617c1ca..e77a677f 100644 --- a/tests/examples/test_gemm_atomics_all_reduce.py +++ b/tests/examples/test_gemm_atomics_all_reduce.py @@ -6,48 +6,33 @@ from pathlib import Path import pytest +import torch +import iris +from examples.common.validation import validate_gemm -# Try to import dependencies - skip test if not available -try: - import numpy as np - import torch - import triton - import triton.language as tl - import iris - from examples.common.utils import Timestamps - from examples.common.validation import validate_gemm +# Import the benchmark module +current_dir = Path(__file__).parent +benchmark_path = (current_dir / "../../examples/08_gemm_atomics_all_reduce/benchmark.py").resolve() +spec = importlib.util.spec_from_file_location("benchmark", benchmark_path) +benchmark_module = importlib.util.module_from_spec(spec) +spec.loader.exec_module(benchmark_module) - # Define test parameters after successful import - DTYPES = [torch.float16, torch.float32] - MATRIX_SIZES = [(256, 256, 256), (512, 512, 512)] - BLOCK_SIZES = [(64, 64, 32)] - -except ImportError as e: - pytest.skip(f"Skipping gemm_atomics_all_reduce test due to missing dependencies: {e}", allow_module_level=True) +# Test parameters +DTYPES = [torch.float16, torch.float32] +MATRIX_SIZES = [(256, 256, 256), (512, 512, 512)] +BLOCK_SIZES = [(64, 64, 32)] @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("m, n, k", MATRIX_SIZES) @pytest.mark.parametrize("block_m, block_n, block_k", BLOCK_SIZES) def test_gemm_atomics_all_reduce(dtype, m, n, k, block_m, block_n, block_k): - # Import matmul_wrapper module at test time - try: - current_dir = Path(__file__).parent - matmul_wrapper_path = (current_dir / "../../examples/08_gemm_atomics_all_reduce/matmul_wrapper.py").resolve() - - matmul_spec = importlib.util.spec_from_file_location("matmul_wrapper", matmul_wrapper_path) - matmul_module = importlib.util.module_from_spec(matmul_spec) - matmul_spec.loader.exec_module(matmul_module) - except (ImportError, FileNotFoundError) as e: - pytest.skip(f"Skipping test due to import error: {e}") - # Initialize iris with appropriate heap size heap_size = 1 << 30 # 1GB shmem = iris.iris(heap_size) rank = shmem.get_rank() world_size = shmem.get_num_ranks() - cu_count = shmem.get_cu_count() # Skip test if matrix dimensions are not divisible by world size if n % world_size != 0 or k % world_size != 0: @@ -55,82 +40,25 @@ def test_gemm_atomics_all_reduce(dtype, m, n, k, block_m, block_n, block_k): # Create test matrices A = shmem.randn(m, k, device="cuda", dtype=dtype) - B = shmem.randn(n, k, device="cuda", dtype=dtype).T - C = shmem.zeros((m, n), device="cuda", dtype=dtype) - - # Split matrices according to rank - rows_per_gpu = k // world_size - start_row = rank * rows_per_gpu - end_row = start_row + rows_per_gpu - local_B = B[start_row:end_row, :] - local_A = A[:, start_row:end_row] - - # Create output matrices - global_C = shmem.zeros((m, n), device="cuda", dtype=dtype) - local_C = shmem.zeros((m, n), device="cuda", dtype=dtype) - - # Setup parameters - total_blocks_M = triton.cdiv(m, block_m) - total_blocks_N = triton.cdiv(n, block_n) - total_tiles = total_blocks_M * total_blocks_N - - # Use conservative number of SMs - gemm_sms = min(cu_count // 2, 64) # Use half of available CUs, max 64 - - # Create required tensors - tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) - locks = shmem.zeros((gemm_sms,), device="cuda", dtype=torch.int32) - P = shmem.zeros( - (gemm_sms, block_m * block_n), - device="cuda", - dtype=torch.float32, + B = shmem.randn(n, k, device="cuda", dtype=dtype) + + # Run the GEMM all-reduce operation using the benchmark function + global_C, local_C = benchmark_module.run_gemm_all_reduce( + A, + B, + shmem, + block_m=block_m, + block_n=block_n, + block_k=block_k, + gsize_m=8, + two_tiles=True, + num_stages=4, + num_warps=4, + waves_per_eu=2, + mfma_instr_size=16, + kpack=1, + trace_tiles=False, ) - bias = None - - # Setup timestamps - timestamps = Timestamps(num_tiles=total_tiles) - - # Synchronize before test - shmem.barrier() - - # Reset tile_completed - iris.memset_tensor(tile_completed, 0) - shmem.barrier() - - # Run the GEMM all-reduce operation - matmul_module.matmul.set_debug(False) - - result_C = matmul_module.matmul.apply( - local_A, - local_B, - local_C, - global_C, - bias, - P, - locks, - tile_completed, - rank, - world_size, - gemm_sms, - block_m, - block_n, - block_k, - 8, # gsize_m - True, # two_tiles - 4, # num_stages - 4, # num_warps - 2, # waves_per_eu - 16, # mfmaInstrSize - 1, # kpack - shmem.get_heap_bases(), - cu_count, - False, # trace_tiles - timestamps.mm_begin_timestamp, - timestamps.mm_end_timestamp, - ) - - # Synchronize after computation - shmem.barrier() # Validate results success = validate_gemm(A, B, global_C, shmem, atol=1e-1)