Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions tests/test_sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,3 +351,43 @@ def test_copy_kernel():
N,
TILE_N=128,
)


# ======== Atomic Operations Tests =========
@triton_viz.trace(clients=Sanitizer(abort_on_error=True))
@triton.jit
def atomic_add_kernel(
output_ptr,
value: tl.constexpr,
):
# Simple atomic add operation
tl.atomic_add(output_ptr, value)


def test_atomic_add():
"""Test that atomic_add operations work with the sanitizer."""
y = torch.zeros(1, dtype=torch.float32)
grid = (1,)
atomic_add_kernel[grid](y, value=5.0)
# Note: The sanitizer analyzes symbolically, so the actual value may not be updated
# This test verifies that the operation doesn't crash


@triton_viz.trace(clients=Sanitizer(abort_on_error=True))
@triton.jit
def atomic_cas_kernel(
output_ptr,
cmp_value: tl.constexpr,
new_value: tl.constexpr,
):
# Simple atomic compare-and-swap operation
tl.atomic_cas(output_ptr, cmp_value, new_value)


def test_atomic_cas():
"""Test that atomic_cas operations work with the sanitizer."""
y = torch.zeros(1, dtype=torch.float32)
grid = (1,)
atomic_cas_kernel[grid](y, cmp_value=0.0, new_value=5.0)
# Note: The sanitizer analyzes symbolically, so the actual value may not be updated
# This test verifies that the operation doesn't crash
15 changes: 14 additions & 1 deletion triton_viz/clients/sanitizer/sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
CumSum,
Bitcast,
AtomicCas,
AtomicRMW,
)
from ..utils import (
check_out_of_bounds_access,
Expand Down Expand Up @@ -723,7 +724,7 @@ class SymbolicExpr:
POINTER_OPS = ("make_block_ptr", "addptr", "advance")
BROADCAST_OPS = ("splat", "expand_dims", "broadcast", "reshape", "join")
CAST_OPS = ("cast_impl", "bitcast")
ATOMIC_OPS = ("atomic_cas",)
ATOMIC_OPS = ("atomic_cas", "atomic_rmw")
SUPPORTED_OPS = (
BASIC_OPS
+ INDIRECT_OPS
Expand Down Expand Up @@ -814,6 +815,7 @@ class SymbolicExpr:
"bitcast": Spec(req=("src", "dst_type"), post=_cast_impl_post),
# Atomic operations
"atomic_cas": Spec(req=("ptr", "cmp", "val")),
"atomic_rmw": Spec(req=("ptr", "val", "mask")),
# Misc
"advance": Spec(req=("ptr", "offsets")),
"umulhi": Spec(req=("lhs", "rhs")),
Expand Down Expand Up @@ -1975,6 +1977,16 @@ def op_atomic_cas_overrider(ptr, cmp, val, sem, scope):
result.sem = sem # Store sem as an attribute
return result

def op_atomic_rmw_overrider(rmwOp, ptr, val, mask, sem, scope):
ptr_sym = SymbolicExpr.from_value(ptr)
val_sym = SymbolicExpr.from_value(val)
mask_sym = SymbolicExpr.from_value(mask)
# rmwOp and sem are enums, not regular values, so we pass them directly
result = SymbolicExpr("atomic_rmw", ptr_sym, val_sym, mask_sym)
result.rmwOp = rmwOp # Store rmwOp as an attribute
result.sem = sem # Store sem as an attribute
return result

OP_TYPE_TO_OVERRIDER: dict[type[Op], Callable] = {
ProgramId: op_program_id_overrider,
RawLoad: op_raw_load_overrider,
Expand Down Expand Up @@ -2010,6 +2022,7 @@ def op_atomic_cas_overrider(ptr, cmp, val, sem, scope):
CumSum: op_cumsum_overrider,
Bitcast: op_bitcast_overrider,
AtomicCas: op_atomic_cas_overrider,
AtomicRMW: op_atomic_rmw_overrider,
}

if op_type in OP_TYPE_TO_OVERRIDER:
Expand Down
5 changes: 5 additions & 0 deletions triton_viz/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,11 @@ class AtomicCas(Op):
name: ClassVar[str] = "atomic_cas"


@dataclass
class AtomicRMW(Op):
name: ClassVar[str] = "atomic_rmw"


@dataclass
class Tensor:
ptr: int
Expand Down
4 changes: 4 additions & 0 deletions triton_viz/core/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
CumSum,
Bitcast,
AtomicCas,
AtomicRMW,
)
import inspect
import ast
Expand Down Expand Up @@ -89,6 +90,7 @@
CumSum,
Bitcast,
AtomicCas,
AtomicRMW,
]

# Hardcoded operation attribute names to avoid issues with lambda functions
Expand Down Expand Up @@ -123,6 +125,7 @@
Trans: "create_trans",
Bitcast: "create_bitcast",
AtomicCas: "create_atomic_cas",
AtomicRMW: "create_atomic_rmw",
}

original_ops = {
Expand Down Expand Up @@ -156,6 +159,7 @@
Trans: interpreter_builder.create_trans,
Bitcast: interpreter_builder.create_bitcast,
AtomicCas: interpreter_builder.create_atomic_cas,
AtomicRMW: interpreter_builder.create_atomic_rmw,
}
reduce_map: dict[type[Op], Callable] = {
ReduceMax: tl.max,
Expand Down