Skip to content

Commit 619dbe1

Browse files
WindQAQGoogle-ML-Automation
authored andcommitted
Add dynamic grid with smem output test.
PiperOrigin-RevId: 826115707
1 parent 78fce19 commit 619dbe1

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

tests/pallas/tpu_pallas_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4001,6 +4001,32 @@ def kernel(x_ref, y_ref):
40014001
)(x)
40024002
np.testing.assert_array_equal(out, x.reshape(out_shape))
40034003

4004+
def test_dynamic_grid_with_smem_output(self):
4005+
if self.INTERPRET:
4006+
self.skipTest('Fail on interpreter.')
4007+
if not jtu.if_cloud_tpu_at_least(2025, 11, 3):
4008+
self.skipTest('Needs a newer libTPU')
4009+
4010+
def body(_, o_ref):
4011+
o_ref[0] = lax.cond(
4012+
pl.program_id(0) == 0, lambda: 1, lambda: o_ref[0] + 1
4013+
)
4014+
4015+
def wrapper_dynamic(n):
4016+
return self.pallas_call(
4017+
body,
4018+
out_shape=pltpu.SMEM((1,), dtype=jnp.int32),
4019+
grid_spec=pl.GridSpec(
4020+
grid=(n,),
4021+
in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)],
4022+
out_specs=pl.BlockSpec(memory_space=pltpu.SMEM),
4023+
),
4024+
)(n)
4025+
4026+
n = jax.random.randint(jax.random.key(0), (1,), 1, 10, dtype=jnp.int32)
4027+
compiled_kernel = jax.jit(wrapper_dynamic).lower(n).compile()
4028+
np.testing.assert_array_equal(compiled_kernel(n), n)
4029+
40044030

40054031
class MiscellaneousInterpretTest(MiscellaneousTest):
40064032
INTERPRET: bool = True

0 commit comments

Comments
 (0)