Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions examples/10_gemm_all_scatter_wg_specialization/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
total_blocks_N = triton.cdiv(args["n"], args["BLK_N"])
total_tiles = total_blocks_M * total_blocks_N

locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int8)
locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)

bias = None

Expand All @@ -157,6 +157,11 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
# Allocate Timestamps
timestamps = Timestamps(num_tiles=total_tiles)

def preamble():
shmem.barrier()
locks.zero_()
shmem.barrier()

def run_experiment():
nonlocal local_C
nonlocal global_C
Expand Down Expand Up @@ -244,7 +249,7 @@ def run_experiment():
matmul.set_debug(False)
shmem.info("Benchmarking...")
perf = lambda ms: 2 * args["M"] * args["N"] * args["K"] * 1e-12 / (ms * 1e-3)
triton_ms = iris.do_bench(run_experiment, shmem.barrier)
triton_ms = iris.do_bench(run_experiment, shmem.barrier, preamble)
triton_tflops = perf(triton_ms)
algo_string = "all_scatter"
shmem.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,7 @@ def persistent_gemm_all_scatter_wg_specialization(
tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp)

tl.store(c_global + global_offset, c, mask=sub_mask, cache_modifier=".wt")
tl.debug_barrier()
tl.store(locks + tile_id, 1, cache_modifier=".wt")
tl.atomic_xchg(locks + tile_id, 1, sem="release", scope="gpu")

else: # pid >= GEMM_SMS
COMM_SMS = NUM_SMS - GEMM_SMS
Expand All @@ -165,8 +164,12 @@ def persistent_gemm_all_scatter_wg_specialization(
global_offset = rm[:, None] * stride_cm_global + (rn[None, :] + cur_rank * N) * stride_cn_global
# End: masks/offset calculations.

while tl.load(locks + tile_id, cache_modifier=".cv", volatile=True) != 1:
pass
# Spin on volatile load until flag is non-zero (cheap)
flag_val = tl.load(locks + tile_id, cache_modifier=".cv", volatile=True)
while flag_val == 0:
flag_val = tl.load(locks + tile_id, cache_modifier=".cv", volatile=True)
# Use atomic_cas with dependency on loaded value to prevent reordering
tl.atomic_cas(locks + tile_id, flag_val, 0, sem="acquire", scope="gpu")

for remote_rank in range(world_size):
if remote_rank != cur_rank:
Expand Down
9 changes: 7 additions & 2 deletions examples/11_gemm_all_scatter_producer_consumer/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
total_blocks_N = triton.cdiv(args["n"], args["BLK_N"])
total_tiles = total_blocks_M * total_blocks_N

locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int8)
locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)

bias = None

Expand Down Expand Up @@ -166,6 +166,11 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
# Allocate Timestamps
timestamps = Timestamps(num_tiles=total_tiles)

def preamble():
shmem.barrier()
locks.zero_()
shmem.barrier()

def run_experiment():
nonlocal C
nonlocal kernel_timing
Expand Down Expand Up @@ -275,7 +280,7 @@ def run_experiment():
matmul.set_debug(False)
shmem.info("Benchmarking...")
perf = lambda ms: 2 * args["M"] * args["N"] * args["K"] * 1e-12 / (ms * 1e-3)
triton_ms = iris.do_bench(run_experiment, shmem.barrier)
triton_ms = iris.do_bench(run_experiment, shmem.barrier, preamble)
triton_tflops = perf(triton_ms)
algo_string = "all_scatter"
shmem.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ def persistent_gemm(
tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp)

tl.store(C + global_offset, c, mask=sub_mask, cache_modifier=".wt")
tl.debug_barrier()
tl.store(locks + tile_id, 1, cache_modifier=".wt")
tl.atomic_xchg(locks + tile_id, 1, sem="release", scope="gpu")


@triton.jit()
Expand Down Expand Up @@ -185,8 +184,12 @@ def persistent_all_scatter(
global_offset = rm[:, None] * stride_cm_global + (rn[None, :] + cur_rank * N) * stride_cn_global
# End: masks/offset calculations.

while tl.load(locks + tile_id, cache_modifier=".cv", volatile=True) != 1:
pass
# Spin on volatile load until flag is non-zero (cheap)
flag_val = tl.load(locks + tile_id, cache_modifier=".cv", volatile=True)
while flag_val == 0:
flag_val = tl.load(locks + tile_id, cache_modifier=".cv", volatile=True)
# Use atomic_cas with dependency on loaded value to prevent reordering
tl.atomic_cas(locks + tile_id, flag_val, 0, sem="acquire", scope="gpu")

for remote_rank in range(world_size):
if remote_rank != cur_rank:
Expand Down
Loading