From d8f9a27ae29dd22285ff052e68257b28dc653871 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 30 Oct 2025 15:59:06 -0700 Subject: [PATCH] Added transpose rhs option to ragged dot kernel, useful for implementing the backward function. PiperOrigin-RevId: 826220830 --- .../pallas/ops/gpu/ragged_dot_mgpu.py | 146 ++++++++++-------- 1 file changed, 84 insertions(+), 62 deletions(-) diff --git a/jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py b/jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py index 2b63787ac7d0..1b42f519f6bf 100644 --- a/jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py @@ -99,6 +99,7 @@ def ragged_dot( block_k: int, max_concurrent_steps: int, grid_block_n: int, + transpose_rhs: bool = False, ) -> jax.Array: if lhs.dtype != rhs.dtype: raise NotImplementedError( @@ -107,6 +108,9 @@ def ragged_dot( m, k = lhs.shape g, k2, n = rhs.shape + if transpose_rhs: + k2, n = n, k2 + if group_sizes.shape[0] != g: raise ValueError( f"Expected group_sizes to have shape {g} but got {group_sizes.shape}" @@ -134,7 +138,9 @@ def mn_loop(loop_info: plgpu.NDLoopInfo): # pylint: disable=unused-variable def acc_scope(acc_ref): plgpu.emit_pipeline( lambda _, lhs_smem, rhs_smem: plgpu.wgmma( - acc_ref, lhs_smem, rhs_smem + acc_ref, + lhs_smem, + plgpu.transpose_ref(rhs_smem, (1, 0)) if transpose_rhs else rhs_smem, ), grid=(k // block_k,), in_specs=[ @@ -144,8 +150,8 @@ def acc_scope(acc_ref): delay_release=1, ), plgpu.BlockSpec( - (block_k, block_n), - lambda k: (k, ni), + (block_n, block_k) if transpose_rhs else (block_k, block_n), + lambda k: (ni, k) if transpose_rhs else (k, ni), delay_release=1, ), ], @@ -233,68 +239,84 @@ def _(): def main(unused_argv): - m, k, n, num_groups = 16 * 1024, 2048, 16 * 1024, 16 - kx, ky, kz = random.split(random.key(1234), num=3) + for transpose_rhs in [False, True]: + m, k, n, num_groups = 16 * 1024, 2048, 16 * 1024, 16 + kx, ky, kz = random.split(random.key(1234), num=3) - lhs = jax.random.normal(kx, (m, k), jnp.float16) - rhs = jax.random.normal(ky, (num_groups, k, n), jnp.float16) - group_boundaries = jax.lax.sort( - jax.random.randint(kz, (num_groups - 1,), 0, m, jnp.int32) - ) - group_starts = lax.concatenate( - [jnp.array([0], dtype=jnp.int32), group_boundaries], 0 - ) - group_ends = lax.concatenate( - [group_boundaries, jnp.array([m], dtype=jnp.int32)], 0 - ) - group_sizes = group_ends - group_starts - assert group_sizes.shape == (num_groups,) - - block_m = block_n = (64, 128, 192) - block_k = (64,) - max_concurrent_steps = (2, 4, 5, 6) - grid_block_n = (1, 2, 4, 8, 16) - configs = itertools.product( - block_m, block_n, block_k, max_concurrent_steps, grid_block_n - ) - names = ( - "block_m", "block_n", "block_k", "max_concurrent_steps", "grid_block_n" - ) - best_runtime = float("inf") - best_kwargs = {} - for config in configs: - kwargs = dict(zip(names, config)) - if n % (kwargs["grid_block_n"] * kwargs["block_n"]): - continue - try: - f = functools.partial(ragged_dot, group_sizes=group_sizes, **kwargs) - _, runtime = profiler.measure(f)(lhs, rhs) - except ValueError as e: - if "Mosaic GPU kernel exceeds available shared memory" not in str(e): - raise - runtime = float("inf") - # Enable this to get more detailed information. + lhs = jax.random.normal(kx, (m, k), jnp.float16) + if transpose_rhs: + rhs = jax.random.normal(ky, (num_groups, n, k), jnp.float16) else: - print(" ".join(f"{k}={v}" for k, v in kwargs.items()), int(runtime * 1000)) - if runtime < best_runtime: # pytype: disable=unsupported-operands - best_runtime = runtime - best_kwargs = kwargs - if not best_kwargs: - raise ValueError("No valid configuration found") - - ref, ref_runtime = profiler.measure(jax.lax.ragged_dot)( - lhs, rhs, group_sizes=group_sizes - ) - result = ragged_dot(lhs, rhs, group_sizes=group_sizes, **best_kwargs) - np.testing.assert_allclose(result, ref, atol=1e-3, rtol=1e-3) + rhs = jax.random.normal(ky, (num_groups, k, n), jnp.float16) + group_boundaries = jax.lax.sort( + jax.random.randint(kz, (num_groups - 1,), 0, m, jnp.int32) + ) + group_starts = lax.concatenate( + [jnp.array([0], dtype=jnp.int32), group_boundaries], 0 + ) + group_ends = lax.concatenate( + [group_boundaries, jnp.array([m], dtype=jnp.int32)], 0 + ) + group_sizes = group_ends - group_starts + assert group_sizes.shape == (num_groups,) + + block_m = block_n = (64, 128, 192) + block_k = (64,) + max_concurrent_steps = (2, 4, 5, 6) + grid_block_n = (1, 2, 4, 8, 16) + configs = itertools.product( + block_m, block_n, block_k, max_concurrent_steps, grid_block_n + ) + names = ( + "block_m", "block_n", "block_k", "max_concurrent_steps", "grid_block_n" + ) + best_runtime = float("inf") + best_kwargs = {} + for config in configs: + kwargs = dict(zip(names, config)) + if n % (kwargs["grid_block_n"] * kwargs["block_n"]): + continue + try: + f = functools.partial( + ragged_dot, group_sizes=group_sizes, transpose_rhs=transpose_rhs, + **kwargs + ) + _, runtime = profiler.measure(f)(lhs, rhs) + except ValueError as e: + if "Mosaic GPU kernel exceeds available shared memory" not in str(e): + raise + runtime = float("inf") + # Enable this to get more detailed information. + else: + print(" ".join(f"{k}={v}" for k, v in kwargs.items()), int(runtime * 1000)) + if runtime < best_runtime: # pytype: disable=unsupported-operands + best_runtime = runtime + best_kwargs = kwargs + if not best_kwargs: + raise ValueError("No valid configuration found") + + def ref_ragged_dot(lhs, rhs, group_sizes): + if transpose_rhs: + rhs = jnp.transpose(rhs, (0, 2, 1)) + return jax.lax.ragged_dot(lhs, rhs, group_sizes=group_sizes) + + ref, ref_runtime = profiler.measure(ref_ragged_dot)( + lhs, rhs, group_sizes=group_sizes + ) + result = ragged_dot( + lhs, rhs, group_sizes=group_sizes, transpose_rhs=transpose_rhs, + **best_kwargs + ) + np.testing.assert_allclose(result, ref, atol=1e-3, rtol=1e-3) - tflops = float(2 * k * m * n) / (best_runtime / 1e3) / 1e12 - ref_tflops = float(2 * k * m * n) / (ref_runtime / 1e3) / 1e12 - print( - "Best parameters: ", " ".join(f"{k}={v}" for k, v in best_kwargs.items()) - ) - print(f"Kernel: {best_runtime * 1000:.1f} us = {tflops:.1f} TFLOPS") - print(f"Reference: {ref_runtime * 1000:.1f} us = {ref_tflops:.1f} TFLOPS") + tflops = float(2 * k * m * n) / (best_runtime / 1e3) / 1e12 + ref_tflops = float(2 * k * m * n) / (ref_runtime / 1e3) / 1e12 + print(f"Transpose RHS: {transpose_rhs}") + print( + "Best parameters: ", " ".join(f"{k}={v}" for k, v in best_kwargs.items()) + ) + print(f"Kernel: {best_runtime * 1000:.1f} us = {tflops:.1f} TFLOPS") + print(f"Reference: {ref_runtime * 1000:.1f} us = {ref_tflops:.1f} TFLOPS") if __name__ == "__main__":