Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
ba7028b
Add initial cache modifiers code and docs
mawad-amd Sep 10, 2025
276713b
Add test
mawad-amd Sep 10, 2025
6f6818f
Apply Ruff auto-fixes
github-actions[bot] Sep 10, 2025
9ad63a0
Use `None` for default value
mawad-amd Sep 13, 2025
677c966
Apply Ruff auto-fixes
github-actions[bot] Sep 13, 2025
af3592d
Cleanup the test
mawad-amd Sep 14, 2025
8a411a2
Apply Ruff auto-fixes
github-actions[bot] Sep 14, 2025
ff26f96
Merge branch 'main' into muhaawad/cache-modifiers
mawad-amd Sep 14, 2025
8f76d95
Check return value
mawad-amd Sep 14, 2025
87fb74a
Remove volatile from store
mawad-amd Sep 14, 2025
162ec39
Add test store
mawad-amd Sep 14, 2025
01da6ca
Apply Ruff auto-fixes
github-actions[bot] Sep 14, 2025
9a27ead
Add put/get modifiers
mawad-amd Sep 14, 2025
99ee66c
Add tests for put and get cache modifiers
mawad-amd Sep 14, 2025
74d0133
Apply Ruff auto-fixes
github-actions[bot] Sep 14, 2025
e76d4c5
Test default values
mawad-amd Sep 14, 2025
451ee99
Fix default value docstring
mawad-amd Sep 14, 2025
b524f40
Fix tests
mawad-amd Sep 14, 2025
b8bd8a7
Apply Ruff auto-fixes
github-actions[bot] Sep 14, 2025
0a157b5
Sync cache modifiers branch with main and add cache modifiers to copy…
Copilot Oct 11, 2025
b127b91
Merge branch 'main' into muhaawad/cache-modifiers
mawad-amd Oct 11, 2025
c9f314f
Merge branch 'main' into muhaawad/cache-modifiers
mawad-amd Oct 24, 2025
e74aacd
Fix device mismatch in test_copy_cache_modifiers assertions (#271)
Copilot Oct 29, 2025
88970ee
Fix pointer arithmetic in test_copy_cache_modifiers (#273)
Copilot Oct 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 94 additions & 13 deletions iris/iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,7 +1521,7 @@ def __translate(ptr, from_rank, to_rank, heap_bases):


@triton.jit
def load(pointer, to_rank, from_rank, heap_bases, mask=None):
def load(pointer, to_rank, from_rank, heap_bases, mask=None, cache_modifier=None, volatile=False):
"""
Loads a value from the specified rank's memory location.

Expand All @@ -1530,12 +1530,28 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None):
data from the target memory location. If the `from_rank` and `to_rank` are the same,
this function performs a local load operation.

The `cache_modifier` parameter controls instruction-level cache behavior
by setting the appropriate scope (`SC0`, `SC1`) and non-temporal (`NT`) bits
in the global load instruction. These affect cache usage across the CU,
L2, and last-level caches.

Args:
pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local.
to_rank (int): The rank ID to which the pointer will be translated. Must be the current rank where the pointer is local.
from_rank (int): The rank ID from which to read the data.
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address pointer[idx]. Defaults to None.
cache_modifier (str, optional): Controls cache behavior of the load.

Supported values:
- None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy.
- ".ca": Cache at all levels (CU, L2, LLC) with LRU policy
- ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted.
- ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted.
Ensures global coherence by invalidating stale GPU cache lines.

volatile (bool, optional): If True, disables compiler optimizations that
could reorder or eliminate the load.

Returns:
Block: The loaded value from the target memory location.
Expand All @@ -1550,12 +1566,12 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None):
>>> return data
"""
translated_ptr = __translate(pointer, to_rank, from_rank, heap_bases)
result = tl.load(translated_ptr, mask=mask)
result = tl.load(translated_ptr, mask=mask, cache_modifier=cache_modifier, volatile=volatile)
return result


@triton.jit
def store(pointer, value, from_rank, to_rank, heap_bases, mask=None):
def store(pointer, value, from_rank, to_rank, heap_bases, mask=None, cache_modifier=None):
"""
Writes data to the specified rank's memory location.

Expand All @@ -1564,13 +1580,25 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None):
the provided data to the target memory location. If the `from_rank` and `to_rank` are the same,
this function performs a local store operation.

The `cache_modifier` parameter controls instruction-level cache behavior
by setting the appropriate scope (`SC0`, `SC1`) and non-temporal (`NT`) bits
in the global store instruction. These affect cache usage across the CU (L1),
L2, and last-level cache (LLC), following the CDNA ISA.

Args:
pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local.
value (Block): The tensor of elements to be stored.
from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local.
to_rank (int): The rank ID to which the data will be written.
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
mask (Block of triton.int1, optional): If mask[idx] is false, do not store the data at address pointer[idx]. Defaults to None.
cache_modifier (str, optional): Controls cache behavior of the store. Supported values are:

- None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy.
- ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later.
- ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU.
- ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC.
- ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU.

Returns:
None
Expand All @@ -1585,11 +1613,21 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None):
>>> iris.store(ptr, value, cur_rank, remote_rank, heap_bases)
"""
translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases)
tl.store(translated_ptr, value, mask=mask)
tl.store(translated_ptr, value, mask=mask, cache_modifier=cache_modifier)


@triton.jit
def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None):
def copy(
src_ptr,
dst_ptr,
from_rank,
to_rank,
cur_rank,
heap_bases,
mask=None,
load_cache_modifier=None,
store_cache_modifier=None,
):
"""
Copies data from the specified rank's memory into the destination rank's memory.
This function performs the transfer by translating `src_ptr` from the `from_rank`'s address
Expand All @@ -1607,6 +1645,19 @@ def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None):
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
mask (Block of triton.int1, optional): If mask[idx] is false, do not load from the translated src_ptr[idx] and do not store to dst_ptr[idx]. Defaults to None.

load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are:
- None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy.
- ".ca": Cache at all levels (CU, L2, LLC) with LRU policy.
- ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted.
- ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted.

store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are:
- None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy.
- ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later.
- ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU.
- ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC.
- ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU.

Returns:
None

Expand Down Expand Up @@ -1635,12 +1686,14 @@ def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None):
translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype)
translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype)

data = tl.load(translated_src, mask=mask)
tl.store(translated_dst, data, mask=mask)
data = tl.load(translated_src, mask=mask, cache_modifier=load_cache_modifier)
tl.store(translated_dst, data, mask=mask, cache_modifier=store_cache_modifier)


@triton.jit
def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
def get(
from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None, load_cache_modifier=None, store_cache_modifier=None
):
"""
Copies data from the specified rank's memory to the current rank's local memory.

