-
Notifications
You must be signed in to change notification settings - Fork 21
Add support for atomic RMW operations (tl.atomic_add, etc.) #200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-authored-by: Jokeren <2306281+Jokeren@users.noreply.github.com>
Co-authored-by: Jokeren <2306281+Jokeren@users.noreply.github.com>
Co-authored-by: Jokeren <2306281+Jokeren@users.noreply.github.com>
|
@copilot This is apparently wrong. The wrapper should be similar to atomic_cas. Also you don't have any tests |
Co-authored-by: Jokeren <2306281+Jokeren@users.noreply.github.com>
Fixed in commit 1c494b4. The |
|
@copilot Also test atomic_cas |
Co-authored-by: Jokeren <2306281+Jokeren@users.noreply.github.com>
|
@copilot run pre-commit |
Co-authored-by: Jokeren <2306281+Jokeren@users.noreply.github.com>
tl.atomic_addis not supported in triton-vizAtomicRMWoperation class todata.pyAtomicRMWto the op_list and mapping inpatch.pyOriginal prompt
This section details on the original issue you should resolve
<issue_title>[FEATURE] Support
tl.atomic_addinint8_matmul_quantization.py</issue_title><issue_description>```
Traceback (most recent call last):
File "/home/hwu27/workspace/triton-viz/.venv/bin/triton-sanitizer", line 10, in
sys.exit(apply())
^^^^^^^
File "/home/hwu27/workspace/triton-viz/triton_viz/wrapper.py", line 58, in apply
runpy.run_path(script, run_name="main")
File "", line 286, in run_path
File "", line 98, in _run_module_code
File "", line 88, in _run_code
File "int8_matmul_quantization.py", line 267, in
result_gold = test_quantize_and_matmul()
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "int8_matmul_quantization.py", line 253, in test_quantize_and_matmul
c_quantized = matmul_quantize_int8(fpa, b, b_scale)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "int8_matmul_quantization.py", line 194, in matmul_quantize_int8
return matmul_int8(a, a_scale, b, b_scale, out)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "int8_matmul_quantization.py", line 211, in matmul_int8
matmul_kernel[grid](
File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/runtime/jit.py", line 390, in
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 239, in run
benchmark()
File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 228, in benchmark
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 160, in _bench
return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/testing.py", line 149, in do_bench
fn()
File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 146, in kernel_call
self.fn.run(
File "/home/hwu27/workspace/triton-viz/triton_viz/core/trace.py", line 68, in run
ret = self.interpreter_fn.run(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/runtime/interpreter.py", line 1380, in run
return GridExecutor(fn, self.arg_names, grid)(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/hwu27/workspace/triton-viz/triton_viz/core/patch.py", line 488, in _grid_executor_call
run_grid_loops(grid)
File "/home/hwu27/workspace/triton-viz/triton_viz/core/patch.py", line 456, in run_grid_loops
self.fn(**call_args)
File "int8_matmul_quantization.py", line 188, in matmul_kernel
tl.atomic_add(c_ptrs, c, mask=c_mask)
File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/runtime/interpreter.py", line 781, in
new_member = lambda *args, member=member, **kwargs: (member(*args, **
^^^^^^^^^^^^^^^^
File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/language/core.py", line 42, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/language/core.py", line 2373, in atomic_add
return _semantic.atomic_add(pointer, val, mask, sem, scope)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/language/semantic.py", line 1423, in atomic_add
return self.tensor(self.builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem, scope),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/runtime/interpreter.py", line 679, in create_atomic_rmw
return TensorHandle(_interpreter.atomic_rmw(rmwOp, ptr.data, val.data, mask.data, sem), val.dtype.scalar)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: atomic_rmw(): incompatible function arguments. The following argument types are supported:
1. (arg0: triton._C.libtriton.interpreter.RMW_OP, arg1: typing.Annotated[numpy.typing.ArrayLike, numpy.uint64], arg2: numpy.ndarray, arg3: typing.Annotated[numpy.typing.ArrayLike, numpy.bool],...
tl.atomic_addinint8_matmul_quantization.py#149✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.