Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 84 additions & 62 deletions jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def ragged_dot(
block_k: int,
max_concurrent_steps: int,
grid_block_n: int,
transpose_rhs: bool = False,
) -> jax.Array:
if lhs.dtype != rhs.dtype:
raise NotImplementedError(
Expand All @@ -107,6 +108,9 @@ def ragged_dot(
m, k = lhs.shape
g, k2, n = rhs.shape

if transpose_rhs:
k2, n = n, k2

if group_sizes.shape[0] != g:
raise ValueError(
f"Expected group_sizes to have shape {g} but got {group_sizes.shape}"
Expand Down Expand Up @@ -134,7 +138,9 @@ def mn_loop(loop_info: plgpu.NDLoopInfo): # pylint: disable=unused-variable
def acc_scope(acc_ref):
plgpu.emit_pipeline(
lambda _, lhs_smem, rhs_smem: plgpu.wgmma(
acc_ref, lhs_smem, rhs_smem
acc_ref,
lhs_smem,
plgpu.transpose_ref(rhs_smem, (1, 0)) if transpose_rhs else rhs_smem,
),
grid=(k // block_k,),
in_specs=[
Expand All @@ -144,8 +150,8 @@ def acc_scope(acc_ref):
delay_release=1,
),
plgpu.BlockSpec(
(block_k, block_n),
lambda k: (k, ni),
(block_n, block_k) if transpose_rhs else (block_k, block_n),
lambda k: (ni, k) if transpose_rhs else (k, ni),
delay_release=1,
),
],
Expand Down Expand Up @@ -233,68 +239,84 @@ def _():


def main(unused_argv):
m, k, n, num_groups = 16 * 1024, 2048, 16 * 1024, 16
kx, ky, kz = random.split(random.key(1234), num=3)
for transpose_rhs in [False, True]:
m, k, n, num_groups = 16 * 1024, 2048, 16 * 1024, 16
kx, ky, kz = random.split(random.key(1234), num=3)

lhs = jax.random.normal(kx, (m, k), jnp.float16)
rhs = jax.random.normal(ky, (num_groups, k, n), jnp.float16)
group_boundaries = jax.lax.sort(
jax.random.randint(kz, (num_groups - 1,), 0, m, jnp.int32)
)
group_starts = lax.concatenate(
[jnp.array([0], dtype=jnp.int32), group_boundaries], 0
)
group_ends = lax.concatenate(
[group_boundaries, jnp.array([m], dtype=jnp.int32)], 0
)
group_sizes = group_ends - group_starts
assert group_sizes.shape == (num_groups,)

block_m = block_n = (64, 128, 192)
block_k = (64,)
max_concurrent_steps = (2, 4, 5, 6)
grid_block_n = (1, 2, 4, 8, 16)
configs = itertools.product(
block_m, block_n, block_k, max_concurrent_steps, grid_block_n
)
names = (
"block_m", "block_n", "block_k", "max_concurrent_steps", "grid_block_n"
)
best_runtime = float("inf")
best_kwargs = {}
for config in configs:
kwargs = dict(zip(names, config))
if n % (kwargs["grid_block_n"] * kwargs["block_n"]):
continue
try:
f = functools.partial(ragged_dot, group_sizes=group_sizes, **kwargs)
_, runtime = profiler.measure(f)(lhs, rhs)
except ValueError as e:
if "Mosaic GPU kernel exceeds available shared memory" not in str(e):
raise
runtime = float("inf")
# Enable this to get more detailed information.
lhs = jax.random.normal(kx, (m, k), jnp.float16)
if transpose_rhs:
rhs = jax.random.normal(ky, (num_groups, n, k), jnp.float16)
else:
print(" ".join(f"{k}={v}" for k, v in kwargs.items()), int(runtime * 1000))
if runtime < best_runtime: # pytype: disable=unsupported-operands
best_runtime = runtime
best_kwargs = kwargs
if not best_kwargs:
raise ValueError("No valid configuration found")

ref, ref_runtime = profiler.measure(jax.lax.ragged_dot)(
lhs, rhs, group_sizes=group_sizes
)
result = ragged_dot(lhs, rhs, group_sizes=group_sizes, **best_kwargs)
np.testing.assert_allclose(result, ref, atol=1e-3, rtol=1e-3)
rhs = jax.random.normal(ky, (num_groups, k, n), jnp.float16)
group_boundaries = jax.lax.sort(
jax.random.randint(kz, (num_groups - 1,), 0, m, jnp.int32)
)
group_starts = lax.concatenate(
[jnp.array([0], dtype=jnp.int32), group_boundaries], 0
)
group_ends = lax.concatenate(
[group_boundaries, jnp.array([m], dtype=jnp.int32)], 0
)
group_sizes = group_ends - group_starts
assert group_sizes.shape == (num_groups,)

block_m = block_n = (64, 128, 192)
block_k = (64,)
max_concurrent_steps = (2, 4, 5, 6)
grid_block_n = (1, 2, 4, 8, 16)
configs = itertools.product(
block_m, block_n, block_k, max_concurrent_steps, grid_block_n
)
names = (
"block_m", "block_n", "block_k", "max_concurrent_steps", "grid_block_n"
)
best_runtime = float("inf")
best_kwargs = {}
for config in configs:
kwargs = dict(zip(names, config))
if n % (kwargs["grid_block_n"] * kwargs["block_n"]):
continue
try:
f = functools.partial(
ragged_dot, group_sizes=group_sizes, transpose_rhs=transpose_rhs,
**kwargs
)
_, runtime = profiler.measure(f)(lhs, rhs)
except ValueError as e:
if "Mosaic GPU kernel exceeds available shared memory" not in str(e):
raise
runtime = float("inf")
# Enable this to get more detailed information.
else:
print(" ".join(f"{k}={v}" for k, v in kwargs.items()), int(runtime * 1000))
if runtime < best_runtime: # pytype: disable=unsupported-operands
best_runtime = runtime
best_kwargs = kwargs
if not best_kwargs:
raise ValueError("No valid configuration found")

def ref_ragged_dot(lhs, rhs, group_sizes):
if transpose_rhs:
rhs = jnp.transpose(rhs, (0, 2, 1))
return jax.lax.ragged_dot(lhs, rhs, group_sizes=group_sizes)

ref, ref_runtime = profiler.measure(ref_ragged_dot)(
lhs, rhs, group_sizes=group_sizes
)
result = ragged_dot(
lhs, rhs, group_sizes=group_sizes, transpose_rhs=transpose_rhs,
**best_kwargs
)
np.testing.assert_allclose(result, ref, atol=1e-3, rtol=1e-3)

tflops = float(2 * k * m * n) / (best_runtime / 1e3) / 1e12
ref_tflops = float(2 * k * m * n) / (ref_runtime / 1e3) / 1e12
print(
"Best parameters: ", " ".join(f"{k}={v}" for k, v in best_kwargs.items())
)
print(f"Kernel: {best_runtime * 1000:.1f} us = {tflops:.1f} TFLOPS")
print(f"Reference: {ref_runtime * 1000:.1f} us = {ref_tflops:.1f} TFLOPS")
tflops = float(2 * k * m * n) / (best_runtime / 1e3) / 1e12
ref_tflops = float(2 * k * m * n) / (ref_runtime / 1e3) / 1e12
print(f"Transpose RHS: {transpose_rhs}")
print(
"Best parameters: ", " ".join(f"{k}={v}" for k, v in best_kwargs.items())
)
print(f"Kernel: {best_runtime * 1000:.1f} us = {tflops:.1f} TFLOPS")
print(f"Reference: {ref_runtime * 1000:.1f} us = {ref_tflops:.1f} TFLOPS")


if __name__ == "__main__":
Expand Down
Loading