Expand All @@ -1657,6 +1710,19 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None.

load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are:
- None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy.
- ".ca": Cache at all levels (CU, L2, LLC) with LRU policy.
- ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted.
- ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted.

store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are:
- None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy.
- ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later.
- ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU.
- ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC.
- ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU.

Returns:
None

Expand All @@ -1669,13 +1735,15 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
"""
translated_from_ptr = __translate(from_ptr, from_rank, to_rank, heap_bases)

data = tl.load(translated_from_ptr, mask=mask)
data = tl.load(translated_from_ptr, mask=mask, cache_modifier=load_cache_modifier)

tl.store(to_ptr, data, mask=mask)
tl.store(to_ptr, data, mask=mask, cache_modifier=store_cache_modifier)


@triton.jit
def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
def put(
from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None, load_cache_modifier=None, store_cache_modifier=None
):
"""
Copies data from the current rank's local memory to the specified rank's memory.
This function performs a memory write operation by loading data from the current
Expand All @@ -1691,6 +1759,19 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None.

load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are:
- None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy.
- ".ca": Cache at all levels (CU, L2, LLC) with LRU policy.
- ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted.
- ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted.

store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are:
- None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy.
- ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later.
- ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU.
- ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC.
- ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU.

Returns:
None

Expand All @@ -1703,9 +1784,9 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
"""
translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases)

data = tl.load(from_ptr, mask=mask)
data = tl.load(from_ptr, mask=mask, cache_modifier=load_cache_modifier)

tl.store(translated_to_ptr, data, mask=mask)
tl.store(translated_to_ptr, data, mask=mask, cache_modifier=store_cache_modifier)


@triton.jit
Expand Down
107 changes: 107 additions & 0 deletions tests/unittests/test_copy_cache_modifiers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.

import torch
import triton
import triton.language as tl
import pytest
import iris
from itertools import product


@triton.jit
def copy_kernel(
data,
results,
cur_rank: tl.constexpr,
num_ranks: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
heap_bases: tl.tensor,
load_cache_modifier: tl.constexpr,
store_cache_modifier: tl.constexpr,
):
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < BLOCK_SIZE

# Test copy with cache modifiers - copy from current rank to other ranks
for target_rank in range(num_ranks):
src_data = data + BLOCK_SIZE * cur_rank
dest_data = results + BLOCK_SIZE * cur_rank
if load_cache_modifier is None and store_cache_modifier is None:
iris.copy(src_data + offsets, dest_data + offsets, cur_rank, target_rank, cur_rank, heap_bases, mask=mask)
elif load_cache_modifier is None:
iris.copy(
src_data + offsets,
dest_data + offsets,
cur_rank,
target_rank,
cur_rank,
heap_bases,
mask=mask,
store_cache_modifier=store_cache_modifier,
)
elif store_cache_modifier is None:
iris.copy(
src_data + offsets,
dest_data + offsets,
cur_rank,
target_rank,
cur_rank,
heap_bases,
mask=mask,
load_cache_modifier=load_cache_modifier,
)
else:
iris.copy(
src_data + offsets,
dest_data + offsets,
cur_rank,
target_rank,
cur_rank,
heap_bases,
mask=mask,
load_cache_modifier=load_cache_modifier,
store_cache_modifier=store_cache_modifier,
)


# Define cache modifiers for load and store operations
LOAD_CACHE_MODIFIERS = [None, "", ".ca", ".cg", ".cv"]
STORE_CACHE_MODIFIERS = [None, "", ".wb", ".cg", ".cs", ".wt"]


@pytest.mark.parametrize(
"load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS))
)
def test_copy_cache_modifiers(load_cache_modifier, store_cache_modifier):
"""Test copy operation with various cache modifiers"""
shmem = iris.iris(1 << 20)
num_ranks = shmem.get_num_ranks()
heap_bases = shmem.get_heap_bases()
cur_rank = shmem.get_rank()

BLOCK_SIZE = 16
data = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32)
base = cur_rank + num_ranks
for i in range(num_ranks):
data[i, :] = base * (i + 1)

results = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32)
grid = lambda meta: (1,)
copy_kernel[grid](
data, results, cur_rank, num_ranks, BLOCK_SIZE, heap_bases, load_cache_modifier, store_cache_modifier
)

shmem.barrier()

# Verify results - each rank copies its data to all other ranks
# After barrier, results[rank_id] should contain data from rank_id
for rank_id in range(num_ranks):
expected_value = (rank_id + num_ranks) * (rank_id + 1)
assert torch.allclose(
results[rank_id], torch.full((BLOCK_SIZE,), expected_value, dtype=torch.float32, device=results.device)
), (
f"Mismatch at rank {cur_rank}, slot {rank_id} with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}"
)
Loading
Loading