@@ -6729,6 +6729,101 @@ def _broadcast_in_dim_ragged_prop_rule(eqn_params, invar_raggedness, outvars):
67296729)
67306730
67316731
6732+ def tile (operand : ArrayLike , reps : Sequence [int ]) -> Array :
6733+ """Tiles an array by repeating it along each dimension.
6734+
6735+ Args:
6736+ operand: an array to tile.
6737+ reps: a sequence of integers representing the number of repeats for each
6738+ dimension. Must have the same length as ``operand.ndim``.
6739+
6740+ Returns:
6741+ A tiled array with shape ``(operand.shape[0] * reps[0], ...,
6742+ operand.shape[-1] * reps[-1])``.
6743+
6744+ Examples:
6745+ >>> x = jnp.array([[1, 2], [3, 4]])
6746+ >>> lax.tile(x, (2, 3))
6747+ Array([[1, 2, 1, 2, 1, 2],
6748+ [3, 4, 3, 4, 3, 4],
6749+ [1, 2, 1, 2, 1, 2],
6750+ [3, 4, 3, 4, 3, 4]], dtype=int32)
6751+
6752+ >>> y = jnp.array([1, 2, 3])
6753+ >>> lax.tile(y, (2,))
6754+ Array([1, 2, 3, 1, 2, 3], dtype=int32)
6755+
6756+ >>> z = jnp.array([[1], [2]])
6757+ >>> lax.tile(z, (1, 3))
6758+ Array([[1, 1, 1],
6759+ [2, 2, 2]], dtype=int32)
6760+ """
6761+ return tile_p .bind (operand , reps = tuple (reps ))
6762+
6763+
6764+ def _tile_abstract_eval (operand , * , reps ):
6765+ if len (reps ) != operand .ndim :
6766+ raise ValueError (
6767+ 'tile reps must have length equal to operand.ndim, '
6768+ f'got reps={ reps } for operand.ndim={ operand .ndim } '
6769+ )
6770+ out_shape = tuple (d * r for d , r in zip (operand .shape , reps ))
6771+ return operand .update (shape = out_shape )
6772+
6773+
6774+ def _tile_impl (operand , * , reps ):
6775+ out_shape = tuple (d * r for d , r in zip (operand .shape , reps ))
6776+ bcast_shape = []
6777+ bcast_dims = []
6778+ for d , r in zip (operand .shape , reps ):
6779+ if d == 1 or r == 1 :
6780+ bcast_dims .append (len (bcast_shape ))
6781+ bcast_shape .append (d * r )
6782+ else :
6783+ bcast_dims .append (len (bcast_shape ) + 1 )
6784+ bcast_shape .extend ((r , d ))
6785+ bcast = broadcast_in_dim (operand , tuple (bcast_shape ), tuple (bcast_dims ))
6786+ return reshape (bcast , out_shape )
6787+
6788+
6789+ def _tile_transpose (ct , operand , * , reps ):
6790+ assert ad .is_undefined_primal (operand )
6791+ if type (ct ) is ad_util .Zero :
6792+ return ad_util .Zero (operand .aval )
6793+ reshape_shape = []
6794+ reduce_dims = []
6795+ for d , r in zip (operand .aval .shape , reps ):
6796+ if r == 1 :
6797+ reshape_shape .append (d )
6798+ elif d == 1 :
6799+ reduce_dims .append (len (reshape_shape ))
6800+ reshape_shape .append (r )
6801+ else :
6802+ reduce_dims .append (len (reshape_shape ))
6803+ reshape_shape .extend ((r , d ))
6804+ reshaped_ct = reshape (ct , tuple (reshape_shape ))
6805+ return [reduce_sum (reshaped_ct , tuple (reduce_dims ))]
6806+
6807+
6808+ def _tile_batching_rule (batched_args , batch_dims , * , reps ):
6809+ (operand ,) = batched_args
6810+ (bdim ,) = batch_dims
6811+ if bdim is None :
6812+ return tile (operand , reps ), None
6813+ reps = list (reps )
6814+ reps .insert (bdim , 1 )
6815+ return tile (operand , reps ), bdim
6816+
6817+
6818+ tile_p = core .Primitive ('tile' )
6819+ tile_p .def_impl (_tile_impl )
6820+ tile_p .def_abstract_eval (_tile_abstract_eval )
6821+ ad .deflinear2 (tile_p , _tile_transpose )
6822+ batching .primitive_batchers [tile_p ] = _tile_batching_rule
6823+ mlir .register_lowering (
6824+ tile_p , mlir .lower_fun (_tile_impl , multiple_results = False ))
6825+
6826+
67326827def _clamp_shape_rule (min , operand , max ):
67336828 if min .shape and min .shape != operand .shape :
67346829 raise TypeError ("clamp requires min.shape == operand.shape or min.shape == "
0 commit comments