diff --git a/examples/08_gemm_atomics_all_reduce/benchmark.py b/examples/08_gemm_atomics_all_reduce/benchmark.py index 45503492..3f365ec1 100755 --- a/examples/08_gemm_atomics_all_reduce/benchmark.py +++ b/examples/08_gemm_atomics_all_reduce/benchmark.py @@ -73,6 +73,134 @@ def parse_args(): return vars(parser.parse_args()) +def run_gemm_all_reduce( + A, + B, + shmem, + block_m=256, + block_n=128, + block_k=64, + gsize_m=6, + two_tiles=True, + num_stages=1, + num_warps=8, + waves_per_eu=0, + mfma_instr_size=16, + kpack=2, + gemm_sms=None, + trace_tiles=False, +): + """ + Run GEMM all-reduce operation on input matrices A and B. + + Args: + A: Input matrix A (M x K) + B: Input matrix B (N x K) - will be transposed internally + shmem: Iris shmem object + block_m, block_n, block_k: Block sizes for GEMM + gsize_m: Grid size M + two_tiles: Use two tiles + num_stages: Number of stages + num_warps: Number of warps + waves_per_eu: Waves per execution unit + mfma_instr_size: MFMA instruction size + kpack: K packing size + gemm_sms: Number of SMs for GEMM (defaults to half of available CUs) + trace_tiles: Enable tile tracing + + Returns: + Tuple of (global_C, local_C) where global_C is the all-reduced result + """ + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + cu_count = shmem.get_cu_count() + + M, K = A.shape + N = B.shape[0] # B is expected to be N x K, will be transposed + + # Validate matrix dimensions + assert N % world_size == 0, f"N ({N}) must be divisible by world size ({world_size})." + assert K % world_size == 0, f"K ({K}) must be divisible by world size ({world_size})." + + # Transpose B if needed + if B.shape != (K, N): + B = B.T + + # Set default gemm_sms if not provided + if gemm_sms is None: + gemm_sms = min(cu_count // 2, 64) + + # Split matrices according to rank + rows_per_gpu = K // world_size + start_row = rank * rows_per_gpu + end_row = start_row + rows_per_gpu + local_B = B[start_row:end_row, :] + local_A = A[:, start_row:end_row] + + # Create output matrices + global_C = shmem.zeros((M, N), device="cuda", dtype=A.dtype) + local_C = shmem.zeros((M, N), device="cuda", dtype=A.dtype) + + # Setup parameters + total_blocks_M = triton.cdiv(M, block_m) + total_blocks_N = triton.cdiv(N, block_n) + total_tiles = total_blocks_M * total_blocks_N + + # Create required tensors + tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) + locks = shmem.zeros((gemm_sms,), device="cuda", dtype=torch.int32) + P = shmem.zeros( + (gemm_sms, block_m * block_n), + device="cuda", + dtype=torch.float32, + ) + bias = None + + # Setup timestamps if tracing + timestamps = Timestamps(num_tiles=total_tiles) if trace_tiles else None + + # Synchronize before computation + shmem.barrier() + iris.memset_tensor(tile_completed, 0) + shmem.barrier() + + # Run the GEMM all-reduce operation + matmul.set_debug(False) + result_C = matmul.apply( + local_A, + local_B, + local_C, + global_C, + bias, + P, + locks, + tile_completed, + rank, + world_size, + gemm_sms, + block_m, + block_n, + block_k, + gsize_m, + two_tiles, + num_stages, + num_warps, + waves_per_eu, + mfma_instr_size, + kpack, + shmem.get_heap_bases(), + cu_count, + trace_tiles, + timestamps.mm_begin_timestamp if timestamps else None, + timestamps.mm_end_timestamp if timestamps else None, + ) + + # Synchronize after computation + shmem.barrier() + + return global_C, local_C + + def main(): args = parse_args() @@ -239,9 +367,27 @@ def run_experiment(): if args["validate"]: shmem.info("Validating...") - matmul.set_debug(False) + # Use the reusable function for validation + global_C_validate, _ = run_gemm_all_reduce( + A, + B, + shmem, + block_m=args["BLK_M"], + block_n=args["BLK_N"], + block_k=args["BLK_K"], + gsize_m=args["gsize_m"], + two_tiles=args["two_tiles"], + num_stages=args["num_stages"], + num_warps=args["num_warps"], + waves_per_eu=args["waves_per_eu"], + mfma_instr_size=args["mfmaInstrSize"], + kpack=args["kpack"], + gemm_sms=args["gemm_sms"], + trace_tiles=False, + ) + # Validate global result - success = validate_gemm(A, B, global_C, shmem, atol=2) + success = validate_gemm(A, B, global_C_validate, shmem, atol=2) passed_str = "passed" if success else "failed" shmem.info(f"Final C validation {passed_str}.") diff --git a/tests/examples/test_gemm_atomics_all_reduce.py b/tests/examples/test_gemm_atomics_all_reduce.py new file mode 100644 index 00000000..e77a677f --- /dev/null +++ b/tests/examples/test_gemm_atomics_all_reduce.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import importlib.util +from pathlib import Path + +import pytest +import torch +import iris +from examples.common.validation import validate_gemm + +# Import the benchmark module +current_dir = Path(__file__).parent +benchmark_path = (current_dir / "../../examples/08_gemm_atomics_all_reduce/benchmark.py").resolve() +spec = importlib.util.spec_from_file_location("benchmark", benchmark_path) +benchmark_module = importlib.util.module_from_spec(spec) +spec.loader.exec_module(benchmark_module) + +# Test parameters +DTYPES = [torch.float16, torch.float32] +MATRIX_SIZES = [(256, 256, 256), (512, 512, 512)] +BLOCK_SIZES = [(64, 64, 32)] + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("m, n, k", MATRIX_SIZES) +@pytest.mark.parametrize("block_m, block_n, block_k", BLOCK_SIZES) +def test_gemm_atomics_all_reduce(dtype, m, n, k, block_m, block_n, block_k): + # Initialize iris with appropriate heap size + heap_size = 1 << 30 # 1GB + shmem = iris.iris(heap_size) + + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Skip test if matrix dimensions are not divisible by world size + if n % world_size != 0 or k % world_size != 0: + pytest.skip(f"Matrix dimensions not divisible by world size {world_size}") + + # Create test matrices + A = shmem.randn(m, k, device="cuda", dtype=dtype) + B = shmem.randn(n, k, device="cuda", dtype=dtype) + + # Run the GEMM all-reduce operation using the benchmark function + global_C, local_C = benchmark_module.run_gemm_all_reduce( + A, + B, + shmem, + block_m=block_m, + block_n=block_n, + block_k=block_k, + gsize_m=8, + two_tiles=True, + num_stages=4, + num_warps=4, + waves_per_eu=2, + mfma_instr_size=16, + kpack=1, + trace_tiles=False, + ) + + # Validate results + success = validate_gemm(A, B, global_C, shmem, atol=1e-1) + + # Assert test passed + assert success, "GEMM all-reduce validation failed" + + # Verify that we got a non-zero result + assert not torch.allclose(global_C, torch.zeros_like(global_C)), "Result should not be all zeros"