diff --git a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py index 59d14565..1b8417dc 100755 --- a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py +++ b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py @@ -136,7 +136,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): total_blocks_N = triton.cdiv(args["n"], args["BLK_N"]) total_tiles = total_blocks_M * total_blocks_N - locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int8) + locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) bias = None @@ -157,6 +157,11 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): # Allocate Timestamps timestamps = Timestamps(num_tiles=total_tiles) + def preamble(): + shmem.barrier() + locks.zero_() + shmem.barrier() + def run_experiment(): nonlocal local_C nonlocal global_C @@ -244,7 +249,7 @@ def run_experiment(): matmul.set_debug(False) shmem.info("Benchmarking...") perf = lambda ms: 2 * args["M"] * args["N"] * args["K"] * 1e-12 / (ms * 1e-3) - triton_ms = iris.do_bench(run_experiment, shmem.barrier) + triton_ms = iris.do_bench(run_experiment, shmem.barrier, preamble) triton_tflops = perf(triton_ms) algo_string = "all_scatter" shmem.info( diff --git a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py index ac2d2e35..aac520da 100644 --- a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py +++ b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py @@ -142,8 +142,7 @@ def persistent_gemm_all_scatter_wg_specialization( tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) tl.store(c_global + global_offset, c, mask=sub_mask, cache_modifier=".wt") - tl.debug_barrier() - tl.store(locks + tile_id, 1, cache_modifier=".wt") + tl.atomic_xchg(locks + tile_id, 1, sem="release", scope="gpu") else: # pid >= GEMM_SMS COMM_SMS = NUM_SMS - GEMM_SMS @@ -165,8 +164,12 @@ def persistent_gemm_all_scatter_wg_specialization( global_offset = rm[:, None] * stride_cm_global + (rn[None, :] + cur_rank * N) * stride_cn_global # End: masks/offset calculations. - while tl.load(locks + tile_id, cache_modifier=".cv", volatile=True) != 1: - pass + # Spin on volatile load until flag is non-zero (cheap) + flag_val = tl.load(locks + tile_id, cache_modifier=".cv", volatile=True) + while flag_val == 0: + flag_val = tl.load(locks + tile_id, cache_modifier=".cv", volatile=True) + # Use atomic_cas with dependency on loaded value to prevent reordering + tl.atomic_cas(locks + tile_id, flag_val, 0, sem="acquire", scope="gpu") for remote_rank in range(world_size): if remote_rank != cur_rank: diff --git a/examples/11_gemm_all_scatter_producer_consumer/benchmark.py b/examples/11_gemm_all_scatter_producer_consumer/benchmark.py index 4849b053..561a37dc 100755 --- a/examples/11_gemm_all_scatter_producer_consumer/benchmark.py +++ b/examples/11_gemm_all_scatter_producer_consumer/benchmark.py @@ -136,7 +136,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): total_blocks_N = triton.cdiv(args["n"], args["BLK_N"]) total_tiles = total_blocks_M * total_blocks_N - locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int8) + locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) bias = None @@ -166,6 +166,11 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): # Allocate Timestamps timestamps = Timestamps(num_tiles=total_tiles) + def preamble(): + shmem.barrier() + locks.zero_() + shmem.barrier() + def run_experiment(): nonlocal C nonlocal kernel_timing @@ -275,7 +280,7 @@ def run_experiment(): matmul.set_debug(False) shmem.info("Benchmarking...") perf = lambda ms: 2 * args["M"] * args["N"] * args["K"] * 1e-12 / (ms * 1e-3) - triton_ms = iris.do_bench(run_experiment, shmem.barrier) + triton_ms = iris.do_bench(run_experiment, shmem.barrier, preamble) triton_tflops = perf(triton_ms) algo_string = "all_scatter" shmem.info( diff --git a/examples/11_gemm_all_scatter_producer_consumer/gemm_all_scatter_producer_consumer.py b/examples/11_gemm_all_scatter_producer_consumer/gemm_all_scatter_producer_consumer.py index a8311943..3620f061 100644 --- a/examples/11_gemm_all_scatter_producer_consumer/gemm_all_scatter_producer_consumer.py +++ b/examples/11_gemm_all_scatter_producer_consumer/gemm_all_scatter_producer_consumer.py @@ -133,8 +133,7 @@ def persistent_gemm( tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) tl.store(C + global_offset, c, mask=sub_mask, cache_modifier=".wt") - tl.debug_barrier() - tl.store(locks + tile_id, 1, cache_modifier=".wt") + tl.atomic_xchg(locks + tile_id, 1, sem="release", scope="gpu") @triton.jit() @@ -185,8 +184,12 @@ def persistent_all_scatter( global_offset = rm[:, None] * stride_cm_global + (rn[None, :] + cur_rank * N) * stride_cn_global # End: masks/offset calculations. - while tl.load(locks + tile_id, cache_modifier=".cv", volatile=True) != 1: - pass + # Spin on volatile load until flag is non-zero (cheap) + flag_val = tl.load(locks + tile_id, cache_modifier=".cv", volatile=True) + while flag_val == 0: + flag_val = tl.load(locks + tile_id, cache_modifier=".cv", volatile=True) + # Use atomic_cas with dependency on loaded value to prevent reordering + tl.atomic_cas(locks + tile_id, flag_val, 0, sem="acquire", scope="gpu") for remote_rank in range(world_size): if remote_rank != cur_rank: