Skip to content

Commit d8f9a27

Browse files
Added transpose rhs option to ragged dot kernel, useful for implementing the backward function.
PiperOrigin-RevId: 826220830
1 parent 95468af commit d8f9a27

File tree

1 file changed

+84
-62
lines changed

1 file changed

+84
-62
lines changed

jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py

Lines changed: 84 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def ragged_dot(
9999
block_k: int,
100100
max_concurrent_steps: int,
101101
grid_block_n: int,
102+
transpose_rhs: bool = False,
102103
) -> jax.Array:
103104
if lhs.dtype != rhs.dtype:
104105
raise NotImplementedError(
@@ -107,6 +108,9 @@ def ragged_dot(
107108
m, k = lhs.shape
108109
g, k2, n = rhs.shape
109110

111+
if transpose_rhs:
112+
k2, n = n, k2
113+
110114
if group_sizes.shape[0] != g:
111115
raise ValueError(
112116
f"Expected group_sizes to have shape {g} but got {group_sizes.shape}"
@@ -134,7 +138,9 @@ def mn_loop(loop_info: plgpu.NDLoopInfo): # pylint: disable=unused-variable
134138
def acc_scope(acc_ref):
135139
plgpu.emit_pipeline(
136140
lambda _, lhs_smem, rhs_smem: plgpu.wgmma(
137-
acc_ref, lhs_smem, rhs_smem
141+
acc_ref,
142+
lhs_smem,
143+
plgpu.transpose_ref(rhs_smem, (1, 0)) if transpose_rhs else rhs_smem,
138144
),
139145
grid=(k // block_k,),
140146
in_specs=[
@@ -144,8 +150,8 @@ def acc_scope(acc_ref):
144150
delay_release=1,
145151
),
146152
plgpu.BlockSpec(
147-
(block_k, block_n),
148-
lambda k: (k, ni),
153+
(block_n, block_k) if transpose_rhs else (block_k, block_n),
154+
lambda k: (ni, k) if transpose_rhs else (k, ni),
149155
delay_release=1,
150156
),
151157
],
@@ -233,68 +239,84 @@ def _():
233239

234240

235241
def main(unused_argv):
236-
m, k, n, num_groups = 16 * 1024, 2048, 16 * 1024, 16
237-
kx, ky, kz = random.split(random.key(1234), num=3)
242+
for transpose_rhs in [False, True]:
243+
m, k, n, num_groups = 16 * 1024, 2048, 16 * 1024, 16
244+
kx, ky, kz = random.split(random.key(1234), num=3)
238245

239-
lhs = jax.random.normal(kx, (m, k), jnp.float16)
240-
rhs = jax.random.normal(ky, (num_groups, k, n), jnp.float16)
241-
group_boundaries = jax.lax.sort(
242-
jax.random.randint(kz, (num_groups - 1,), 0, m, jnp.int32)
243-
)
244-
group_starts = lax.concatenate(
245-
[jnp.array([0], dtype=jnp.int32), group_boundaries], 0
246-
)
247-
group_ends = lax.concatenate(
248-
[group_boundaries, jnp.array([m], dtype=jnp.int32)], 0
249-
)
250-
group_sizes = group_ends - group_starts
251-
assert group_sizes.shape == (num_groups,)
252-
253-
block_m = block_n = (64, 128, 192)
254-
block_k = (64,)
255-
max_concurrent_steps = (2, 4, 5, 6)
256-
grid_block_n = (1, 2, 4, 8, 16)
257-
configs = itertools.product(
258-
block_m, block_n, block_k, max_concurrent_steps, grid_block_n
259-
)
260-
names = (
261-
"block_m", "block_n", "block_k", "max_concurrent_steps", "grid_block_n"
262-
)
263-
best_runtime = float("inf")
264-
best_kwargs = {}
265-
for config in configs:
266-
kwargs = dict(zip(names, config))
267-
if n % (kwargs["grid_block_n"] * kwargs["block_n"]):
268-
continue
269-
try:
270-
f = functools.partial(ragged_dot, group_sizes=group_sizes, **kwargs)
271-
_, runtime = profiler.measure(f)(lhs, rhs)
272-
except ValueError as e:
273-
if "Mosaic GPU kernel exceeds available shared memory" not in str(e):
274-
raise
275-
runtime = float("inf")
276-
# Enable this to get more detailed information.
246+
lhs = jax.random.normal(kx, (m, k), jnp.float16)
247+
if transpose_rhs:
248+
rhs = jax.random.normal(ky, (num_groups, n, k), jnp.float16)
277249
else:
278-
print(" ".join(f"{k}={v}" for k, v in kwargs.items()), int(runtime * 1000))
279-
if runtime < best_runtime: # pytype: disable=unsupported-operands
280-
best_runtime = runtime
281-
best_kwargs = kwargs
282-
if not best_kwargs:
283-
raise ValueError("No valid configuration found")
284-
285-
ref, ref_runtime = profiler.measure(jax.lax.ragged_dot)(
286-
lhs, rhs, group_sizes=group_sizes
287-
)
288-
result = ragged_dot(lhs, rhs, group_sizes=group_sizes, **best_kwargs)
289-
np.testing.assert_allclose(result, ref, atol=1e-3, rtol=1e-3)
250+
rhs = jax.random.normal(ky, (num_groups, k, n), jnp.float16)
251+
group_boundaries = jax.lax.sort(
252+
jax.random.randint(kz, (num_groups - 1,), 0, m, jnp.int32)
253+
)
254+
group_starts = lax.concatenate(
255+
[jnp.array([0], dtype=jnp.int32), group_boundaries], 0
256+
)
257+
group_ends = lax.concatenate(
258+
[group_boundaries, jnp.array([m], dtype=jnp.int32)], 0
259+
)
260+
group_sizes = group_ends - group_starts
261+
assert group_sizes.shape == (num_groups,)
262+
263+
block_m = block_n = (64, 128, 192)
264+
block_k = (64,)
265+
max_concurrent_steps = (2, 4, 5, 6)
266+
grid_block_n = (1, 2, 4, 8, 16)
267+
configs = itertools.product(
268+
block_m, block_n, block_k, max_concurrent_steps, grid_block_n
269+
)
270+
names = (
271+
"block_m", "block_n", "block_k", "max_concurrent_steps", "grid_block_n"
272+
)
273+
best_runtime = float("inf")
274+
best_kwargs = {}
275+
for config in configs:
276+
kwargs = dict(zip(names, config))
277+
if n % (kwargs["grid_block_n"] * kwargs["block_n"]):
278+
continue
279+
try:
280+
f = functools.partial(
281+
ragged_dot, group_sizes=group_sizes, transpose_rhs=transpose_rhs,
282+
**kwargs
283+
)
284+
_, runtime = profiler.measure(f)(lhs, rhs)
285+
except ValueError as e:
286+
if "Mosaic GPU kernel exceeds available shared memory" not in str(e):
287+
raise
288+
runtime = float("inf")
289+
# Enable this to get more detailed information.
290+
else:
291+
print(" ".join(f"{k}={v}" for k, v in kwargs.items()), int(runtime * 1000))
292+
if runtime < best_runtime: # pytype: disable=unsupported-operands
293+
best_runtime = runtime
294+
best_kwargs = kwargs
295+
if not best_kwargs:
296+
raise ValueError("No valid configuration found")
297+
298+
def ref_ragged_dot(lhs, rhs, group_sizes):
299+
if transpose_rhs:
300+
rhs = jnp.transpose(rhs, (0, 2, 1))
301+
return jax.lax.ragged_dot(lhs, rhs, group_sizes=group_sizes)
302+
303+
ref, ref_runtime = profiler.measure(ref_ragged_dot)(
304+
lhs, rhs, group_sizes=group_sizes
305+
)
306+
result = ragged_dot(
307+
lhs, rhs, group_sizes=group_sizes, transpose_rhs=transpose_rhs,
308+
**best_kwargs
309+
)
310+
np.testing.assert_allclose(result, ref, atol=1e-3, rtol=1e-3)
290311

291-
tflops = float(2 * k * m * n) / (best_runtime / 1e3) / 1e12
292-
ref_tflops = float(2 * k * m * n) / (ref_runtime / 1e3) / 1e12
293-
print(
294-
"Best parameters: ", " ".join(f"{k}={v}" for k, v in best_kwargs.items())
295-
)
296-
print(f"Kernel: {best_runtime * 1000:.1f} us = {tflops:.1f} TFLOPS")
297-
print(f"Reference: {ref_runtime * 1000:.1f} us = {ref_tflops:.1f} TFLOPS")
312+
tflops = float(2 * k * m * n) / (best_runtime / 1e3) / 1e12
313+
ref_tflops = float(2 * k * m * n) / (ref_runtime / 1e3) / 1e12
314+
print(f"Transpose RHS: {transpose_rhs}")
315+
print(
316+
"Best parameters: ", " ".join(f"{k}={v}" for k, v in best_kwargs.items())
317+
)
318+
print(f"Kernel: {best_runtime * 1000:.1f} us = {tflops:.1f} TFLOPS")
319+
print(f"Reference: {ref_runtime * 1000:.1f} us = {ref_tflops:.1f} TFLOPS")
298320

299321

300322
if __name__ == "__main__":

0 commit comments

Comments
 (0)