Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 13 additions & 59 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import contextlib
import dataclasses
import functools
import operator
import string
from typing import Any, Protocol, Self, TypeVar, cast

Expand All @@ -33,8 +32,8 @@
from jax._src import custom_derivatives
from jax._src import debugging
from jax._src import dtypes
from jax._src import literals
from jax._src import linear_util as lu
from jax._src import literals
from jax._src import mesh as mesh_lib
from jax._src import pjit
from jax._src import prng
Expand All @@ -50,7 +49,6 @@
from jax._src.lax import control_flow
from jax._src.lax import lax as lax_internal
from jax._src.lax.control_flow import BranchesPlatforms

from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import cf
Expand Down Expand Up @@ -1320,7 +1318,11 @@ def _index_to_start_size_stride(
assert not isinstance(idx, slice)
if isinstance(idx, indexing.Slice):
start = _maybe_cast_to_index(cast_to_index, idx.start)
size = idx.size
size = (
_make_index(idx.size)
if isinstance(idx.size, ir.Value) and cast_to_index
else idx.size
)
stride = idx.stride
squeeze = False
elif isinstance(idx, int):
Expand Down Expand Up @@ -1380,27 +1382,6 @@ def _indexer_to_start_size_stride(
)


def _compute_squeezed_dims(source_shape: Sequence[int], target_shape: Sequence[int]) -> Sequence[bool]:
# This function only exists to align the ``tpu.memref_squeeze`` layout
# inference logic between Python and MLIR.
result = []
source_index = len(source_shape) - 1
target_index = len(target_shape) - 1
while source_index >= 0 or target_index >= 0:
target_dim = target_shape[target_index] if target_index >= 0 else -1
assert source_index >= 0
if source_shape[source_index] == target_dim:
result.append(False)
source_index -= 1
target_index -= 1
else:
assert source_shape[source_index] == 1
result.append(True)
source_index -= 1
result.reverse()
return result


def _slice_memref(
ref: ir.Value,
indexer: NDIndexer,
Expand All @@ -1410,23 +1391,13 @@ def _slice_memref(
assert ref_block_shape is not None
starts, sizes, strides, squeeze_dims, ref_block_shape = (
_indexer_to_start_size_stride(
indexer,
ref_block_shape,
cast_to_index=False,
indexer, ref_block_shape, cast_to_index=False
)
)
if not all((s is None or s == 1) for s in strides):
raise NotImplementedError("Strided slices of references are unsupported.")

ir_dynamic_size = ir.ShapedType.get_dynamic_size()
static_starts = []
for s in starts:
if not isinstance(s, ir.Value):
static_starts.append(s)
elif (v := _fold_and_get_constant_value(s)) is not None:
static_starts.append(v)
else:
static_starts.append(ir_dynamic_size)

static_sizes = []
dynamic_sizes = []
Expand All @@ -1440,36 +1411,19 @@ def _slice_memref(
dynamic_sizes.append(s)

ref_ty = ir.MemRefType(ref.type)
ref_strides, ref_offset = ref_ty.get_strides_and_offset()
if ref_offset == ir_dynamic_size or ir_dynamic_size in static_starts:
target_offset = ir_dynamic_size
else:
target_offset = sum(
map(operator.mul, static_starts, ref_strides), ref_offset
)
out_layout = ir.StridedLayoutAttr.get(target_offset, ref_strides)
out_ty = ir.MemRefType.get(
static_sizes, ref_ty.element_type, out_layout, ref_ty.memory_space
static_sizes, ref_ty.element_type, memory_space=ref_ty.memory_space
)
out = tpu.memref_slice(out_ty, ref, starts, dynamic_sizes)
if any(squeeze_dims):
# We need to squeeze out some dimensions.
ref_ty = out_ty
del out_ty
ref_strides, ref_offset = ref_ty.get_strides_and_offset()
target_sizes = [dim for i, dim in enumerate(ref_ty.shape) if not squeeze_dims[i]]
del squeeze_dims
# We re-infer the squeezed dimensions to align with the tpu.memref_squeeze
# verification logic in MLIR in ambiguous cases, e.g. when squeezing
# from [1, 1, 128] to [1, 128].
squeeze_dims = _compute_squeezed_dims(ref_ty.shape, target_sizes)
target_strides = [s for i, s in enumerate(ref_strides) if not squeeze_dims[i]]
out_layout = ir.StridedLayoutAttr.get(ref_offset, target_strides)
target_sizes = [
dim for i, dim in enumerate(ref_ty.shape) if not squeeze_dims[i]
]
out_ty = ir.MemRefType.get(
target_sizes,
ref_ty.element_type,
out_layout,
ref_ty.memory_space,
target_sizes, ref_ty.element_type, memory_space=ref_ty.memory_space
)
out = tpu.memref_squeeze(out_ty, out)
return out, ref_block_shape
Expand Down Expand Up @@ -2636,7 +2590,7 @@ def _fold(x, fuel):
"arith.minsi": min,
}
if op_name == "arith.constant":
if ir.IntegerType.isinstance(x.type):
if ir.IntegerType.isinstance(x.type) or ir.IndexType.isinstance(x.type):
return ir.IntegerAttr(x.owner.attributes["value"]).value
elif ir.FloatType.isinstance(x.type):
return ir.FloatAttr(x.owner.attributes["value"]).value
Expand Down
86 changes: 85 additions & 1 deletion jax/_src/pallas/mosaic/sc_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import contextlib
import dataclasses
import functools
import operator
from typing import Any, NoReturn, cast

import jax
Expand Down Expand Up @@ -82,7 +83,6 @@ class ScLoweringContext(tc_lowering.LoweringContext):

LoweringRuleContext = tc_lowering.LoweringRuleContext

_transform_ref = tc_lowering._transform_ref
_dtype_to_ir_type = tc_lowering._dtype_to_ir_type

# pylint: disable=protected-access
Expand Down Expand Up @@ -836,3 +836,87 @@ def _alloc_value(
)
return memref.alloca(out_type, [], [])
return tc_lowering._alloc_value(aval, ctx=ctx)


def _split_static_and_dynamic_values(
values: Sequence[ir.Value | Any],
) -> tuple[Sequence[Any], Sequence[ir.Value]]:
static_values = []
dynamic_values = []
for v in values:
if not isinstance(v, ir.Value):
static_values.append(v)
elif (c := tc_lowering._fold_and_get_constant_value(v)) is not None:
static_values.append(c)
else:
static_values.append(ir.ShapedType.get_dynamic_size())
dynamic_values.append(v)
return static_values, dynamic_values


def _slice_memref(
ref: ir.Value,
indexer: indexing.NDIndexer,
ref_dtype: jax.typing.DTypeLike,
ref_block_shape: tuple[int | pallas_core.Squeezed, ...] | None,
) -> tuple[ir.Value, tuple[int | pallas_core.Squeezed, ...]]:
assert ref_block_shape is not None
starts, sizes, strides, squeeze_dims, ref_block_shape = (
tc_lowering._indexer_to_start_size_stride(
indexer, ref_block_shape, cast_to_index=True
)
)
if not all((s is None or s == 1) for s in strides):
raise NotImplementedError("Strided slices of references are unsupported.")

static_starts, dynamic_starts = _split_static_and_dynamic_values(starts)
static_sizes, dynamic_sizes = _split_static_and_dynamic_values(sizes)

ref_ty = ir.MemRefType(ref.type)
ref_strides, ref_offset = ref_ty.get_strides_and_offset()

ir_dynamic_size = ir.ShapedType.get_dynamic_size()
if ref_offset == ir_dynamic_size or ir_dynamic_size in static_starts:
out_offset = ir_dynamic_size
else:
out_offset = sum(
map(operator.mul, static_starts, ref_strides), ref_offset
)
out_sizes = [s for i, s in enumerate(static_sizes) if not squeeze_dims[i]]
out_strides = [s for i, s in enumerate(ref_strides) if not squeeze_dims[i]]
out_layout = ir.StridedLayoutAttr.get(out_offset, out_strides)
out_ty = ir.MemRefType.get(
out_sizes, ref_ty.element_type, out_layout, ref_ty.memory_space
)

# We bypass ``memref.subview``, because we want to precisely control how the
# static/dynamic split is performed, since it affects the result layout.
out = memref.SubViewOp(
out_ty,
ref,
dynamic_starts,
dynamic_sizes,
[],
static_starts,
static_sizes,
static_strides=[1] * len(ref_strides),
).result
return out, ref_block_shape


def _transform_ref(
ref: ir.Value,
ref_dtype: jax.typing.DTypeLike,
ref_block_shape: tuple[int | pallas_core.Squeezed, ...] | None,
transforms: Sequence[pallas_core.MemoryRefTransform],
) -> tuple[ir.Value, tuple[int | pallas_core.Squeezed, ...] | None]:
for transform in transforms:
if isinstance(transform, indexing.NDIndexer):
ref, ref_block_shape = _slice_memref(
ref, transform, ref_dtype, ref_block_shape
)
else:
ref, ref_block_shape = tc_lowering._transform_ref(
ref, ref_dtype, ref_block_shape, [transform]
)
return ref, ref_block_shape
4 changes: 2 additions & 2 deletions jax/_src/pallas/mosaic/sc_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def _gather_lowering_rule(
)
if transforms:
ref_block_shape, *_ = ctx.block_shapes
ref, _ = tc_lowering._transform_ref(
ref, _ = sc_lowering._transform_ref(
ref, ref_aval.dtype, ref_block_shape, transforms
)
[out_aval] = ctx.avals_out
Expand Down Expand Up @@ -344,7 +344,7 @@ def _scatter_lowering_rule(
)
if transforms:
ref_block_shape, *_ = ctx.block_shapes
ref, _ = tc_lowering._transform_ref(
ref, _ = sc_lowering._transform_ref(
ref, ref_aval.dtype, ref_block_shape, transforms
)
tpu.vector_store_idx(x, ref, indices, mask=mask, add=add)
Expand Down
66 changes: 0 additions & 66 deletions jaxlib/mosaic/dialect/tpu/tpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,6 @@ OpFoldResult BitcastVregOp::fold(FoldAdaptor adaptor) {
LogicalResult MemRefSliceOp::verify() {
auto source_type = getMemRefType(getMemRef());
auto target_type = getType();
auto source_layout = source_type.getLayout();
auto target_layout = target_type.getLayout();
auto target_memory_space = target_type.getMemorySpace();
auto indices = getBaseIdx();
auto slice_shape = getResult().getType().getShape();
Expand Down Expand Up @@ -184,43 +182,6 @@ LogicalResult MemRefSliceOp::verify() {
return emitOpError(
"Memory spaces must match if the target memory space is provided.");
}
if (isa<TiledLayoutAttr>(source_layout) &&
!isa<TiledLayoutAttr>(target_layout)) {
// TODO(slebedev): Remove this special-case once we move layout propagation
// to the infer-memref-layout pass.
} else if (isa<StridedLayoutAttr>(target_layout)) {
SmallVector<int64_t> source_strides;
int64_t source_offset;
if (failed(
source_type.getStridesAndOffset(source_strides, source_offset))) {
return failure();
}
int64_t target_offset = source_offset;
if (target_offset != ShapedType::kDynamic) {
for (auto [base_idx, source_stride] :
llvm::zip(getBaseIdx(), source_strides)) {
if (auto idx = getConstantIntValue(base_idx)) {
target_offset += *idx * source_stride;
} else {
target_offset = ShapedType::kDynamic;
break;
}
}
}
auto expected_layout =
StridedLayoutAttr::get(getContext(), target_offset, source_strides);
if (target_layout != expected_layout) {
return emitOpError("Layout mismatch: got ")
<< target_layout << ", expected " << expected_layout << ".";
}
} else {
bool is_target_layout_identity_map =
isa<AffineMapAttr>(target_layout) && target_layout.isIdentity();
if (!is_target_layout_identity_map && target_layout != source_layout) {
return emitOpError(
"Layouts must match if the target layout is not an identity map.");
}
}
if (getDynamicSizes().size() != target_type.getNumDynamicDims()) {
return emitOpError(
"Number of provided dynamic dimensions sizes must match the number of "
Expand Down Expand Up @@ -325,33 +286,6 @@ LogicalResult MemRefSqueezeOp::verify() {
return failure();
}

auto source_layout = source_type.getLayout();
auto target_layout = target_type.getLayout();
if (isa<TiledLayoutAttr>(source_layout) &&
!isa<TiledLayoutAttr>(target_layout)) {
// TODO(slebedev): Remove this special-case once we move layout propagation
// to the infer-memref-layout pass.
} else if (isa<StridedLayoutAttr>(target_layout)) {
SmallVector<int64_t> source_strides;
int64_t source_offset;
if (failed(
source_type.getStridesAndOffset(source_strides, source_offset))) {
return failure();
}
SmallVector<int64_t> target_strides;
for (auto [i, stride] : llvm::enumerate(source_strides)) {
if (!llvm::is_contained(*squeezed_or, i)) {
target_strides.push_back(stride);
}
}
auto expected_layout =
StridedLayoutAttr::get(getContext(), source_offset, target_strides);
if (target_layout != expected_layout) {
return emitOpError("Layout mismatch: got ")
<< target_layout << ", expected " << expected_layout << ".";
}
}

auto erase_layout_op = getInput().getDefiningOp<tpu::EraseLayoutOp>();
if (!erase_layout_op) {
return success();
Expand Down
3 changes: 0 additions & 3 deletions tests/pallas/tpu_pallas_distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,9 +750,6 @@ def _():
class VerificationTest(jtu.JaxTestCase):

def test_verification(self):
self.skipTest(
'TODO(b/455847773): Fix MLIR layout mismatch in tpu.memref_slice (dynamic offset issue).'
)
if (num_devices := jax.local_device_count()) <= 1:
self.skipTest('Test requires multiple devices.')
if not jtu.is_device_tpu_at_least(4) or jax.devices()[0].num_cores > 1:
Expand Down
Loading