Skip to content

[FEATURE] Support tl.atomic_add in int8_matmul_quantization.py #149

@mark14wu

Description

@mark14wu
Traceback (most recent call last):
  File "/home/hwu27/workspace/triton-viz/.venv/bin/triton-sanitizer", line 10, in <module>
    sys.exit(apply())
             ^^^^^^^
  File "/home/hwu27/workspace/triton-viz/triton_viz/wrapper.py", line 58, in apply
    runpy.run_path(script, run_name="__main__")
  File "<frozen runpy>", line 286, in run_path
  File "<frozen runpy>", line 98, in _run_module_code
  File "<frozen runpy>", line 88, in _run_code
  File "int8_matmul_quantization.py", line 267, in <module>
    result_gold = test_quantize_and_matmul()
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "int8_matmul_quantization.py", line 253, in test_quantize_and_matmul
    c_quantized = matmul_quantize_int8(fpa, b, b_scale)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "int8_matmul_quantization.py", line 194, in matmul_quantize_int8
    return matmul_int8(a, a_scale, b, b_scale, out)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "int8_matmul_quantization.py", line 211, in matmul_int8
    matmul_kernel[grid](
  File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/runtime/jit.py", line 390, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 239, in run
    benchmark()
  File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 228, in benchmark
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 160, in _bench
    return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/testing.py", line 149, in do_bench
    fn()
  File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 146, in kernel_call
    self.fn.run(
  File "/home/hwu27/workspace/triton-viz/triton_viz/core/trace.py", line 68, in run
    ret = self.interpreter_fn.run(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/runtime/interpreter.py", line 1380, in run
    return GridExecutor(fn, self.arg_names, grid)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hwu27/workspace/triton-viz/triton_viz/core/patch.py", line 488, in _grid_executor_call
    run_grid_loops(grid)
  File "/home/hwu27/workspace/triton-viz/triton_viz/core/patch.py", line 456, in run_grid_loops
    self.fn(**call_args)
  File "int8_matmul_quantization.py", line 188, in matmul_kernel
    tl.atomic_add(c_ptrs, c, mask=c_mask)
  File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/runtime/interpreter.py", line 781, in <lambda>
    new_member = lambda *args, member=member, **kwargs: (member(*args, **
                                                         ^^^^^^^^^^^^^^^^
  File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/language/core.py", line 42, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/language/core.py", line 2373, in atomic_add
    return _semantic.atomic_add(pointer, val, mask, sem, scope)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/language/semantic.py", line 1423, in atomic_add
    return self.tensor(self.builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem, scope),
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hwu27/workspace/triton-viz/.venv/lib/python3.12/site-packages/triton/runtime/interpreter.py", line 679, in create_atomic_rmw
    return TensorHandle(_interpreter.atomic_rmw(rmwOp, ptr.data, val.data, mask.data, sem), val.dtype.scalar)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: atomic_rmw(): incompatible function arguments. The following argument types are supported:
    1. (arg0: triton._C.libtriton.interpreter.RMW_OP, arg1: typing.Annotated[numpy.typing.ArrayLike, numpy.uint64], arg2: numpy.ndarray, arg3: typing.Annotated[numpy.typing.ArrayLike, numpy.bool], arg4: triton._C.libtriton.interpreter.MEM_SEMANTIC) -> numpy.ndarray

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions