Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <algorithm>
#include <array>
#include <cstdint>
#include <memory>
Expand Down Expand Up @@ -132,8 +133,25 @@ FailureOr<TypedValue<VectorType>> relayout(
// dst. We will only change bitwidth when both src and dst are in the safe
// layout. Necessary relayout will be inserted to relayout from src to the
// safe layout or from the safe layout to dst.
auto safe_tiling =
src.tiling()[0] < dst.tiling()[0] ? src.tiling() : dst.tiling();
// Consider cases like src is 32-bit with (8, 128) tiling and dst is 8-bit
// with (32, 128) tiling, sublane tiling shouldn't be too large, so we
// choose the smaller one between src and dst. On the other hand, for cases
// like src is 32-bit with (1, 128) tiling and dst is 8-bit with (4, 128)
// tiling, a safe sublane tiling should be at least the larger packing
// between src and dst.
int64_t safe_packing = std::max(src.packing(), dst.packing());
// TODO(yueshengys): this is still not safe for cases with 2-bit, because we
// may end with (16, 128) tiling for 32-bit. In those cases, we should do
// multiple rounds of the process.
if (safe_packing * src.bitwidth() > 32 * target_shape[0] ||
safe_packing * dst.bitwidth() > 32 * target_shape[0]) {
return emitError(v.getLoc(),
"Not implemented: changeBitwidth when src bitwidth and "
"dst bitwidth differs too much.");
}
std::array<int64_t, 2> safe_tiling = {
std::max(std::min(src.tiling()[0], dst.tiling()[0]), safe_packing),
target_shape[1]};
auto safe_vreg_slice =
VectorLayout::vregSlice(target_shape, dst.bitwidth(), safe_tiling);
auto safe_offsets = LayoutOffsets{
Expand Down
6 changes: 4 additions & 2 deletions tests/pallas/tpu_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,11 +346,13 @@ def kernel(x, out):
np.testing.assert_array_equal(result, expected)

@parameterized.product(
shape=[(129, 129), (1, 129), (2, 129), (4, 129)],
msk_dtype=[jnp.float32, jnp.bfloat16, jnp.int8],
dtype=[jnp.float32, jnp.bfloat16],
)
def test_i1_relayout_bw(self, msk_dtype, dtype):
shape = (129, 129)
def test_i1_relayout_bw(self, shape, msk_dtype, dtype):
if shape[0] < 8 and not jtu.if_cloud_tpu_at_least(2025, 11, 9):
self.skipTest("Requires libtpu built after 2025-11-09")
msk_bitwidth = dtypes.bit_width(msk_dtype)
bitwidth = dtypes.bit_width(dtype)
if jtu.get_tpu_version() < 5 and msk_bitwidth < 32:
Expand Down
Loading