Skip to content

Commit f6101c5

Browse files
sharadmvGoogle-ML-Automation
authored andcommitted
Add lax.tile_p
PiperOrigin-RevId: 825382043
1 parent 2d545c3 commit f6101c5

File tree

6 files changed

+180
-5
lines changed

6 files changed

+180
-5
lines changed

jax/_src/lax/lax.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6729,6 +6729,101 @@ def _broadcast_in_dim_ragged_prop_rule(eqn_params, invar_raggedness, outvars):
67296729
)
67306730

67316731

6732+
def tile(operand: ArrayLike, reps: Sequence[int]) -> Array:
6733+
"""Tiles an array by repeating it along each dimension.
6734+
6735+
Args:
6736+
operand: an array to tile.
6737+
reps: a sequence of integers representing the number of repeats for each
6738+
dimension. Must have the same length as ``operand.ndim``.
6739+
6740+
Returns:
6741+
A tiled array with shape ``(operand.shape[0] * reps[0], ...,
6742+
operand.shape[-1] * reps[-1])``.
6743+
6744+
Examples:
6745+
>>> x = jnp.array([[1, 2], [3, 4]])
6746+
>>> lax.tile(x, (2, 3))
6747+
Array([[1, 2, 1, 2, 1, 2],
6748+
[3, 4, 3, 4, 3, 4],
6749+
[1, 2, 1, 2, 1, 2],
6750+
[3, 4, 3, 4, 3, 4]], dtype=int32)
6751+
6752+
>>> y = jnp.array([1, 2, 3])
6753+
>>> lax.tile(y, (2,))
6754+
Array([1, 2, 3, 1, 2, 3], dtype=int32)
6755+
6756+
>>> z = jnp.array([[1], [2]])
6757+
>>> lax.tile(z, (1, 3))
6758+
Array([[1, 1, 1],
6759+
[2, 2, 2]], dtype=int32)
6760+
"""
6761+
return tile_p.bind(operand, reps=tuple(reps))
6762+
6763+
6764+
def _tile_abstract_eval(operand, *, reps):
6765+
if len(reps) != operand.ndim:
6766+
raise ValueError(
6767+
'tile reps must have length equal to operand.ndim, '
6768+
f'got reps={reps} for operand.ndim={operand.ndim}'
6769+
)
6770+
out_shape = tuple(d * r for d, r in zip(operand.shape, reps))
6771+
return operand.update(shape=out_shape)
6772+
6773+
6774+
def _tile_impl(operand, *, reps):
6775+
out_shape = tuple(d * r for d, r in zip(operand.shape, reps))
6776+
bcast_shape = []
6777+
bcast_dims = []
6778+
for d, r in zip(operand.shape, reps):
6779+
if d == 1 or r == 1:
6780+
bcast_dims.append(len(bcast_shape))
6781+
bcast_shape.append(d * r)
6782+
else:
6783+
bcast_dims.append(len(bcast_shape) + 1)
6784+
bcast_shape.extend((r, d))
6785+
bcast = broadcast_in_dim(operand, tuple(bcast_shape), tuple(bcast_dims))
6786+
return reshape(bcast, out_shape)
6787+
6788+
6789+
def _tile_transpose(ct, operand, *, reps):
6790+
assert ad.is_undefined_primal(operand)
6791+
if type(ct) is ad_util.Zero:
6792+
return ad_util.Zero(operand.aval)
6793+
reshape_shape = []
6794+
reduce_dims = []
6795+
for d, r in zip(operand.aval.shape, reps):
6796+
if r == 1:
6797+
reshape_shape.append(d)
6798+
elif d == 1:
6799+
reduce_dims.append(len(reshape_shape))
6800+
reshape_shape.append(r)
6801+
else:
6802+
reduce_dims.append(len(reshape_shape))
6803+
reshape_shape.extend((r, d))
6804+
reshaped_ct = reshape(ct, tuple(reshape_shape))
6805+
return [reduce_sum(reshaped_ct, tuple(reduce_dims))]
6806+
6807+
6808+
def _tile_batching_rule(batched_args, batch_dims, *, reps):
6809+
(operand,) = batched_args
6810+
(bdim,) = batch_dims
6811+
if bdim is None:
6812+
return tile(operand, reps), None
6813+
reps = list(reps)
6814+
reps.insert(bdim, 1)
6815+
return tile(operand, reps), bdim
6816+
6817+
6818+
tile_p = core.Primitive('tile')
6819+
tile_p.def_impl(_tile_impl)
6820+
tile_p.def_abstract_eval(_tile_abstract_eval)
6821+
ad.deflinear2(tile_p, _tile_transpose)
6822+
batching.primitive_batchers[tile_p] = _tile_batching_rule
6823+
mlir.register_lowering(
6824+
tile_p, mlir.lower_fun(_tile_impl, multiple_results=False))
6825+
6826+
67326827
def _clamp_shape_rule(min, operand, max):
67336828
if min.shape and min.shape != operand.shape:
67346829
raise TypeError("clamp requires min.shape == operand.shape or min.shape == "

jax/_src/numpy/lax_numpy.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4519,11 +4519,13 @@ def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array:
45194519
reps_tup = tuple(reps) # type: ignore[arg-type]
45204520
reps_tup = tuple(operator.index(rep) if core.is_constant_dim(rep) else rep
45214521
for rep in reps_tup)
4522-
A_shape = (1,) * (len(reps_tup) - np.ndim(A)) + np.shape(A)
4523-
reps_tup = (1,) * (len(A_shape) - len(reps_tup)) + reps_tup
4524-
result = broadcast_to(reshape(A, [j for i in A_shape for j in [1, i]]),
4525-
[k for pair in zip(reps_tup, A_shape) for k in pair])
4526-
return reshape(result, tuple(np.multiply(A_shape, reps_tup)))
4522+
4523+
# Prepend 1s to reps to match A.ndim
4524+
if len(reps_tup) < A.ndim:
4525+
reps_tup = (1,) * (A.ndim - len(reps_tup)) + reps_tup
4526+
if len(reps_tup) > A.ndim:
4527+
A = lax.expand_dims(A, list(range(len(reps_tup) - A.ndim)))
4528+
return lax.tile(A, reps_tup)
45274529

45284530
def _concatenate_array(arr: ArrayLike, axis: int | None,
45294531
dtype: DTypeLike | None = None) -> Array:

jax/lax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,8 @@
229229
tan_p as tan_p,
230230
tanh as tanh,
231231
tanh_p as tanh_p,
232+
tile as tile,
233+
tile_p as tile_p,
232234
top_k as top_k,
233235
top_k_p as top_k_p,
234236
transpose as transpose,

tests/lax_autodiff_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,5 +1180,21 @@ def testPowShapeMismatch(self):
11801180
self.assertArraysEqual(actual, expected)
11811181

11821182

1183+
@jtu.sample_product(
1184+
[
1185+
dict(arg_shape=arg_shape, reps=reps)
1186+
for arg_shape, reps in [
1187+
[(3,), (2,)],
1188+
[(2, 3), (1, 2)],
1189+
]
1190+
],
1191+
dtype=grad_float_dtypes,
1192+
)
1193+
def testTileAutodiff(self, arg_shape, reps, dtype):
1194+
rng = jtu.rand_default(self.rng())
1195+
args_maker = lambda: [rng(arg_shape, dtype)]
1196+
op = lambda x: lax.tile(x, reps)
1197+
check_grads(op, args_maker(), order=3, modes=["fwd", "rev"], eps=1.)
1198+
11831199
if __name__ == '__main__':
11841200
absltest.main(testLoader=jtu.JaxTestLoader())

tests/lax_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1526,6 +1526,26 @@ def testBroadcastInDimAgainstNumpy(self, inshape, dtype, outshape, dimensions):
15261526
numpy_op = lambda x: lax_reference.broadcast_in_dim(x, outshape, dimensions)
15271527
self._CheckAgainstNumpy(numpy_op, op, args_maker)
15281528

1529+
@jtu.sample_product(
1530+
[
1531+
dict(arg_shape=arg_shape, reps=reps)
1532+
for arg_shape, reps in [
1533+
[(3,), (2,)],
1534+
[(2, 3), (1, 2)],
1535+
[(2, 3), (2, 1)],
1536+
[(2, 1, 3), (1, 2, 3)],
1537+
]
1538+
],
1539+
dtype=lax_test_util.default_dtypes,
1540+
)
1541+
def testTile(self, arg_shape, reps, dtype):
1542+
rng = jtu.rand_default(self.rng())
1543+
args_maker = lambda: [rng(arg_shape, dtype)]
1544+
op = lambda x: lax.tile(x, reps)
1545+
numpy_op = lambda x: np.tile(x, reps)
1546+
self._CompileAndCheck(op, args_maker)
1547+
self._CheckAgainstNumpy(numpy_op, op, args_maker)
1548+
15291549
@parameterized.parameters(
15301550
{"inshape": inshape, "dimensions": dimensions, "error_type": error_type,
15311551
"err_msg": err_msg}

tests/lax_vmap_test.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,5 +791,45 @@ def g(a, b):
791791
self.assertAllClose(output, expected, check_dtypes=False)
792792

793793

794+
@jtu.sample_product(
795+
[
796+
dict(arg_shape=arg_shape, reps=reps)
797+
for arg_shape, reps in [
798+
[(3,), (2,)],
799+
[(2, 3), (1, 2)],
800+
[(2, 3), (2, 1)],
801+
[(2, 1, 3), (1, 2, 3)],
802+
]
803+
],
804+
in_axes=[0, 1, -1],
805+
out_axes=[0, 1, -1],
806+
)
807+
def testTileBatching(self, arg_shape, reps, in_axes, out_axes):
808+
rng = jtu.rand_default(self.rng())
809+
dtype = np.float32
810+
args_maker = lambda: [rng(arg_shape, dtype)]
811+
op = lambda x: lax.tile(x, reps)
812+
args = args_maker()
813+
814+
# Construct batched arguments based on in_axes
815+
if in_axes == 0:
816+
batched_args = [jnp.stack([arg, arg], axis=0) for arg in args]
817+
elif in_axes == 1:
818+
batched_args = [jnp.stack([arg, arg], axis=1) for arg in args]
819+
else: # in_axes == -1
820+
batched_args = [jnp.stack([arg, arg], axis=-1) for arg in args]
821+
822+
# Compute expected output
823+
out = op(*args)
824+
if out_axes == 0:
825+
expected = jnp.stack([out, out], axis=0)
826+
elif out_axes == 1:
827+
expected = jnp.stack([out, out], axis=1)
828+
else: # out_axes == -1
829+
expected = jnp.stack([out, out], axis=-1)
830+
831+
actual = jax.vmap(op, in_axes=in_axes, out_axes=out_axes)(*batched_args)
832+
self.assertAllClose(expected, actual)
833+
794834
if __name__ == '__main__':
795835
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)