diff --git a/stablehlo/conversions/linalg/tests/scatter.mlir b/stablehlo/conversions/linalg/tests/scatter.mlir new file mode 100644 index 0000000000..a29c59cfbc --- /dev/null +++ b/stablehlo/conversions/linalg/tests/scatter.mlir @@ -0,0 +1,111 @@ +// RUN: stablehlo-opt %s --stablehlo-legalize-to-linalg --split-input-file --canonicalize | FileCheck %s + +func.func @matching_update_tensor(%arg0: tensor<1x32x32x128xf32>, %arg1: tensor<1x32x1x128xf32>, %arg2: tensor<1x1xi64>) -> tensor<1x32x32x128xf32> { + // CHECK-NOT: stablehlo.scatter + // CHECK: %[[ZERO:.*]] = arith.constant 0 : index + // CHECK: %[[EXT:.*]] = tensor.extract %arg2[%[[ZERO]], %[[ZERO]]] : tensor<1x1xi64> + // CHECK: %[[IDX:.*]] = arith.index_cast %[[EXT]] : i64 to index + // CHECK: tensor.insert_slice %arg1 into %arg0[0, 0, %[[IDX]], 0] [1, 32, 1, 128] [1, 1, 1, 1] : tensor<1x32x1x128xf32> into tensor<1x32x32x128xf32> + %0 = "stablehlo.scatter"(%arg0, %arg2, %arg1) <{ + indices_are_sorted = false, + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [0, 1, 3], + inserted_window_dims = [2], + scatter_dims_to_operand_dims = [2], + index_vector_dim = 1>, + unique_indices = false}> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + stablehlo.return %arg4 : tensor + }) : (tensor<1x32x32x128xf32>, tensor<1x1xi64>, tensor<1x32x1x128xf32>) -> tensor<1x32x32x128xf32> + return %0 : tensor<1x32x32x128xf32> + + +} + +// ----- + +func.func @smaller_update_tensor() -> tensor<9x7x5xf64> { + // CHECK-DAG: %[[scatter_indices:.*]] = tensor.empty() : tensor<1xi32> + // CHECK-DAG: %[[inputs:.*]] = tensor.empty() : tensor<9x[[dim1:.*]]x[[dim0:.*]]xf64> + // CHECK-DAG: %[[updates:.*]] = tensor.empty() : tensor<[[dim1]]x[[dim0]]xf64> + // CHECK-DAG: %[[zero:.*]] = arith.constant 0 : index + %scatter_indices = tensor.empty() : tensor<1xi32> + %inputs = tensor.empty() : tensor<9x7x5xf64> + %updates = tensor.empty() : tensor<7x5xf64> + + // CHECK-DAG: %[[ext:.*]] = tensor.extract %[[scatter_indices]][%[[zero]]] : tensor<1xi32> + // CHECK-DAG: %[[idx:.*]] = arith.index_cast %[[ext]] : i32 to index + // CHECK-DAG: %[[inserted_slice:.*]] = tensor.insert_slice %[[updates]] into %[[inputs]][%[[idx]], 0, 0] [1, [[dim1]], [[dim0]]] [1, 1, 1] : tensor<7x5xf64> into tensor<9x7x5xf64> + + %3 = "stablehlo.scatter"(%inputs, %scatter_indices, %updates) <{ + indices_are_sorted = true, + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [0, 1], + inserted_window_dims = [0], + scatter_dims_to_operand_dims = [0]>, + unique_indices = true}> ({ + ^bb0(%arg0: tensor, %arg1: tensor): + stablehlo.return %arg1 : tensor + }) : (tensor<9x7x5xf64>, tensor<1xi32>, tensor<7x5xf64>) -> tensor<9x7x5xf64> + return %3 : tensor<9x7x5xf64> +} + +// ----- + +func.func @non_matching_scatter(%arg0: tensor<2x3x4x2xi64>, %arg1: tensor<2x2x3x2xi64>, %arg2: tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64> { + // CHECK: stablehlo.scatter + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = stablehlo.add %arg3, %arg4 : tensor + stablehlo.return %1 : tensor + }) : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64> + return %0 : tensor<2x3x4x2xi64> +} + +// ----- + +func.func @scatter_with_batching_dims(%input_tensor: tensor<5x200x100x300xf32>, + %scatter_indices: tensor<5x10x2xi32>, %updates: tensor<5x10x300xf32>) -> + tensor<5x200x100x300xf32> { + // CHECK: stablehlo.scatter + %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor, %rhs: tensor): + %add = stablehlo.add %lhs, %rhs : tensor + "stablehlo.return"(%add) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [2], + inserted_window_dims = [1, 2], + input_batching_dims = [0], + scatter_indices_batching_dims = [0], + scatter_dims_to_operand_dims = [1, 2], + index_vector_dim = 2 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<5x200x100x300xf32>, tensor<5x10x2xi32>, tensor<5x10x300xf32>) -> + tensor<5x200x100x300xf32> + func.return %0 : tensor<5x200x100x300xf32> +} + +// ----- + +func.func @valid_scatter_dimensions_with_dynamic_index_vector_dim( + %input_tensor: tensor, %scatter_indices: tensor<10x?xi32>, + %updates: tensor) -> tensor { + %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor, %rhs: tensor): + %add = stablehlo.add %lhs, %rhs : tensor + "stablehlo.return"(%add) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1, 2], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor, tensor<10x?xi32>, tensor) -> tensor + func.return %0 : tensor +} diff --git a/stablehlo/conversions/linalg/transforms/CMakeLists.txt b/stablehlo/conversions/linalg/transforms/CMakeLists.txt index d65809a5b0..a8fcc8c63b 100644 --- a/stablehlo/conversions/linalg/transforms/CMakeLists.txt +++ b/stablehlo/conversions/linalg/transforms/CMakeLists.txt @@ -24,6 +24,7 @@ add_mlir_library(StablehloLinalgTransforms StablehloToLinalgPointwise.cpp StablehloToLinalgRandom.cpp StablehloToLinalgReduce.cpp + StablehloToLinalgScatter.cpp TypeConversion.cpp DEPENDS diff --git a/stablehlo/conversions/linalg/transforms/Rewriters.h b/stablehlo/conversions/linalg/transforms/Rewriters.h index 9db1a021ea..833a7616d0 100644 --- a/stablehlo/conversions/linalg/transforms/Rewriters.h +++ b/stablehlo/conversions/linalg/transforms/Rewriters.h @@ -66,6 +66,12 @@ void populateStablehloReductionToLinalgConversionPatterns( MLIRContext *context, TypeConverter &typeConverter, RewritePatternSet *patterns, bool enablePrimitiveOps); +/// Populates the patterns that convert from scatter StableHLO ops to Linalg +/// on tensors. +void populateStablehloScatterToLinalgConversionPatterns( + MLIRContext *context, TypeConverter &typeConverter, + RewritePatternSet *patterns, bool enablePrimitiveOps); + /// Populates the patterns that convert scalar StableHLO ops to Arith ops. void populateScalarHloToArithConversionPatterns( MLIRContext *context, TypeConverter &typeConverter, diff --git a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp index 45f3cbbb1c..464278e452 100644 --- a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp +++ b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp @@ -2699,6 +2699,8 @@ void populateStablehloToLinalgConversionPatterns(MLIRContext *context, context, typeConverter, patterns); detail::populateStablehloReductionToLinalgConversionPatterns( context, typeConverter, patterns, enablePrimitiveOps); + detail::populateStablehloScatterToLinalgConversionPatterns( + context, typeConverter, patterns, enablePrimitiveOps); detail::populateScalarHloToArithConversionPatterns( context, typeConverter, patterns, isInBodyOfLinalgOps); linalg::populateEraseUnusedOperandsAndResultsPatterns(*patterns); diff --git a/stablehlo/conversions/linalg/transforms/StablehloToLinalgScatter.cpp b/stablehlo/conversions/linalg/transforms/StablehloToLinalgScatter.cpp new file mode 100644 index 0000000000..50548262c2 --- /dev/null +++ b/stablehlo/conversions/linalg/transforms/StablehloToLinalgScatter.cpp @@ -0,0 +1,204 @@ +/* Copyright 2019 The IREE Authors + Copyright 2023 OpenXLA Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Implements logic for lowering StableHLO scatter ops to Linalg dialect. +// These patterns are separated out to their own file to save on the compilation +// times. + +#include "llvm/Support/FormatVariadic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "stablehlo/conversions/linalg/transforms/Rewriters.h" +#include "stablehlo/dialect/StablehloOps.h" + +namespace mlir::stablehlo { +namespace { +bool isAssignment(stablehlo::ScatterOp op) { + // Return true if the scatter op is equivalent to an assignment. + // This means that there is only one op in the body, and it is a ReturnOp. + // E.g., + // update_function = + // ^bb0(%arg0: T, %arg1: T): + // return %arg1 : T + // }) + Region ®ion = op.getUpdateComputation(); + Block &block = region.front(); + bool oneOperation = block.begin() == --block.end(); + if (!oneOperation) { + return false; + } + + stablehlo::ReturnOp returnOp = + dyn_cast(block.getTerminator()); + if (!returnOp) { + return false; + } + + return returnOp.getOperands().front() == block.getArgument(1); +} + +bool singleFullSlices(stablehlo::ScatterOp op) { + // Return true if the scatter op is inserting the whole update tensor into the + // input tensor. This means that all dims that are not in the + // update_window_dims are size 1. + + auto update = op.getUpdates().front(); + auto updateTy = dyn_cast(update.getType()); + if (!updateTy || !updateTy.hasStaticShape()) { + return false; // Can't verify without static shape + } + + auto scatterDimNumbers = op.getScatterDimensionNumbers(); + auto updateWindowDims = scatterDimNumbers.getUpdateWindowDims(); + + llvm::SmallDenseSet windowDimsSet(updateWindowDims.begin(), + updateWindowDims.end()); + + auto shape = updateTy.getShape(); + for (int64_t i = 0; i < static_cast(shape.size()); ++i) { + if (!windowDimsSet.contains(i)) { + if (shape[i] != 1) { + // Found a non-window dimension that is not size-1 + return false; + } + } + } + return true; +} + +bool isInsertSliceScatter(stablehlo::ScatterOp op) { + // Return true if the scatter op is equivalent to an insert_slice + + // Requirement 1: has exactly one input, one update and one result tensor + if (op.getInputs().size() != 1 || op.getUpdates().size() != 1 || + op.getResults().size() != 1) { + return false; + } + + // Requirement 2: is assignment (see isAssignment) + if (!isAssignment(op)) { + return false; + } + + // Requirement 3: no batching + // input_batching_dims = [] + // scatter_indices_batching_dims = [] + auto scatterDimNumbers = op.getScatterDimensionNumbers(); + if (!scatterDimNumbers.getInputBatchingDims().empty()) { + return false; + } + + // Requirement 4: we are inserting the whole %update into a dimension of + // %input + if (!singleFullSlices(op)) { + return false; + } + + // Requirement 5: scatter indices is a static tensor of size 1 + auto indicesType = cast(op.getScatterIndices().getType()); + if (!indicesType.hasStaticShape() || indicesType.getNumElements() != 1) { + return false; + } + + return true; +} + +/// Pattern to lower relevant stablehlo::ScatterOps to tensor.insert_slice ops +struct ReduceOpToInsertSliceConverter final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + stablehlo::ScatterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isInsertSliceScatter(op)) { + return failure(); + } + + auto input = op.getInputs().front(); + auto update = op.getUpdates().front(); + auto scatterIndices = op.getScatterIndices(); + + auto inputTy = cast(input.getType()); + auto updateTy = cast(update.getType()); + auto inputShape = inputTy.getShape(); + auto updateShape = updateTy.getShape(); + + auto scatterDimNumbers = op.getScatterDimensionNumbers(); + auto insertedWindowDims = scatterDimNumbers.getInsertedWindowDims(); + + SmallVector dynOffsets, dynSizes, dynStrides; + SmallVector staticOffsets, staticSizes, staticStrides; + Location loc = op.getLoc(); + bool sameRank = inputTy.getRank() == updateTy.getRank(); + + for (size_t i = 0, updateDim = 0; i < inputShape.size(); i++) { + if (llvm::is_contained(insertedWindowDims, i)) { + auto zero = rewriter.create(loc, 0); + int64_t rank = cast(scatterIndices.getType()).getRank(); + SmallVector indices; + for (int64_t i = 0; i < rank; ++i) { + indices.push_back(zero); + } + auto extractOp = + rewriter.create(loc, scatterIndices, indices); + auto indexCastOp = rewriter + .create( + loc, rewriter.getIndexType(), extractOp) + .getResult(); + + // Offset is dynamic, based on the index we extract + dynOffsets.push_back(indexCastOp); + staticOffsets.push_back(ShapedType::kDynamic); + staticSizes.push_back(1); + if (sameRank) { + if (updateShape[updateDim] != 1) { + op->emitError(llvm::formatv("updateShape[{0}] must be 1, got {1}", + updateDim, updateShape[updateDim])); + } + updateDim++; + } + + } else { + staticOffsets.push_back(0); + staticSizes.push_back(updateShape[updateDim]); + updateDim++; + } + staticStrides.push_back(1); + } + + rewriter.replaceOpWithNewOp( + op, update, input, dynOffsets, dynSizes, dynStrides, staticOffsets, + staticSizes, staticStrides); + return success(); + } +}; +} // namespace + +namespace detail { +void populateStablehloScatterToLinalgConversionPatterns( + MLIRContext *context, TypeConverter &typeConverter, + RewritePatternSet *patterns, bool enablePrimitiveOps) { + // Ensure specialized patterns are higher priority than their generic + // versions. + patterns->add(typeConverter, context, + PatternBenefit(2)); +} +} // namespace detail +} // namespace mlir::stablehlo