diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 643db331b28c..12438196f436 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -19,7 +19,6 @@ import contextlib import dataclasses import functools -import operator import string from typing import Any, Protocol, Self, TypeVar, cast @@ -33,8 +32,8 @@ from jax._src import custom_derivatives from jax._src import debugging from jax._src import dtypes -from jax._src import literals from jax._src import linear_util as lu +from jax._src import literals from jax._src import mesh as mesh_lib from jax._src import pjit from jax._src import prng @@ -50,7 +49,6 @@ from jax._src.lax import control_flow from jax._src.lax import lax as lax_internal from jax._src.lax.control_flow import BranchesPlatforms - from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith from jax._src.lib.mlir.dialects import cf @@ -1320,7 +1318,11 @@ def _index_to_start_size_stride( assert not isinstance(idx, slice) if isinstance(idx, indexing.Slice): start = _maybe_cast_to_index(cast_to_index, idx.start) - size = idx.size + size = ( + _make_index(idx.size) + if isinstance(idx.size, ir.Value) and cast_to_index + else idx.size + ) stride = idx.stride squeeze = False elif isinstance(idx, int): @@ -1380,27 +1382,6 @@ def _indexer_to_start_size_stride( ) -def _compute_squeezed_dims(source_shape: Sequence[int], target_shape: Sequence[int]) -> Sequence[bool]: - # This function only exists to align the ``tpu.memref_squeeze`` layout - # inference logic between Python and MLIR. - result = [] - source_index = len(source_shape) - 1 - target_index = len(target_shape) - 1 - while source_index >= 0 or target_index >= 0: - target_dim = target_shape[target_index] if target_index >= 0 else -1 - assert source_index >= 0 - if source_shape[source_index] == target_dim: - result.append(False) - source_index -= 1 - target_index -= 1 - else: - assert source_shape[source_index] == 1 - result.append(True) - source_index -= 1 - result.reverse() - return result - - def _slice_memref( ref: ir.Value, indexer: NDIndexer, @@ -1410,23 +1391,13 @@ def _slice_memref( assert ref_block_shape is not None starts, sizes, strides, squeeze_dims, ref_block_shape = ( _indexer_to_start_size_stride( - indexer, - ref_block_shape, - cast_to_index=False, + indexer, ref_block_shape, cast_to_index=False ) ) if not all((s is None or s == 1) for s in strides): raise NotImplementedError("Strided slices of references are unsupported.") ir_dynamic_size = ir.ShapedType.get_dynamic_size() - static_starts = [] - for s in starts: - if not isinstance(s, ir.Value): - static_starts.append(s) - elif (v := _fold_and_get_constant_value(s)) is not None: - static_starts.append(v) - else: - static_starts.append(ir_dynamic_size) static_sizes = [] dynamic_sizes = [] @@ -1440,36 +1411,19 @@ def _slice_memref( dynamic_sizes.append(s) ref_ty = ir.MemRefType(ref.type) - ref_strides, ref_offset = ref_ty.get_strides_and_offset() - if ref_offset == ir_dynamic_size or ir_dynamic_size in static_starts: - target_offset = ir_dynamic_size - else: - target_offset = sum( - map(operator.mul, static_starts, ref_strides), ref_offset - ) - out_layout = ir.StridedLayoutAttr.get(target_offset, ref_strides) out_ty = ir.MemRefType.get( - static_sizes, ref_ty.element_type, out_layout, ref_ty.memory_space + static_sizes, ref_ty.element_type, memory_space=ref_ty.memory_space ) out = tpu.memref_slice(out_ty, ref, starts, dynamic_sizes) if any(squeeze_dims): # We need to squeeze out some dimensions. ref_ty = out_ty del out_ty - ref_strides, ref_offset = ref_ty.get_strides_and_offset() - target_sizes = [dim for i, dim in enumerate(ref_ty.shape) if not squeeze_dims[i]] - del squeeze_dims - # We re-infer the squeezed dimensions to align with the tpu.memref_squeeze - # verification logic in MLIR in ambiguous cases, e.g. when squeezing - # from [1, 1, 128] to [1, 128]. - squeeze_dims = _compute_squeezed_dims(ref_ty.shape, target_sizes) - target_strides = [s for i, s in enumerate(ref_strides) if not squeeze_dims[i]] - out_layout = ir.StridedLayoutAttr.get(ref_offset, target_strides) + target_sizes = [ + dim for i, dim in enumerate(ref_ty.shape) if not squeeze_dims[i] + ] out_ty = ir.MemRefType.get( - target_sizes, - ref_ty.element_type, - out_layout, - ref_ty.memory_space, + target_sizes, ref_ty.element_type, memory_space=ref_ty.memory_space ) out = tpu.memref_squeeze(out_ty, out) return out, ref_block_shape @@ -2636,7 +2590,7 @@ def _fold(x, fuel): "arith.minsi": min, } if op_name == "arith.constant": - if ir.IntegerType.isinstance(x.type): + if ir.IntegerType.isinstance(x.type) or ir.IndexType.isinstance(x.type): return ir.IntegerAttr(x.owner.attributes["value"]).value elif ir.FloatType.isinstance(x.type): return ir.FloatAttr(x.owner.attributes["value"]).value diff --git a/jax/_src/pallas/mosaic/sc_lowering.py b/jax/_src/pallas/mosaic/sc_lowering.py index 92c318afbf50..a5d8c4a9fff7 100644 --- a/jax/_src/pallas/mosaic/sc_lowering.py +++ b/jax/_src/pallas/mosaic/sc_lowering.py @@ -17,6 +17,7 @@ import contextlib import dataclasses import functools +import operator from typing import Any, NoReturn, cast import jax @@ -82,7 +83,6 @@ class ScLoweringContext(tc_lowering.LoweringContext): LoweringRuleContext = tc_lowering.LoweringRuleContext -_transform_ref = tc_lowering._transform_ref _dtype_to_ir_type = tc_lowering._dtype_to_ir_type # pylint: disable=protected-access @@ -836,3 +836,87 @@ def _alloc_value( ) return memref.alloca(out_type, [], []) return tc_lowering._alloc_value(aval, ctx=ctx) + + +def _split_static_and_dynamic_values( + values: Sequence[ir.Value | Any], +) -> tuple[Sequence[Any], Sequence[ir.Value]]: + static_values = [] + dynamic_values = [] + for v in values: + if not isinstance(v, ir.Value): + static_values.append(v) + elif (c := tc_lowering._fold_and_get_constant_value(v)) is not None: + static_values.append(c) + else: + static_values.append(ir.ShapedType.get_dynamic_size()) + dynamic_values.append(v) + return static_values, dynamic_values + + +def _slice_memref( + ref: ir.Value, + indexer: indexing.NDIndexer, + ref_dtype: jax.typing.DTypeLike, + ref_block_shape: tuple[int | pallas_core.Squeezed, ...] | None, +) -> tuple[ir.Value, tuple[int | pallas_core.Squeezed, ...]]: + assert ref_block_shape is not None + starts, sizes, strides, squeeze_dims, ref_block_shape = ( + tc_lowering._indexer_to_start_size_stride( + indexer, ref_block_shape, cast_to_index=True + ) + ) + if not all((s is None or s == 1) for s in strides): + raise NotImplementedError("Strided slices of references are unsupported.") + + static_starts, dynamic_starts = _split_static_and_dynamic_values(starts) + static_sizes, dynamic_sizes = _split_static_and_dynamic_values(sizes) + + ref_ty = ir.MemRefType(ref.type) + ref_strides, ref_offset = ref_ty.get_strides_and_offset() + + ir_dynamic_size = ir.ShapedType.get_dynamic_size() + if ref_offset == ir_dynamic_size or ir_dynamic_size in static_starts: + out_offset = ir_dynamic_size + else: + out_offset = sum( + map(operator.mul, static_starts, ref_strides), ref_offset + ) + out_sizes = [s for i, s in enumerate(static_sizes) if not squeeze_dims[i]] + out_strides = [s for i, s in enumerate(ref_strides) if not squeeze_dims[i]] + out_layout = ir.StridedLayoutAttr.get(out_offset, out_strides) + out_ty = ir.MemRefType.get( + out_sizes, ref_ty.element_type, out_layout, ref_ty.memory_space + ) + + # We bypass ``memref.subview``, because we want to precisely control how the + # static/dynamic split is performed, since it affects the result layout. + out = memref.SubViewOp( + out_ty, + ref, + dynamic_starts, + dynamic_sizes, + [], + static_starts, + static_sizes, + static_strides=[1] * len(ref_strides), + ).result + return out, ref_block_shape + + +def _transform_ref( + ref: ir.Value, + ref_dtype: jax.typing.DTypeLike, + ref_block_shape: tuple[int | pallas_core.Squeezed, ...] | None, + transforms: Sequence[pallas_core.MemoryRefTransform], +) -> tuple[ir.Value, tuple[int | pallas_core.Squeezed, ...] | None]: + for transform in transforms: + if isinstance(transform, indexing.NDIndexer): + ref, ref_block_shape = _slice_memref( + ref, transform, ref_dtype, ref_block_shape + ) + else: + ref, ref_block_shape = tc_lowering._transform_ref( + ref, ref_dtype, ref_block_shape, [transform] + ) + return ref, ref_block_shape diff --git a/jax/_src/pallas/mosaic/sc_primitives.py b/jax/_src/pallas/mosaic/sc_primitives.py index 4cd276118425..d3923506ce15 100644 --- a/jax/_src/pallas/mosaic/sc_primitives.py +++ b/jax/_src/pallas/mosaic/sc_primitives.py @@ -270,7 +270,7 @@ def _gather_lowering_rule( ) if transforms: ref_block_shape, *_ = ctx.block_shapes - ref, _ = tc_lowering._transform_ref( + ref, _ = sc_lowering._transform_ref( ref, ref_aval.dtype, ref_block_shape, transforms ) [out_aval] = ctx.avals_out @@ -344,7 +344,7 @@ def _scatter_lowering_rule( ) if transforms: ref_block_shape, *_ = ctx.block_shapes - ref, _ = tc_lowering._transform_ref( + ref, _ = sc_lowering._transform_ref( ref, ref_aval.dtype, ref_block_shape, transforms ) tpu.vector_store_idx(x, ref, indices, mask=mask, add=add) diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 664548b22de1..8a56bda55faa 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -150,8 +150,6 @@ OpFoldResult BitcastVregOp::fold(FoldAdaptor adaptor) { LogicalResult MemRefSliceOp::verify() { auto source_type = getMemRefType(getMemRef()); auto target_type = getType(); - auto source_layout = source_type.getLayout(); - auto target_layout = target_type.getLayout(); auto target_memory_space = target_type.getMemorySpace(); auto indices = getBaseIdx(); auto slice_shape = getResult().getType().getShape(); @@ -184,43 +182,6 @@ LogicalResult MemRefSliceOp::verify() { return emitOpError( "Memory spaces must match if the target memory space is provided."); } - if (isa(source_layout) && - !isa(target_layout)) { - // TODO(slebedev): Remove this special-case once we move layout propagation - // to the infer-memref-layout pass. - } else if (isa(target_layout)) { - SmallVector source_strides; - int64_t source_offset; - if (failed( - source_type.getStridesAndOffset(source_strides, source_offset))) { - return failure(); - } - int64_t target_offset = source_offset; - if (target_offset != ShapedType::kDynamic) { - for (auto [base_idx, source_stride] : - llvm::zip(getBaseIdx(), source_strides)) { - if (auto idx = getConstantIntValue(base_idx)) { - target_offset += *idx * source_stride; - } else { - target_offset = ShapedType::kDynamic; - break; - } - } - } - auto expected_layout = - StridedLayoutAttr::get(getContext(), target_offset, source_strides); - if (target_layout != expected_layout) { - return emitOpError("Layout mismatch: got ") - << target_layout << ", expected " << expected_layout << "."; - } - } else { - bool is_target_layout_identity_map = - isa(target_layout) && target_layout.isIdentity(); - if (!is_target_layout_identity_map && target_layout != source_layout) { - return emitOpError( - "Layouts must match if the target layout is not an identity map."); - } - } if (getDynamicSizes().size() != target_type.getNumDynamicDims()) { return emitOpError( "Number of provided dynamic dimensions sizes must match the number of " @@ -325,33 +286,6 @@ LogicalResult MemRefSqueezeOp::verify() { return failure(); } - auto source_layout = source_type.getLayout(); - auto target_layout = target_type.getLayout(); - if (isa(source_layout) && - !isa(target_layout)) { - // TODO(slebedev): Remove this special-case once we move layout propagation - // to the infer-memref-layout pass. - } else if (isa(target_layout)) { - SmallVector source_strides; - int64_t source_offset; - if (failed( - source_type.getStridesAndOffset(source_strides, source_offset))) { - return failure(); - } - SmallVector target_strides; - for (auto [i, stride] : llvm::enumerate(source_strides)) { - if (!llvm::is_contained(*squeezed_or, i)) { - target_strides.push_back(stride); - } - } - auto expected_layout = - StridedLayoutAttr::get(getContext(), source_offset, target_strides); - if (target_layout != expected_layout) { - return emitOpError("Layout mismatch: got ") - << target_layout << ", expected " << expected_layout << "."; - } - } - auto erase_layout_op = getInput().getDefiningOp(); if (!erase_layout_op) { return success(); diff --git a/tests/pallas/tpu_pallas_distributed_test.py b/tests/pallas/tpu_pallas_distributed_test.py index 62faea36585d..dce5684901d9 100644 --- a/tests/pallas/tpu_pallas_distributed_test.py +++ b/tests/pallas/tpu_pallas_distributed_test.py @@ -750,9 +750,6 @@ def _(): class VerificationTest(jtu.JaxTestCase): def test_verification(self): - self.skipTest( - 'TODO(b/455847773): Fix MLIR layout mismatch in tpu.memref_slice (dynamic offset issue).' - ) if (num_devices := jax.local_device_count()) <= 1: self.skipTest('Test requires multiple devices.') if not jtu.is_device_tpu_at_least(4) or jax.devices()[0].num_cores > 1: