diff --git a/CMakeLists.txt b/CMakeLists.txt index 7e52f92d..693c96bb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -58,6 +58,8 @@ if (TRITON_BUILD_PYTHON_MODULE) LinalgToLinked LinkedToHIVM + DiscreteMaskAccessConversion + TritonToUnstructure ) target_link_libraries(tritonDicpTriton PRIVATE Python3::Module pybind11::headers) endif() \ No newline at end of file diff --git a/backend/npu.py b/backend/npu.py index 3be3d0af..7171c20a 100644 --- a/backend/npu.py +++ b/backend/npu.py @@ -396,6 +396,11 @@ def ttir_to_linalg(mod, metadata, opt, *, named_ops=True): def ttir_to_ttsharedir_ascend(mod, metadata, opt, *, named_ops=False): pm = ir.pass_manager(mod.context) + dicp_triton.passes.triton_shared_ascend.add_discrete_mask_access_conversion( + pm, False, False + ) + dicp_triton.passes.triton_shared_ascend.add_triton_to_unstructure(pm) + dicp_triton.passes.triton_shared_ascend.add_bubble_up_operation(pm) dicp_triton.passes.triton_shared_ascend.add_canonicalize_cmpi(pm) dicp_triton.passes.triton_shared_ascend.add_canonicalize_triton_ir_ascend(pm) dicp_triton.passes.triton_shared_ascend.add_triton_to_linalg_npu(pm) @@ -404,6 +409,9 @@ def ttir_to_ttsharedir_ascend(mod, metadata, opt, *, named_ops=False): cmd_list = [ _get_dicp_opt_path(), "kernel.ttir.mlir", + "--discrete-mask-access-conversion", + "--triton-to-unstructure", + "--bubble-up-operation", "--canonicalize-cmpi", "--canonicalize-triton-ir-ascend", "--triton-to-linalg-npu-conversion", diff --git a/compiler/include/dicp/Conversion/DiscreteMaskAccessConversion/CMakeLists.txt b/compiler/include/dicp/Conversion/DiscreteMaskAccessConversion/CMakeLists.txt new file mode 100644 index 00000000..567a119e --- /dev/null +++ b/compiler/include/dicp/Conversion/DiscreteMaskAccessConversion/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name DiscreteMaskAccessConversion) +add_public_tablegen_target(DiscreteMaskAccessConversionPassIncGen) \ No newline at end of file diff --git a/compiler/include/dicp/Conversion/DiscreteMaskAccessConversion/MaskAnalysis.h b/compiler/include/dicp/Conversion/DiscreteMaskAccessConversion/MaskAnalysis.h new file mode 100644 index 00000000..10ba119e --- /dev/null +++ b/compiler/include/dicp/Conversion/DiscreteMaskAccessConversion/MaskAnalysis.h @@ -0,0 +1,124 @@ + +#ifndef TRITON_ANALYSIS_MASKANALYSIS_H +#define TRITON_ANALYSIS_MASKANALYSIS_H + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include + +namespace mlir { + +// this class helps build Operations +class OpBuilder; + +namespace dicp { +// use to decode the pattern in a mask used for load and store + +class MaskState { +public: + OpFoldResult start; + OpFoldResult end; + SmallVector dims; + SmallVector offsets; + OpFoldResult scalar; + + int64_t getRank() const { + assert(dims.size() == offsets.size() && "dims and offsets rank mismatch!"); + return dims.size(); + } + + bool isEmpty() const { return getRank() == 0 && !scalar && !start && !end; } + + bool isMask() const { + return !start && !end && !scalar && dims.size() != 0 && offsets.size() != 0; + } + + // parse value recursively + LogicalResult parse(Value operand, const Location &loc, OpBuilder &builder); + + tensor::ExtractSliceOp getExtractSlice(Value source, const Location &loc, + OpBuilder &builder) const; + + tensor::InsertSliceOp getInsertSlice(Value source, Value dest, + const Location &loc, + OpBuilder &builder) const; + + memref::SubViewOp getSubview(Value source, const Location &loc, + OpBuilder &builder) const; + + void eraseInsertedOps(Operation *rawOp, PatternRewriter &rewriter); + +private: + LogicalResult addStateScalar(const MaskState &state, + const OpFoldResult scalar, const Location &loc, + OpBuilder &builder); + + LogicalResult addStates(const MaskState &lhsState, const MaskState &rhsState, + const Location &loc, OpBuilder &builder); + + LogicalResult divStateScalar(const MaskState &state, + const OpFoldResult scalar, const Location &loc, + OpBuilder &builder); + + LogicalResult divStates(const MaskState &lhsState, const MaskState &rhsState, + const Location &loc, OpBuilder &builder); + + // Helper function to handle operator `and` both mask state + LogicalResult minStates(const MaskState &lhsState, const MaskState &rhsState, + const Location &loc, OpBuilder &builder); + + // Helper functions to parse values to populate MaskState + + LogicalResult parseConstant(arith::ConstantOp constOp, const Location &loc, + OpBuilder &builder); + + // Operand is an integer scalar + LogicalResult parseIntScalar(Value scalar, const Location &loc, + OpBuilder &builder); + + // TODO + LogicalResult parseAdd(arith::AddIOp addOp, const Location &loc, + OpBuilder &builder); + + // operand is the result of divsi + LogicalResult parseDiv(arith::DivSIOp divOp, const Location &loc, + OpBuilder &builder); + + // Operand is the result of andi + LogicalResult parseAnd(arith::AndIOp andOp, const Location &loc, + OpBuilder &builder); + + // Operand is the result of cmpi, necessary method to fuse scalar, start and + // end into dims and offset + LogicalResult parseCmp(arith::CmpIOp cmpOp, const Location &loc, + OpBuilder &builder); + + // Operand is the result of select + LogicalResult parseSel(arith::SelectOp selOp, const Location &loc, + OpBuilder &builder); + + // Operand is the result of make_range + LogicalResult parseMakeRange(triton::MakeRangeOp rangeOp, const Location &loc, + OpBuilder &builder); + + // Operand is the result of broadcast + LogicalResult parseBroadcast(triton::BroadcastOp broadcastOp, + const Location &loc, OpBuilder &builder); + + // Operand is the result of splat + LogicalResult parseSplat(triton::SplatOp splatOp, const Location &loc, + OpBuilder &builder); + + // Operand is the result of expand_dims + LogicalResult parseExpandDims(triton::ExpandDimsOp expandDimsOp, + const Location &loc, OpBuilder &builder); +}; + +} // namespace dicp + +} // namespace mlir + +#endif diff --git a/compiler/include/dicp/Conversion/DiscreteMaskAccessConversion/Passes.h b/compiler/include/dicp/Conversion/DiscreteMaskAccessConversion/Passes.h new file mode 100644 index 00000000..d7c1cf1f --- /dev/null +++ b/compiler/include/dicp/Conversion/DiscreteMaskAccessConversion/Passes.h @@ -0,0 +1,33 @@ +#ifndef TRITON_DLC_DISCRETE_MASK_ACCESS_CONVERSION_PASSES_H +#define TRITON_DLC_DISCRETE_MASK_ACCESS_CONVERSION_PASSES_H + +#include "mlir/Pass/Pass.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/IR/PatternMatch.h" + +#define GEN_PASS_DECL_DISCRETEMASKACCESSCONVERSION +#include "dicp/Conversion//DiscreteMaskAccessConversion/Passes.h.inc" + +#define GEN_PASS_DEF_DISCRETEMASKACCESSCONVERSION +#include "dicp/Conversion//DiscreteMaskAccessConversion/Passes.h.inc" + +namespace mlir { +namespace triton { + +std::unique_ptr> createDiscreteMaskAccessConversionPass( + const DiscreteMaskAccessConversionOptions &options = {}); + +} // namespace triton +} // namespace mlir + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "dicp/Conversion//DiscreteMaskAccessConversion/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif // TRITON_DLC_DISCRETE_MASK_ACCESS_CONVERSION_PASSES_H diff --git a/compiler/include/dicp/Conversion/DiscreteMaskAccessConversion/Passes.td b/compiler/include/dicp/Conversion/DiscreteMaskAccessConversion/Passes.td new file mode 100644 index 00000000..343127b8 --- /dev/null +++ b/compiler/include/dicp/Conversion/DiscreteMaskAccessConversion/Passes.td @@ -0,0 +1,21 @@ + + +#ifndef DISCRETE_MASK_ACCESS_CONVERSION_PASSES +#define DISCRETE_MASK_ACCESS_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def DiscreteMaskAccessConversion : Pass<"discrete-mask-access-conversion", "mlir::ModuleOp"> { + let summary = "Recognize and convert discrete mask memory access"; + let constructor = "triton::createDiscreteMaskAccessConversionPass()"; + let options = [ + Option<"compileOn91095", "compile-on-910-95", + "bool", /*default*/"false", + "compile on 910_95">, + Option<"forceSimtTemplate", "force-simt-template", + "bool", /*default*/"false", + "force to use simt template"> + ]; +} + +#endif // DISCRETE_MASK_ACCESS_CONVERSION_PASSES diff --git a/compiler/include/dicp/Conversion/LinalgToLinked/TritonOpConverter.h b/compiler/include/dicp/Conversion/LinalgToLinked/TritonOpConverter.h index 9233f9c9..ceb0e458 100644 --- a/compiler/include/dicp/Conversion/LinalgToLinked/TritonOpConverter.h +++ b/compiler/include/dicp/Conversion/LinalgToLinked/TritonOpConverter.h @@ -1,5 +1,5 @@ -#ifndef TRITON_ADAPTER_TRITONOPCONVERTER_H -#define TRITON_ADAPTER_TRITONOPCONVERTER_H +#ifndef TRITON_DLC_TRITONOPCONVERTER_H +#define TRITON_DLC_TRITONOPCONVERTER_H #include "triton/Dialect/Triton/IR/Dialect.h" diff --git a/compiler/include/dicp/Conversion/TritonToLinalgNPU/MemRefCopyGatherToTensorInsert/CMakeLists.txt b/compiler/include/dicp/Conversion/TritonToLinalgNPU/MemRefCopyGatherToTensorInsert/CMakeLists.txt new file mode 100644 index 00000000..5d8615fe --- /dev/null +++ b/compiler/include/dicp/Conversion/TritonToLinalgNPU/MemRefCopyGatherToTensorInsert/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name MemRefCopyGatherToTensorInsert) +add_public_tablegen_target(MemRefCopyGatherToTensorInsertPassIncGen) diff --git a/compiler/include/dicp/Conversion/TritonToLinalgNPU/MemRefCopyGatherToTensorInsert/Passes.h b/compiler/include/dicp/Conversion/TritonToLinalgNPU/MemRefCopyGatherToTensorInsert/Passes.h new file mode 100644 index 00000000..d6664621 --- /dev/null +++ b/compiler/include/dicp/Conversion/TritonToLinalgNPU/MemRefCopyGatherToTensorInsert/Passes.h @@ -0,0 +1,17 @@ +#ifndef MEMREF_COPY_GATHER_TO_TENSOR_INSERT_PASSES_H +#define MEMREF_COPY_GATHER_TO_TENSOR_INSERT_PASSES_H + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::dicp::linked { + +std::unique_ptr> +createMemRefCopyGatherToTensorInsertPass(); + +#define GEN_PASS_REGISTRATION +#include "dicp/Conversion/TritonToLinalgNPU/MemRefCopyGatherToTensorInsert/Passes.h.inc" + +} // namespace mlir::dicp::linked + +#endif diff --git a/compiler/include/dicp/Conversion/TritonToLinalgNPU/MemRefCopyGatherToTensorInsert/Passes.td b/compiler/include/dicp/Conversion/TritonToLinalgNPU/MemRefCopyGatherToTensorInsert/Passes.td new file mode 100644 index 00000000..a0f9e4a4 --- /dev/null +++ b/compiler/include/dicp/Conversion/TritonToLinalgNPU/MemRefCopyGatherToTensorInsert/Passes.td @@ -0,0 +1,24 @@ +#ifndef MEMREF_COPY_GATHER_TO_TENSOR_INSERT_PASSES +#define MEMREF_COPY_GATHER_TO_TENSOR_INSERT_PASSES + +include "mlir/Pass/PassBase.td" + +def MemRefCopyGatherToTensorInsert : Pass<"discrete-gather-to-direct-insert", "mlir::ModuleOp"> { + let summary = "Converts alloc+copy based gather loops to direct tensor insertions"; + let description = [{ + Identifies patterns where a temporary memref is allocated, populated via + index-based gathers (subview + copy) in a loop, and then converted to a tensor. + Replaces this with a tensor.empty + scf.for (iter_args) + tensor.insert + sequence to avoid stack allocation and enable register-level operation. + }]; + let constructor = "::mlir::dicp::linked::createMemRefCopyGatherToTensorInsertPass()"; + let dependentDialects = [ + "mlir::scf::SCFDialect", + "mlir::memref::MemRefDialect", + "mlir::tensor::TensorDialect", + "mlir::arith::ArithDialect", + "mlir::bufferization::BufferizationDialect" + ]; +} + +#endif diff --git a/compiler/include/dicp/Conversion/TritonToUnstructure/BubbleUpOperation.h b/compiler/include/dicp/Conversion/TritonToUnstructure/BubbleUpOperation.h new file mode 100644 index 00000000..91990030 --- /dev/null +++ b/compiler/include/dicp/Conversion/TritonToUnstructure/BubbleUpOperation.h @@ -0,0 +1,31 @@ +#pragma once + +#include "mlir/Pass/Pass.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/IR/PatternMatch.h" + +#define GEN_PASS_DECL_BUBBLEUPOPERATION +#include "dicp/Conversion/TritonToUnstructure/Passes.h.inc" + +#define GEN_PASS_DEF_BUBBLEUPOPERATION +#include "dicp/Conversion/TritonToUnstructure/Passes.h.inc" + +namespace mlir { +namespace triton { + +std::unique_ptr> +createBubbleUpOperationPass(const BubbleUpOperationOptions &options = {}); + +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace triton; + +class BubbleUpOperationPass + : public ::impl::BubbleUpOperationBase { +public: + explicit BubbleUpOperationPass(const BubbleUpOperationOptions &options); + void runOnOperation() override; +}; diff --git a/compiler/include/dicp/Conversion/TritonToUnstructure/CMakeLists.txt b/compiler/include/dicp/Conversion/TritonToUnstructure/CMakeLists.txt new file mode 100644 index 00000000..ecdd8d93 --- /dev/null +++ b/compiler/include/dicp/Conversion/TritonToUnstructure/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToUnstructure) +add_public_tablegen_target(TritonToUnstructureConversionPassIncGen) \ No newline at end of file diff --git a/compiler/include/dicp/Conversion/TritonToUnstructure/OffsetAnalysis.h b/compiler/include/dicp/Conversion/TritonToUnstructure/OffsetAnalysis.h new file mode 100644 index 00000000..2c13880f --- /dev/null +++ b/compiler/include/dicp/Conversion/TritonToUnstructure/OffsetAnalysis.h @@ -0,0 +1,237 @@ +#ifndef TRITON_ANALYSIS_OFFSETANALYSIS_H +#define TRITON_ANALYSIS_OFFSETANALYSIS_H + +#include "mlir/Dialect/Arith/IR/Arith.h" + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { +namespace triton { + +struct PtrOffsetInfo { + /** + Possible status of the ptr offset: + - ScalarLike: + - Tensor's elements are all the same such as [[2.0,2.0,2.0],[2.0,2.0,2.0]] + - Constant integer or floating-point such as 2, 2.0, and `load + tensor<1xptr>` + - Unstructured: + - Not a `ScalarLike` ptr offset + - Or satisfy any below conditions: + - Incontinuous stride such as + - `muli [0,1,2,3] [0,1,2,3]` => [0,1,4,9] + - `divsi [9,8,7] [3,2,1]` => [3,4,7] + - `minsi [3,4,5] [5,4,3]` => [3,4,3] + - From non-`scalarLike` floating point element type such as + - `fptosi [1.0,2.0,3.0]` => [1,2,3] + - Compilation time unknown value + - `load %ptr, %offset` => %value + - Structured: + - orthongonal to `Unstructured` + - if PtrOffsetInfo isn't `Unstructured`, it is `Structured` + + In short: + ScalarLike ⊆ Structured + Unstructured = {x| x ∉ Structured} + + Example: + ``` + %y = sitofp %x + %z = fptosi %y + ``` + If %x is scalarLike (structured), %z will be scalar (structured) as well. + If %x is non-scalarLike structured, %z will be unstructured. + */ + +public: + explicit PtrOffsetInfo(); + PtrOffsetInfo(const PtrOffsetInfo &other); + + explicit PtrOffsetInfo(const Value &ptr); + explicit PtrOffsetInfo(ArrayRef structured); + explicit PtrOffsetInfo(const Value &ptr, bool structured); + explicit PtrOffsetInfo(const Value &ptr, ArrayRef structured); + explicit PtrOffsetInfo(const Value &ptr, const Value &offset, + bool structured); + explicit PtrOffsetInfo(const Value &ptr, const Value &offset, + ArrayRef structured); + + PtrOffsetInfo &operator=(const PtrOffsetInfo &other); + + Value getPtr() const; + Value getOffset() const; + SmallVector getOffsets() const; + SmallVector &getOffsetsRef(); + bool isScalarLike() const; + SmallVector &getStructuredRef(); + const SmallVector &getStructured() const; + int getRank() const; + + void setPtr(const Value &ptr); + void setOffset(const Value &offset); + void setOffsets(ValueRange offsets); + void setStructured(); + void setStructured(int rank); + void setUnstructured(); + void setUnstructured(int rank); + void setStructured(ArrayRef structured); + void setStructured(const PtrOffsetInfo &other); + void setScalarLike(bool scalarLike); + + bool isStructured(int dim) const; + bool isStructured() const; + bool isUnstructured() const; + + void setZeroOffset(); + +private: + Value ptr; + Value offset; + SmallVector tptOffsets; + + bool scalarLike = false; + + SmallVector structured; +}; + +PtrOffsetInfo combineInfo(const PtrOffsetInfo &lhs, const PtrOffsetInfo &rhs); + +void parse(Value operand, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseLoopRegionIterArg(LoopLikeOpInterface loopOp, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap, + BlockArgument regionIterArg); + +void parseArithOp(Operation *arithOp, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseTritonOp(Operation *tritonOp, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseTritonOp(Operation *tritonOp, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseAddPtr(triton::AddPtrOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseSplat(triton::SplatOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +template +void parseBinaryOp(BinOpTy op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseAddI(arith::AddIOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseSubI(arith::SubIOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseIndexCast(arith::IndexCastOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +template +void parseConstantOp(ConstOpTy dst, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseMakeRange(triton::MakeRangeOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseExtSI(arith::ExtSIOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseBitcast(triton::BitcastOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseLoad(triton::LoadOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseMulI(arith::MulIOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseBroadcast(triton::BroadcastOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseExpandDims(triton::ExpandDimsOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseClampF(triton::ClampFOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseSelect(arith::SelectOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseFPToSI(arith::FPToSIOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseSIToFP(arith::SIToFPOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseMakeTensorDesc(triton::MakeTensorDescOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseMakeTensorPtr(triton::MakeTensorPtrOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseAdvance(triton::AdvanceOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseReduce(triton::ReduceOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseReduceReturn(triton::ReduceReturnOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseIf(scf::IfOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap, Value dst); + +void parseYield(scf::YieldOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseLoopOp(LoopLikeOpInterface op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap, Value dst); + +void parseExtractSlice(tensor::ExtractSliceOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseExtract(tensor::ExtractOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseIntToPtr(triton::IntToPtrOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); +} // namespace triton + +} // namespace mlir + +#endif diff --git a/compiler/include/dicp/Conversion/TritonToUnstructure/Passes.h b/compiler/include/dicp/Conversion/TritonToUnstructure/Passes.h new file mode 100644 index 00000000..e30f6441 --- /dev/null +++ b/compiler/include/dicp/Conversion/TritonToUnstructure/Passes.h @@ -0,0 +1,18 @@ + + +#ifndef TRITON_DLC_TRITON_TO_UNSTRUCTURE_CONVERSION_PASSES_H +#define TRITON_DLC_TRITON_TO_UNSTRUCTURE_CONVERSION_PASSES_H + +#include "BubbleUpOperation.h" +#include "UnstructureConversionPass.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "dicp/Conversion/TritonToUnstructure/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif // TRITON_DLC_TRITON_TO_UNSTRUCTURE_CONVERSION_PASSES_H diff --git a/compiler/include/dicp/Conversion/TritonToUnstructure/Passes.td b/compiler/include/dicp/Conversion/TritonToUnstructure/Passes.td new file mode 100644 index 00000000..a9abee65 --- /dev/null +++ b/compiler/include/dicp/Conversion/TritonToUnstructure/Passes.td @@ -0,0 +1,24 @@ +#ifndef TRITON_TO_UNSTRUCTURE_CONVERSION_PASSES +#define TRITON_TO_UNSTRUCTURE_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonToUnstructure : Pass<"triton-to-unstructure", "mlir::ModuleOp"> { + let summary = "Convert Triton for unstructure case"; + let constructor = "triton::createTritonToUnstructurePass()"; + let options = [ + Option<"forceScalarizeMode", "force-scalarize-mode", "bool", "false", + "Scalarize unstructured memory access even if structured dimensions are mixed."> + ]; +} + +def BubbleUpOperation : Pass<"bubble-up-operation", "mlir::ModuleOp"> { + let summary = "Apply bubble up operation optimization"; + let constructor = "triton::createBubbleUpOperationPass()"; + let options = [ + Option<"enableAggressiveMode", "enable-aggressive-mode", "bool", "true", + "Enable aggressive bubble up operation.">, + ]; +} + +#endif // TRITON_TO_UNSTRUCTURE_CONVERSION_PASSES diff --git a/compiler/include/dicp/Conversion/TritonToUnstructure/UnstructureConversionPass.h b/compiler/include/dicp/Conversion/TritonToUnstructure/UnstructureConversionPass.h new file mode 100644 index 00000000..e7e07f4a --- /dev/null +++ b/compiler/include/dicp/Conversion/TritonToUnstructure/UnstructureConversionPass.h @@ -0,0 +1,120 @@ +#ifndef TRITON_DLC_UNSTRUCTURECONVERSION_H +#define TRITON_DLC_UNSTRUCTURECONVERSION_H + +#include "dicp/Conversion/TritonToUnstructure/OffsetAnalysis.h" +#include "mlir/Pass/Pass.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/IR/PatternMatch.h" + +#define GEN_PASS_DECL_TRITONTOUNSTRUCTURE +#include "dicp/Conversion/TritonToUnstructure/Passes.h.inc" + +#define GEN_PASS_DEF_TRITONTOUNSTRUCTURE +#include "dicp/Conversion/TritonToUnstructure/Passes.h.inc" + +namespace mlir { +namespace triton { + +std::unique_ptr> createTritonToUnstructurePass(); + +} // namespace triton +} // namespace mlir + +namespace { + +using namespace mlir; +using namespace triton; + +// For example, in unstructured load case +// %0 = tt.load %structured : tensor<128x128x!tt.ptr> +// %ptr_2 = tt.splat %arg1 : !tt.ptr -> tensor<128x128x!tt.ptr> +// %1 = tt.addptr %ptr_2, %0 : tensor<128x128x!tt.ptr>, +// tensor<128x128xi32> %2 = tt.load %1 : tensor<128x128x!tt.ptr> tt.store +// %output %2 : tensor<128x128x!tt.ptr> +// +// +// In this case, this will be converted to +// +// %0 = tt.load %structured : tensor<128x128x!tt.ptr> +// %1 = tensor.empty() : tensor<128x128xf32> +// %2 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %1) -> +// (tensor<128x128xf32>) { +// %4 = scf.for %arg4 = %c0 to %c128 step %c1 iter_args(%arg5 = %arg3) -> +// (tensor<128x128xf32>) { +// %extracted = tensor.extract %10[%arg3, %arg5] {DiscreteMemAccess} : +// tensor<128x128xi32> %5 = arith.extsi %extracted : i32 to i64 %6 = +// tt.addptr %arg1, %5 : !tt.ptr, i64 %7 = tt.load %6 +// {DiscreteMemAccess} : tt.ptr %inserted_slice = tensor.insert_slice +// %7 into %arg5[%arg2, %arg4] [1, 1] [128, 1] {DiscreteMemAccess} : +// tensor<1x1xf32> into tensor<128x128xf32> scf.yield %inserted_slice : +// tensor<128x128xf32> +// } +// scf.yield %4 : tensor<128x128xf32> +// } +// tt.store %output %2 : tensor<128x128x!tt.ptr> +template +class UnstructuredMemAccessConverter : public OpRewritePattern { + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v); + +public: + using OpRewritePattern::OpRewritePattern; + + explicit UnstructuredMemAccessConverter( + MLIRContext *context, bool forceScalarizeMode, + const llvm::DenseMap &offsetMap, + const llvm::SmallDenseMap &fromTensorArg); + LogicalResult matchAndRewrite(MemAccOpTy op, + PatternRewriter &rewriter) const override; + +private: + bool checkUnstructureAnnotated(MemAccOpTy op, + PatternRewriter &rewriter) const; + Value createExtractOp(Location loc, Value value, PatternRewriter &rewriter, + ArrayRef iterIdx) const; + Value createExtractOp(Location loc, Value value, PatternRewriter &rewriter, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) const; + template + typename std::enable_if, void>::type + splatAndLoadScenario(MemAccOpTy op, int rank, + PatternRewriter &rewriter) const; + + template + MemAccOpTy createMemAccOp(MemAccOpTy op, Value ptrToAccess, Location loc, + PatternRewriter &rewriter, + Args &&...args) const = delete; + + const llvm::DenseMap &offsetMap; + const llvm::SmallDenseMap &fromTensorArg; + bool forceScalarizeMode; +}; + +class TritonToUnstructurePass + : public ::impl::TritonToUnstructureBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override; + + void runOnOperation() override; + +private: + void runPreparse(LoopLikeOpInterface op); + template || + std::is_same_v || + std::is_same_v || + std::is_same_v>> + void runParse(MemAccOpTy op); + llvm::DenseMap offsetMap; + llvm::DenseMap offsetMapForLoopArgs; + llvm::SmallDenseMap fromTensorArg; +}; + +} // namespace + +#endif // TRITON_DLC_UNSTRUCTURECONVERSION_H diff --git a/compiler/include/dicp/Dialect/TritonExt/Transforms/CanonicalizerPattern.h b/compiler/include/dicp/Dialect/TritonExt/Transforms/CanonicalizerPattern.h index 517ed6fd..b6967c2a 100644 --- a/compiler/include/dicp/Dialect/TritonExt/Transforms/CanonicalizerPattern.h +++ b/compiler/include/dicp/Dialect/TritonExt/Transforms/CanonicalizerPattern.h @@ -1,6 +1,5 @@ - -#ifndef TRITON_ADAPTER_LOADSTORECONVERTER_H -#define TRITON_ADAPTER_LOADSTORECONVERTER_H +#ifndef TRITON_DLC_LOADSTORECONVERTER_H +#define TRITON_DLC_LOADSTORECONVERTER_H #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/AffineMap.h" diff --git a/compiler/include/dicp/Utils/Utils.h b/compiler/include/dicp/Utils/Utils.h index b6fbb04e..e50253dd 100644 --- a/compiler/include/dicp/Utils/Utils.h +++ b/compiler/include/dicp/Utils/Utils.h @@ -35,6 +35,8 @@ namespace mlir::dicp { const std::string GeneratedByMakeTensorPtrTAG = "GeneratedByMakeTensorPtr"; const std::string MayImplicitTransposeWithLastAxisTAG = "MayImplicitTransposeWithLastAxis"; +const std::string discreteMaskAttrName = "DiscreteMask"; +const std::string discreteAttrName = "DiscreteMemAccess"; // Gets the string attribute "dicp.backend" from the module if it exists. llvm::StringRef getBackend(ModuleOp module); @@ -118,6 +120,11 @@ scf::ForOp createNestedLoops( function_ref &, ValueRange)> bodyBuilder); +enum class TypelessValue { Undefined = 0, Zero = 1, Min = 2, Max = 3 }; + +FailureOr specializeTypelessValueToConstant(TypelessValue, Type, + Location, OpBuilder &); + } // namespace mlir::dicp #endif // TRITONNPU_UTILS_UTILS_H \ No newline at end of file diff --git a/compiler/lib/Conversion/DiscreteMaskAccessConversion/CMakeLists.txt b/compiler/lib/Conversion/DiscreteMaskAccessConversion/CMakeLists.txt new file mode 100644 index 00000000..a5b764ae --- /dev/null +++ b/compiler/lib/Conversion/DiscreteMaskAccessConversion/CMakeLists.txt @@ -0,0 +1,15 @@ +add_triton_library(DiscreteMaskAccessConversion + DiscreteMaskAccessConversionPass.cpp + MaskAnalysis.cpp + + DEPENDS + DiscreteMaskAccessConversionPassIncGen + + LINK_LIBS + MLIRIR + MLIRPass + MLIRTransforms + MLIRSupport + TritonIR + MLIRAnalysis +) diff --git a/compiler/lib/Conversion/DiscreteMaskAccessConversion/DiscreteMaskAccessConversionPass.cpp b/compiler/lib/Conversion/DiscreteMaskAccessConversion/DiscreteMaskAccessConversionPass.cpp new file mode 100644 index 00000000..c82a378a --- /dev/null +++ b/compiler/lib/Conversion/DiscreteMaskAccessConversion/DiscreteMaskAccessConversionPass.cpp @@ -0,0 +1,181 @@ +#include "dicp/Conversion/DiscreteMaskAccessConversion/MaskAnalysis.h" +#include "dicp/Conversion/DiscreteMaskAccessConversion/Passes.h" + +#include "dicp/Utils/Utils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/LogicalResult.h" + +bool compileOn91095Flag = false; +bool forceSimtTemplateFlag = false; + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_DISCRETEMASKACCESSCONVERSION +#include "dicp/Conversion/DiscreteMaskAccessConversion/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::dicp; + +LogicalResult isDiscreteMask(Operation *op, Value mask, + PatternRewriter &rewriter) { + if (!mask) + return failure(); + + mlir::dicp::MaskState mstate; + auto isContMask = mstate.parse(mask, op->getLoc(), rewriter); + if (!isContMask.failed()) { + mstate.eraseInsertedOps(op, rewriter); + return failure(); + } + return success(); +} + +struct DiscreteMaskStoreConversion : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(triton::StoreOp op, + PatternRewriter &rewriter) const final { + auto mask = op.getMask(); + auto loc = op.getLoc(); + auto dst = op.getPtr(); + auto src = op.getValue(); + + if (failed(isDiscreteMask(op, mask, rewriter))) + return failure(); + + auto loadFromDstOp = rewriter.create( + loc, dst, op.getCache(), op.getEvict(), false); + + auto selOp = rewriter.create(loc, mask, src, + loadFromDstOp.getResult()); + auto newStore = rewriter.create( + loc, dst, selOp, op.getCache(), op.getEvict()); + newStore->setAttr(mlir::dicp::discreteMaskAttrName, + UnitAttr::get(rewriter.getContext())); + rewriter.replaceOp(op, newStore); + return success(); + } +}; + +struct DiscreteMaskLoadConversion : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(triton::LoadOp op, + PatternRewriter &rewriter) const final { + auto loc = op.getLoc(); + auto other = op.getOther(); + auto mask = op.getMask(); + auto ptr = op.getPtr(); + + if (failed(isDiscreteMask(op, mask, rewriter))) + return failure(); + if (compileOn91095Flag && forceSimtTemplateFlag) + return failure(); + + if (!other) { + FailureOr constant = specializeTypelessValueToConstant( + TypelessValue::Zero, ptr.getType(), loc, rewriter); + // TODO: fix me + if (failed(constant)) { + ptr.getType().dump(); + op->emitRemark() << " Unsupported type for constant creation"; + return failure(); + } + other = *constant; + } + + auto newLoadOp = rewriter.create( + loc, ptr, op.getCache(), op.getEvict(), op.getIsVolatile()); + auto discreteMaskOp = + rewriter.create(loc, mask, newLoadOp, other); + rewriter.replaceOp(op, discreteMaskOp); + return success(); + } +}; + +struct DiscreteMaskAtomicConversion : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(triton::AtomicRMWOp op, + PatternRewriter &rewriter) const final { + auto loc = op.getLoc(); + auto ptr = op.getPtr(); + auto src = op.getVal(); + auto mask = op.getMask(); + auto rmwOp = op.getAtomicRmwOp(); + + if (failed(isDiscreteMask(op, mask, rewriter))) + return failure(); + + const std::map initMap = { + {RMWOp::FADD, TypelessValue::Zero}, + {RMWOp::ADD, TypelessValue::Zero}, + {RMWOp::UMAX, TypelessValue::Zero}, + {RMWOp::OR, TypelessValue::Zero}, + {RMWOp::MIN, TypelessValue::Max}, + {RMWOp::UMIN, TypelessValue::Max}, + {RMWOp::AND, TypelessValue::Max}, + {RMWOp::MAX, TypelessValue::Min}, + {RMWOp::XOR, TypelessValue::Zero}, + {RMWOp::XCHG, TypelessValue::Undefined}, + }; + assert(initMap.find(rmwOp) != initMap.end()); + auto typelessVal = initMap.at(rmwOp); + if (typelessVal == TypelessValue::Undefined) { + // Undefined default value atomic op will be decomposed in AscendNPU-IR + op->setAttr(mlir::dicp::discreteMaskAttrName, + UnitAttr::get(rewriter.getContext())); + return failure(); + } + + FailureOr fill = specializeTypelessValueToConstant( + typelessVal, src.getType(), loc, rewriter); + if (failed(fill)) + op->emitError("Unsupported atomic operation."); + + auto maskedValue = rewriter.create(loc, mask, src, *fill); + auto newAtomicOp = rewriter.create( + loc, src.getType(), rmwOp, ptr, maskedValue, mlir::Value(), op.getSem(), + op.getScope()); + rewriter.replaceOp(op, newAtomicOp); + return success(); + } +}; + +struct DiscreteMaskAccessConversionPass + : mlir::triton::impl::DiscreteMaskAccessConversionBase< + DiscreteMaskAccessConversionPass> { + + DiscreteMaskAccessConversionPass( + const DiscreteMaskAccessConversionOptions &options) + : DiscreteMaskAccessConversionBase(options) {} + + void runOnOperation() override { + compileOn91095Flag = this->compileOn91095; + forceSimtTemplateFlag = this->forceSimtTemplate; + + auto moduleOp = getOperation(); + + RewritePatternSet patterns(&getContext()); + patterns.add(patterns.getContext()); + if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) { + moduleOp->emitError("failed to apply discrete mask access patterns"); + signalPassFailure(); + } + } +}; + +std::unique_ptr> +mlir::triton::createDiscreteMaskAccessConversionPass( + const DiscreteMaskAccessConversionOptions &options) { + return std::make_unique(options); +} diff --git a/compiler/lib/Conversion/DiscreteMaskAccessConversion/MaskAnalysis.cpp b/compiler/lib/Conversion/DiscreteMaskAccessConversion/MaskAnalysis.cpp new file mode 100644 index 00000000..e9e9ea55 --- /dev/null +++ b/compiler/lib/Conversion/DiscreteMaskAccessConversion/MaskAnalysis.cpp @@ -0,0 +1,848 @@ +#include "dicp/Conversion/DiscreteMaskAccessConversion/MaskAnalysis.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define DEBUG_TYPE "dicp-mask-analysis" + +namespace mlir { + +static Value createConstIndexValueOp(const Location &loc, OpBuilder &b, + int64_t value) { + return b.create(loc, b.getIndexAttr(value)).getResult(); +} + +static std::optional getConstantOfAttr(const OpFoldResult &arg) { + if (isa(arg)) { + return getConstantIntValue(arg); + } + + return std::nullopt; +} + +static bool isZeroIndex(OpFoldResult v) { + if (!v) + return false; + if (auto attr = dyn_cast(v)) { + IntegerAttr intAttr = dyn_cast(attr); + return intAttr && intAttr.getValue().isZero(); + } + return false; +} + +OpFoldResult addOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b) { + auto lhsInt = getConstantOfAttr(lhs); + auto rhsInt = getConstantOfAttr(rhs); + + if (lhsInt && rhsInt) + return b.getIndexAttr(lhsInt.value() + rhsInt.value()); + + if (!lhsInt && rhsInt && rhsInt.value() == 0) + return lhs; + if (!rhsInt && lhsInt && lhsInt.value() == 0) + return rhs; + + auto lhsValue = dyn_cast(lhs); + if (lhsInt) + lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); + else + assert(isa(lhsValue.getType())); + + auto rhsValue = dyn_cast(rhs); + if (rhsInt) + rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); + else + assert(isa(rhsValue.getType())); + + return b.create(loc, lhsValue, rhsValue).getResult(); +} + +OpFoldResult subOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b) { + auto lhsInt = getConstantOfAttr(lhs); + auto rhsInt = getConstantOfAttr(rhs); + + if (lhsInt && rhsInt) + return b.getIndexAttr(lhsInt.value() - rhsInt.value()); + + if (!lhsInt && rhsInt && rhsInt.value() == 0) + return lhs; + + auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); + if (lhsInt) + lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); + else + assert(isa(lhsValue.getType())); + + if (rhsInt) + rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); + else + assert(isa(rhsValue.getType())); + + return b.create(loc, lhsValue, rhsValue).getResult(); +} + +OpFoldResult mulOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b) { + auto lhsInt = getConstantOfAttr(lhs); + auto rhsInt = getConstantOfAttr(rhs); + + if (lhsInt && rhsInt) + return b.getIndexAttr(lhsInt.value() * rhsInt.value()); + + if (lhsInt) { + if (lhsInt.value() == 0) + return lhs; + if (lhsInt.value() == 1) + return rhs; + } + if (rhsInt) { + if (rhsInt.value() == 0) + return rhs; + if (rhsInt.value() == 1) + return lhs; + } + + auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); + if (lhsInt) + lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); + else + assert(isa(lhsValue.getType())); + + if (rhsInt) + rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); + else + assert(isa(rhsValue.getType())); + + return b.create(loc, lhsValue, rhsValue).getResult(); +} + +OpFoldResult divOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b) { + auto lhsInt = getConstantOfAttr(lhs); + auto rhsInt = getConstantOfAttr(rhs); + + if (rhsInt && rhsInt.value() == 0) { + emitError(loc) << "cannot div 0!"; + return OpFoldResult(); + } + + if (lhsInt && rhsInt) + return b.getIndexAttr(lhsInt.value() / rhsInt.value()); + + if (lhsInt) { + if (lhsInt.value() == 0) + return lhs; + } + + if (rhsInt) { + if (rhsInt.value() == 1) + return lhs; + } + + auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); + if (lhsInt) + lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); + else + assert(isa(lhsValue.getType())); + + if (rhsInt) + rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); + else + assert(isa(rhsValue.getType())); + + return b.create(loc, lhsValue, rhsValue).getResult(); +} + +OpFoldResult remOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b) { + auto lhsInt = getConstantOfAttr(lhs); + auto rhsInt = getConstantOfAttr(rhs); + + if (rhsInt && rhsInt.value() == 0) { + emitError(loc) << "cannot remainder by 0!"; + return OpFoldResult(); + } + + if (lhsInt && rhsInt) + return b.getIndexAttr(lhsInt.value() % rhsInt.value()); + + if (lhsInt) { + if (lhsInt.value() == 0) + return lhs; + } + + auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); + if (lhsInt) + lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); + else + assert(isa(lhsValue.getType())); + + if (rhsInt) + rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); + else + assert(isa(rhsValue.getType())); + + return b.create(loc, lhsValue, rhsValue).getResult(); +} + +OpFoldResult minOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b) { + auto lhsInt = getConstantOfAttr(lhs); + auto rhsInt = getConstantOfAttr(rhs); + if (lhsInt && rhsInt) + return b.getIndexAttr(std::min(lhsInt.value(), rhsInt.value())); + + auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); + if (lhsInt) + lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); + else + assert(isa(lhsValue.getType())); + + if (rhsInt) + rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); + else + assert(isa(rhsValue.getType())); + + return b.create(loc, lhsValue, rhsValue).getResult(); +} + +OpFoldResult maxOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b) { + auto lhsInt = getConstantOfAttr(lhs); + auto rhsInt = getConstantOfAttr(rhs); + if (lhsInt && rhsInt) + return b.getIndexAttr(std::max(lhsInt.value(), rhsInt.value())); + + auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); + if (lhsInt) + lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); + else + assert(isa(lhsValue.getType())); + + if (rhsInt) + rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); + else + assert(isa(rhsValue.getType())); + + return b.create(loc, lhsValue, rhsValue).getResult(); +} + +// Fold layout constant info to attr, otherwise convert to index type value +OpFoldResult getOpFoldResultOfLayoutInfo(Value value, OpBuilder &builder) { + OpFoldResult constantFold = getAsOpFoldResult(value); + if (llvm::isa(constantFold)) { + assert(isa(cast(constantFold))); + return constantFold; + } + + if (!isa(value.getType())) + llvm_unreachable("Illegal data type when parse block data layout info"); + + if (!isa(value.getType())) { + if (value.getType().isInteger(/*width*/ 1)) + value = builder.create( + value.getLoc(), builder.getIndexType(), value); + else + value = builder.create(value.getLoc(), + builder.getIndexType(), value); + } + + return value; +} + +namespace dicp { + +LogicalResult MaskState::parse(Value operand, const Location &loc, + OpBuilder &builder) { + if (isa(operand.getType())) { + return parseIntScalar(operand, loc, builder); + } + + if (auto blockArgument = dyn_cast(operand)) { + auto parentOp = blockArgument.getOwner()->getParentOp(); + if (auto loopOp = dyn_cast(parentOp)) { + OpOperand *initArgOperand = loopOp.getTiedLoopInit(blockArgument); + if (initArgOperand) { + Value initArg = initArgOperand->get(); + return parse(initArg, loc, builder); + } + } + } + + auto definingOp = operand.getDefiningOp(); + if (!definingOp) + return failure(); + + LLVM_DEBUG({ + llvm::dbgs() << "[MaskState]==> parse op\n" + << *definingOp << "\n[MaskState]<==\n"; + }); + return TypeSwitch(definingOp) + .Case( + [&](auto op) { return this->parseConstant(op, loc, builder); }) + .Case( + [&](auto op) { return this->parseAdd(op, loc, builder); }) + .Case( + [&](auto op) { return this->parseAnd(op, loc, builder); }) + .Case( + [&](auto op) { return this->parseCmp(op, loc, builder); }) + .Case( + [&](auto op) { return this->parseMakeRange(op, loc, builder); }) + .Case( + [&](auto op) { return this->parseBroadcast(op, loc, builder); }) + .Case( + [&](auto op) { return this->parseSplat(op, loc, builder); }) + .Case( + [&](auto op) { return this->parseExpandDims(op, loc, builder); }) + .Case( + [&](auto op) { return this->parse(op.getIn(), loc, builder); }) + .Case( + [&](auto op) { return this->parseDiv(op, loc, builder); }) + .Case( + [&](auto op) { return this->parseSel(op, loc, builder); }) + .Default([&](Operation *op) { return failure(); }); +} + +// extractSlice +tensor::ExtractSliceOp MaskState::getExtractSlice(Value source, + const Location &loc, + OpBuilder &builder) const { + auto sourceRType = cast(source.getType()); + SmallVector strides(getRank(), builder.getIndexAttr(1)); + + auto dstRType = tensor::ExtractSliceOp::inferResultType(sourceRType, offsets, + dims, strides); + return builder.create(loc, dstRType, source, offsets, + dims, strides); +} + +tensor::InsertSliceOp MaskState::getInsertSlice(Value source, Value dest, + const Location &loc, + OpBuilder &builder) const { + SmallVector strides(getRank(), builder.getIndexAttr(1)); + return builder.create(loc, source, dest, offsets, dims, + strides); +} + +memref::SubViewOp MaskState::getSubview(Value source, const Location &loc, + OpBuilder &builder) const { + auto sourceType = cast(source.getType()); + SmallVector strides(getRank(), builder.getIndexAttr(1)); + auto dstType = + memref::SubViewOp::inferResultType(sourceType, offsets, dims, strides); + return builder.create(loc, cast(dstType), + source, offsets, dims, strides); +} + +static memref::SubViewOp createSubview(Value src, const Location &loc, + OpBuilder &builder, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) { + auto srcType = cast(src.getType()); + auto dstType = + memref::SubViewOp::inferResultType(srcType, offsets, sizes, strides); + return builder.create(loc, cast(dstType), src, + offsets, sizes, strides); +} + +LogicalResult MaskState::addStateScalar(const MaskState &state, + const OpFoldResult scalar, + const Location &loc, + OpBuilder &builder) { + start = addOpFoldResult(state.start, scalar, loc, builder); + end = addOpFoldResult(state.end, scalar, loc, builder); + dims = state.dims; + offsets = state.offsets; + return success(); +} + +LogicalResult MaskState::addStates(const MaskState &lhsState, + const MaskState &rhsState, + const Location &loc, OpBuilder &builder) { + if (lhsState.scalar && rhsState.scalar) { + InFlightDiagnostic diag = + emitWarning(loc) + << "Unexpected case where both lhs and rhs are scalars"; + return failure(); + } + if (!lhsState.scalar && !rhsState.scalar) { + InFlightDiagnostic diag = + emitWarning(loc) + << "Unsupported scenario where neither lhs nor rhs is a scalar"; + return failure(); + } + + if (lhsState.scalar) { + return addStateScalar(rhsState, lhsState.scalar, loc, builder); + } else { + return addStateScalar(lhsState, rhsState.scalar, loc, builder); + } +} + +LogicalResult MaskState::divStateScalar(const MaskState &state, + const OpFoldResult scalar, + const Location &loc, + OpBuilder &builder) { + start = divOpFoldResult(state.start, scalar, loc, builder); + end = divOpFoldResult(state.end, scalar, loc, builder); + dims = state.dims; + offsets = state.offsets; + return success(); +} + +LogicalResult MaskState::divStates(const MaskState &lhsState, + const MaskState &rhsState, + const Location &loc, OpBuilder &builder) { + if (!lhsState.scalar && rhsState.scalar) { + if (isZeroIndex(rhsState.scalar)) { + InFlightDiagnostic diag = + emitError(loc) + << "Unsupported scenario where rhs is zero constant in divide!"; + return failure(); + } + + return divStateScalar(lhsState, rhsState.scalar, loc, builder); + } + + InFlightDiagnostic diag = emitWarning(loc) + << "Supported scenario where only rhs is a scalar"; + return failure(); +} + +LogicalResult MaskState::minStates(const MaskState &lhsState, + const MaskState &rhsState, + const Location &loc, OpBuilder &builder) { + if (lhsState.getRank() != rhsState.getRank()) { + InFlightDiagnostic diag = + emitError(loc) + << "Unexpected case where lhs and rhs have different ranks"; + return failure(); + } + + for (uint32_t i = 0; i < lhsState.getRank(); i++) { + auto lhsOffset = lhsState.offsets[i]; + auto rhsOffset = rhsState.offsets[i]; + auto newOffset = maxOpFoldResult(lhsOffset, rhsOffset, loc, builder); + auto lhsDim = lhsState.dims[i]; + auto rhsDim = rhsState.dims[i]; + auto lhsEnd = addOpFoldResult(lhsOffset, lhsDim, loc, builder); + auto rhsEnd = addOpFoldResult(rhsOffset, rhsDim, loc, builder); + auto newEnd = minOpFoldResult(lhsEnd, rhsEnd, loc, builder); + auto newDim = subOpFoldResult(newEnd, newOffset, loc, builder); + + offsets.push_back(newOffset); + dims.push_back(newDim); + } + return success(); +} + +// Helper func for MaskState::parse() +LogicalResult MaskState::parseConstant(arith::ConstantOp constOp, + const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + if (isa(constOp.getValue())) { + auto attr = cast(constOp.getValue()); + auto elementType = attr.getElementType(); + assert(attr.isSplat() && isa(elementType) && + "All elements must share a single integer constant value"); + this->scalar = builder.getIndexAttr( + attr.getSplatValue().getValue().getSExtValue()); + } else { + auto value = cast(constOp.getValue()).getInt(); + this->scalar = builder.getIndexAttr(value); + } + return success(); +} + +// parseIntScalar +LogicalResult MaskState::parseIntScalar(Value scalar, const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + this->scalar = getOpFoldResultOfLayoutInfo(scalar, builder); + return success(); +} + +LogicalResult MaskState::parseAdd(arith::AddIOp addOp, const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + MaskState lhsState; + if (failed(lhsState.parse(addOp.getLhs(), loc, builder))) { + return failure(); + } + + MaskState rhsState; + if (failed(rhsState.parse(addOp.getRhs(), loc, builder))) { + return failure(); + } + return this->addStates(lhsState, rhsState, loc, builder); +} + +LogicalResult MaskState::parseDiv(arith::DivSIOp divOp, const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + MaskState lhsState; + if (failed(lhsState.parse(divOp.getLhs(), loc, builder))) { + return failure(); + } + + MaskState rhsState; + if (failed(rhsState.parse(divOp.getRhs(), loc, builder))) { + return failure(); + } + return this->divStates(lhsState, rhsState, loc, builder); +} + +LogicalResult MaskState::parseAnd(arith::AndIOp andOp, const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + MaskState lhsState; + if (failed(lhsState.parse(andOp.getLhs(), loc, builder)) || + !lhsState.isMask()) { + return failure(); + } + + MaskState rhsState; + if (failed(rhsState.parse(andOp.getRhs(), loc, builder)) || + !rhsState.isMask()) { + return failure(); + } + + if (!lhsState.isMask() && !rhsState.isMask()) { + return failure(); + } + + // Only support both lhs and rhs satisfy `isMask` condition + return this->minStates(lhsState, rhsState, loc, builder); +} + +LogicalResult MaskState::parseSel(arith::SelectOp selOp, const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + auto trueValue = selOp.getTrueValue(); + auto falseValue = selOp.getFalseValue(); + + MaskState condState; + auto condition = selOp.getCondition(); + auto cmpOp = condition.getDefiningOp(); + if (!cmpOp || failed(condState.parse(condition, loc, builder))) { + return failure(); + } + + MaskState trueState; + if (failed(trueState.parse(trueValue, loc, builder)) || !trueState.scalar) { + return failure(); + } + + MaskState falseState; + if (failed(falseState.parse(falseValue, loc, builder)) || + !falseState.scalar) { + return failure(); + } + + auto trueScalar = dyn_cast(cast(trueState.scalar)); + auto falseScalar = dyn_cast(cast(falseState.scalar)); + + if (trueScalar && falseScalar) { + if (trueScalar.getInt() == 1 && falseScalar.getInt() == 0) { + start = condState.start; + end = condState.end; + dims = condState.dims; + offsets = condState.offsets; + return success(); + } + } + + return failure(); +} + +LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + auto predicate = cmpOp.getPredicate(); + // Only support <, <=, >=, =, != + if (predicate != arith::CmpIPredicate::slt && + predicate != arith::CmpIPredicate::sle && + predicate != arith::CmpIPredicate::sge && + predicate != arith::CmpIPredicate::eq && + predicate != arith::CmpIPredicate::ne) { + LLVM_DEBUG({ llvm::dbgs() << "Unsupported cmpi predicate\n"; }); + return failure(); + } + + MaskState lhsState; + MaskState rhsState; + auto lhs = cmpOp.getLhs(); + auto rhs = cmpOp.getRhs(); + + if (predicate == arith::CmpIPredicate::ne) { + auto selOp = lhs.getDefiningOp(); + auto constantOp = rhs.getDefiningOp(); + if (!selOp || !constantOp) { + return failure(); + } + } + + if (failed(lhsState.parse(lhs, loc, builder))) { + return failure(); + } + + if (failed(rhsState.parse(rhs, loc, builder))) { + return failure(); + } + + if (!(!lhsState.scalar && rhsState.scalar)) { + InFlightDiagnostic diag = emitWarning(loc) + << "[MaskState] Unsupported cmpi scenario"; + return failure(); + } + + int32_t cmpDim = -1; + for (int32_t i = 0; i < lhsState.getRank(); i++) { + auto constDimLength = getConstantIntValue(lhsState.dims[i]); + if (!constDimLength || constDimLength.value() != 1) { + if (cmpDim != -1) { + InFlightDiagnostic diag = emitWarning(loc) + << "Unsupported cmpi with more than one " + "dimension with size larger than 1"; + return failure(); + } + cmpDim = i; + } + } + + assert(cmpDim != -1 && + "Unexpected case where no dimension has size larger than 1"); + + this->offsets = lhsState.offsets; + this->dims = lhsState.dims; + switch (predicate) { + case arith::CmpIPredicate::slt: { + auto realBound = + maxOpFoldResult(lhsState.start, rhsState.scalar, loc, builder); + auto newEnd = minOpFoldResult(lhsState.end, realBound, loc, builder); + auto newDim = subOpFoldResult(newEnd, lhsState.start, loc, builder); + + this->dims[cmpDim] = newDim; + break; + } + case arith::CmpIPredicate::sle: { + // lhs <= rhs <=> lhs < rhs + 1 + auto rhsPlusOne = + addOpFoldResult(rhsState.scalar, builder.getIndexAttr(1), loc, builder); + auto realBound = maxOpFoldResult(lhsState.start, rhsPlusOne, loc, builder); + auto newEnd = minOpFoldResult(lhsState.end, realBound, loc, builder); + auto newDim = subOpFoldResult(newEnd, lhsState.start, loc, builder); + + this->dims[cmpDim] = newDim; + break; + } + case arith::CmpIPredicate::sge: { + auto realBound = + maxOpFoldResult(lhsState.start, rhsState.scalar, loc, builder); + auto newStart = minOpFoldResult(lhsState.end, realBound, loc, builder); + auto newOffset = subOpFoldResult(newStart, lhsState.start, loc, builder); + auto newDim = subOpFoldResult(lhsState.end, newStart, loc, builder); + + this->offsets[cmpDim] = newOffset; + this->dims[cmpDim] = newDim; + break; + } + case arith::CmpIPredicate::eq: { + auto newOffset = + subOpFoldResult(rhsState.scalar, lhsState.start, loc, builder); + auto newDim = builder.getIndexAttr(1); + + this->offsets[cmpDim] = newOffset; + this->dims[cmpDim] = newDim; + break; + } + case arith::CmpIPredicate::ne: { + // only support lhs != 0 + auto rhsScalar = dyn_cast(cast(rhsState.scalar)); + if (!rhsScalar || rhsScalar.getInt() != 0) { + return failure(); + } + + start = lhsState.start; + end = lhsState.end; + break; + } + default: + return failure(); + } + return success(); +} + +LogicalResult MaskState::parseMakeRange(triton::MakeRangeOp rangeOp, + const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + auto shape = cast(rangeOp.getType()).getShape(); + auto start = rangeOp.getStart(); + auto end = rangeOp.getEnd(); + auto stride = (end - start + shape[0] - 1) / shape[0]; + + if (stride != 1) { + InFlightDiagnostic diag = + emitWarning(loc) + << "stride must be 1 for make_range whose result is used " + "as load or store masks"; + return failure(); + } + + this->start = builder.getIndexAttr(start); + this->end = builder.getIndexAttr(end); + this->dims.push_back(builder.getIndexAttr(shape[0])); + this->offsets.push_back(builder.getIndexAttr(0)); + return success(); +} + +LogicalResult MaskState::parseBroadcast(triton::BroadcastOp broadcastOp, + const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + auto src = broadcastOp.getSrc(); + auto dst = broadcastOp.getResult(); + assert(isa(src.getType()) && + "input to tt.broadcast should be a tensor"); + + auto srcShape = cast(src.getType()).getShape(); + auto dstShape = cast(dst.getType()).getShape(); + assert(srcShape.size() == dstShape.size() && + "rank of source and destination should match"); + + if (failed(parse(src, loc, builder))) { + return failure(); + } + for (size_t i = 0; i < srcShape.size(); i++) { + if (srcShape[i] == dstShape[i]) + continue; + else if (srcShape[i] < dstShape[i]) + this->dims[i] = builder.getIndexAttr(dstShape[i]); + else + llvm_unreachable("unexpected dimensions used in broadcast"); + } + return success(); +} + +LogicalResult MaskState::parseSplat(triton::SplatOp splatOp, + const Location &loc, OpBuilder &builder) { + assert(this->isEmpty()); + + auto src = splatOp.getSrc(); + auto dst = splatOp.getResult(); + auto dstShape = cast(dst.getType()).getShape(); + + if (!isa(src.getType())) { + InFlightDiagnostic diag = + emitWarning(loc) + << "splat source must be an integer scalar for load/store masks"; + return failure(); + } + + if (failed(this->parse(src, loc, builder))) + return failure(); + + auto splatAsMask = [&](Operation *userOp) -> bool { + return TypeSwitch(userOp) + .Case([&](arith::AndIOp andOp) { return true; }) + .Case([&](arith::SelectOp selectOp) { + return selectOp.getCondition() == dst; + }) + .Case( + [&](triton::LoadOp loadOp) { return loadOp.getMask() == dst; }) + .Case( + [&](triton::StoreOp storeOp) { return storeOp.getMask() == dst; }) + .Default([&](Operation *op) { return false; }); + }; + + if (src.getType().isInteger(1) && !splatOp->use_empty() && + llvm::all_of(splatOp->getUsers(), splatAsMask)) { + for (auto s : dstShape) { + auto currentDim = + mulOpFoldResult(builder.getIndexAttr(s), this->scalar, loc, builder); + this->dims.push_back(currentDim); + this->offsets.push_back(builder.getIndexAttr(0)); + } + + this->scalar = nullptr; + return success(); + } + + for (auto s : dstShape) { + this->dims.push_back(builder.getIndexAttr(s)); + this->offsets.push_back(builder.getIndexAttr(0)); + } + return success(); +} + +LogicalResult MaskState::parseExpandDims(triton::ExpandDimsOp expandDimsOp, + const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + if (failed(this->parse(expandDimsOp.getSrc(), loc, builder))) { + return failure(); + } + + auto dstShape = + cast(expandDimsOp.getResult().getType()).getShape(); + auto axis = expandDimsOp.getAxis(); + assert(dstShape[axis] == 1 && + "Expect changed dimention to be 1 in expand_dims"); + this->dims.insert(this->dims.begin() + axis, builder.getIndexAttr(1)); + this->offsets.insert(this->offsets.begin() + axis, builder.getIndexAttr(0)); + + return success(); +} + +void MaskState::eraseInsertedOps(Operation *rawOp, PatternRewriter &rewriter) { + auto moduleOp = rawOp->getParentOfType(); + SmallVector worklist; + moduleOp->walk([&](Operation *op) { + if (isOpTriviallyDead(op)) + worklist.push_back(op); + }); + while (!worklist.empty()) { + Operation *op = worklist.pop_back_val(); + if (!isOpTriviallyDead(op)) + continue; + for (Value value : op->getOperands()) { + if (auto defOp = value.getDefiningOp()) + worklist.push_back(defOp); + } + LLVM_DEBUG({ + llvm::dbgs() << "[MaskState]==> inserted op: \n" + << *op << "\n[MaskState]<== is removed\n"; + }); + rewriter.eraseOp(op); + } +} + +} // namespace dicp + +} // namespace mlir diff --git a/compiler/lib/Conversion/TritonToLinalgNPU/MemRefCopyGatherToTensorInsert/CMakeLists.txt b/compiler/lib/Conversion/TritonToLinalgNPU/MemRefCopyGatherToTensorInsert/CMakeLists.txt new file mode 100644 index 00000000..9f9fd662 --- /dev/null +++ b/compiler/lib/Conversion/TritonToLinalgNPU/MemRefCopyGatherToTensorInsert/CMakeLists.txt @@ -0,0 +1,13 @@ +add_triton_library(MemRefCopyGatherToTensorInsert + MemRefCopyGatherToTensorInsert.cpp + + DEPENDS + MemRefCopyGatherToTensorInsertPassIncGen + + LINK_LIBS + MLIRIR + MLIRPass + MLIRTransforms + MLIRSupport + TritonIR +) diff --git a/compiler/lib/Conversion/TritonToLinalgNPU/MemRefCopyGatherToTensorInsert/MemRefCopyGatherToTensorInsert.cpp b/compiler/lib/Conversion/TritonToLinalgNPU/MemRefCopyGatherToTensorInsert/MemRefCopyGatherToTensorInsert.cpp new file mode 100644 index 00000000..467adbeb --- /dev/null +++ b/compiler/lib/Conversion/TritonToLinalgNPU/MemRefCopyGatherToTensorInsert/MemRefCopyGatherToTensorInsert.cpp @@ -0,0 +1,283 @@ +#include "dicp/Conversion/TritonToLinalgNPU/MemRefCopyGatherToTensorInsert/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace mlir::dicp::linked { +#define GEN_PASS_DEF_MEMREFCOPYGATHERTOTENSORINSERT +#include "dicp/Conversion/TritonToLinalgNPU/MemRefCopyGatherToTensorInsert/Passes.h.inc" +} // namespace mlir::dicp::linked + +namespace { + +/// Helper to convert OpFoldResult to Value. +/// If it is an attribute, creates an arith.constant. +static Value getValueOrCreateConstantIndexOp(PatternRewriter &rewriter, + Location loc, OpFoldResult ofr) { + if (auto val = dyn_cast(ofr)) + return val; + return rewriter.create( + loc, cast(cast(ofr)).getInt()); +} + +/// Helper function to check if a Value is derived from a Tensor ExtractOp. +/// It supports two chains: +/// 1. ExtractOp -> Value (Target) +/// 2. ExtractOp -> IndexCastOp -> Value (Target) +/// Returns the defining ExtractOp if a match is found involving the loop IV. +static tensor::ExtractOp findSourceExtractOp(Value val, Value loopIV) { + Operation *defOp = val.getDefiningOp(); + if (!defOp) + return nullptr; + + // Case 1: Direct ExtractOp + if (auto extractOp = dyn_cast(defOp)) { + for (Value idx : extractOp.getIndices()) { + if (idx == loopIV) + return extractOp; + } + return nullptr; + } + + // Case 2: ExtractOp -> IndexCastOp + if (auto indexCastOp = dyn_cast(defOp)) { + if (auto extractOp = + indexCastOp.getIn().getDefiningOp()) { + for (Value idx : extractOp.getIndices()) { + if (idx == loopIV) + return extractOp; + } + } + } + + return nullptr; +} + +/// Pattern to convert a specific memory gather loop into a tensor insertion +/// loop. +/// +/// Source Pattern: +/// scf.for %iv ... { +/// %idx = tensor.extract[%iv] <-- (Optional IndexCast) +/// %view = memref.reinterpret_cast %src to offset: [%idx] ... +/// %subview = memref.subview %alloc[%iv] ... +/// memref.copy %view, %subview +/// } +/// +/// Target Pattern: +/// %res = scf.for %iv ... iter_args(%acc = %empty) { +/// %idx = tensor.extract[%iv] +/// %cast_idx = arith.index_cast %idx <-- (If needed) +/// %view = memref.reinterpret_cast %src to offset: [%cast_idx] ... +/// %val = memref.load %view[0, 0, ...] <-- Matches rank +/// %next = tensor.insert %val into %acc[%iv, ...] <-- Matches subview +/// offsets scf.yield %next +/// } +struct MemRefCopyGatherToTensorInsertPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForOp forOp, + PatternRewriter &rewriter) const override { + // 1. Analyze Loop Body: Must contain exactly one memref.copy + memref::CopyOp copyOp; + int copyCount = 0; + forOp.getBody()->walk([&](memref::CopyOp op) { + copyOp = op; + copyCount++; + }); + + if (copyCount != 1 || !copyOp) + return failure(); + + // 2. Validate Target Semantics (Write Side) + // Expectation: Copy -> SubView -> Alloc -> ToTensor + auto subViewOp = copyOp.getTarget().getDefiningOp(); + if (!subViewOp) + return failure(); + + auto allocOp = subViewOp.getSource().getDefiningOp(); + if (!allocOp) + return failure(); + + bufferization::ToTensorOp toTensorOp; + for (Operation *user : allocOp->getUsers()) { + if (auto op = dyn_cast(user)) { + toTensorOp = op; + break; + } + } + + if (!toTensorOp || toTensorOp->getBlock() != forOp->getBlock() || + !forOp->isBeforeInBlock(toTensorOp)) { + return failure(); + } + + // 3. Validate Source Semantics (Read Side) + // Expectation: (Extract -> Optional Cast) -> ReinterpretCast -> Copy Source + auto reinterpretOp = + copyOp.getSource().getDefiningOp(); + if (!reinterpretOp) + return failure(); + + tensor::ExtractOp extractOp; + bool patternFound = false; + + // Check dynamic offsets to find the one driven by the loop induction + // variable. + for (OpFoldResult ofr : reinterpretOp.getOffsets()) { + if (auto val = dyn_cast(ofr)) { + extractOp = findSourceExtractOp(val, forOp.getInductionVar()); + if (extractOp) { + patternFound = true; + break; + } + } + } + + if (!patternFound) + return failure(); + + // ==================================================== + // Rewrite Phase + // ==================================================== + Location loc = forOp.getLoc(); + + // A. Prepare Accumulator (tensor.empty) + auto resultType = cast(toTensorOp.getResult().getType()); + Value initTensor = rewriter.create( + loc, resultType.getShape(), resultType.getElementType()); + + // B. Create New Loop + auto newForOp = rewriter.create( + loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), + ValueRange{initTensor}); + + newForOp->setAttr("ExtractedLoadOrStore", rewriter.getUnitAttr()); + + // C. Populate New Loop Body + rewriter.setInsertionPointToStart(newForOp.getBody()); + + Value iv = newForOp.getInductionVar(); + Value acc = newForOp.getRegionIterArgs()[0]; + + // C.1. Recreate Index Calculation + // We clone the indices from the original extractOp. + // If an index was the old loop's IV, replace it with the new loop's IV. + SmallVector extractIndices; + for (Value idx : extractOp.getIndices()) { + if (idx == forOp.getInductionVar()) + extractIndices.push_back(iv); + else + extractIndices.push_back(idx); + } + + Value newExtract = rewriter.create( + extractOp.getLoc(), extractOp.getTensor(), extractIndices); + + Value newOffsetIdx = newExtract; + if (!newOffsetIdx.getType().isIndex()) { + newOffsetIdx = rewriter.create( + loc, rewriter.getIndexType(), newExtract); + } + + // C.2. Recreate ReinterpretCast + // Map the matched dynamic offset to our new calculated index. + OpFoldResult newOffsetOfr = rewriter.getIndexAttr(0); + + if (!reinterpretOp.getMixedOffsets().empty()) { + OpFoldResult oldOfr = reinterpretOp.getMixedOffsets()[0]; + + // If the old offset matches our pattern, replace it. + bool isTargetOffset = false; + if (auto val = dyn_cast(oldOfr)) { + if (findSourceExtractOp(val, forOp.getInductionVar())) { + isTargetOffset = true; + } + } + newOffsetOfr = isTargetOffset ? newOffsetIdx : oldOfr; + } + + Value newSrcMemref = rewriter.create( + reinterpretOp.getLoc(), reinterpretOp.getType(), + reinterpretOp.getSource(), newOffsetOfr, reinterpretOp.getMixedSizes(), + reinterpretOp.getMixedStrides()); + + // C.3. Load from Source + // FIX: Ensure we provide indices for ALL dimensions of the reinterpreted + // memref. We assume we are loading from the base (0, 0, ...). + auto memRefType = cast(newSrcMemref.getType()); + Value c0 = rewriter.create(loc, 0); + SmallVector loadIndices(memRefType.getRank(), c0); + + Value loadedVal = + rewriter.create(loc, newSrcMemref, loadIndices); + + // C.4. Insert into Accumulator + // FIX: Derive insertion indices from the original subview offsets. + // This handles multi-dimensional tensors correctly by mapping the subview + // logic. + SmallVector insertIndices; + for (OpFoldResult ofr : subViewOp.getMixedOffsets()) { + if (auto val = dyn_cast(ofr)) { + // If the offset is the old IV, use the new IV. + if (val == forOp.getInductionVar()) + insertIndices.push_back(iv); + else + insertIndices.push_back(val); + } else { + // Materialize static offset as a constant index. + insertIndices.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); + } + } + + Value nextTensor = + rewriter.create(loc, loadedVal, acc, insertIndices); + + // C.5. Yield + auto yieldOp = rewriter.create(loc, nextTensor); + yieldOp->setAttr("DiscreteMemAccess", rewriter.getUnitAttr()); + + // D. Finalize + rewriter.replaceOp(toTensorOp, newForOp.getResult(0)); + + // Cleanup + rewriter.eraseOp(forOp); + rewriter.eraseOp(allocOp); + + return success(); + } +}; + +struct MemRefCopyGatherToTensorInsertPass + : mlir::dicp::linked::impl::MemRefCopyGatherToTensorInsertBase< + MemRefCopyGatherToTensorInsertPass> { + + void runOnOperation() override { + auto moduleOp = getOperation(); + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + + if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> +mlir::dicp::linked::createMemRefCopyGatherToTensorInsertPass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/compiler/lib/Conversion/TritonToLinalgNPU/TritonToLinalgNPUCoversion/CMakeLists.txt b/compiler/lib/Conversion/TritonToLinalgNPU/TritonToLinalgNPUCoversion/CMakeLists.txt index 091bd7d3..52072492 100644 --- a/compiler/lib/Conversion/TritonToLinalgNPU/TritonToLinalgNPUCoversion/CMakeLists.txt +++ b/compiler/lib/Conversion/TritonToLinalgNPU/TritonToLinalgNPUCoversion/CMakeLists.txt @@ -28,4 +28,6 @@ add_triton_library(TritonToLinalgNPUCoversion TritonToUnstructured TritonPtrToMemref UnstructuredToMemref + MemRefCopyGatherToTensorInsert + TritonToUnstructure ) diff --git a/compiler/lib/Conversion/TritonToLinalgNPU/TritonToLinalgNPUCoversion/TritonToLinalgNPUConversionPass.cpp b/compiler/lib/Conversion/TritonToLinalgNPU/TritonToLinalgNPUCoversion/TritonToLinalgNPUConversionPass.cpp index 78d524cb..237d36eb 100644 --- a/compiler/lib/Conversion/TritonToLinalgNPU/TritonToLinalgNPUCoversion/TritonToLinalgNPUConversionPass.cpp +++ b/compiler/lib/Conversion/TritonToLinalgNPU/TritonToLinalgNPUCoversion/TritonToLinalgNPUConversionPass.cpp @@ -1,4 +1,6 @@ +#include "dicp/Conversion/TritonToLinalgNPU/MemRefCopyGatherToTensorInsert/Passes.h" #include "dicp/Conversion/TritonToLinalgNPU/TritonToLinalgNPUCoversion/TritonToLinalgNPUCoversion.h" +#include "dicp/Conversion/TritonToUnstructure/BubbleUpOperation.h" #include "triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h" #include "triton-shared/Conversion/TritonArithToLinalg/TritonArithToLinalg.h" @@ -58,6 +60,8 @@ class TritonToLinalgNPUCoversionPass pm.addPass(createTritonArithToLinalgPass(true, false)); pm.addPass(createStructuredToMemrefPass()); + pm.addPass(createMemRefCopyGatherToTensorInsertPass()); + pm.addPass(createUnstructuredToMemrefPass()); pm.addPass(createTritonPtrToMemrefPass()); pm.addPass(createTritonToPtrPass()); @@ -76,7 +80,6 @@ class TritonToLinalgNPUCoversionPass // collapseShape pass pm.addPass(createCollapseShapePass()); } - if (failed(runPipeline(pm, getOperation()))) { signalPassFailure(); } diff --git a/compiler/lib/Conversion/TritonToUnstructure/BubbleUpOperation.cpp b/compiler/lib/Conversion/TritonToUnstructure/BubbleUpOperation.cpp new file mode 100644 index 00000000..43677f9e --- /dev/null +++ b/compiler/lib/Conversion/TritonToUnstructure/BubbleUpOperation.cpp @@ -0,0 +1,537 @@ +#include "dicp/Conversion/TritonToUnstructure/BubbleUpOperation.h" +#include "dicp/Utils/Utils.h" + +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +#define DEBUG_TYPE "triton-bubble-up-operation" + +template +class BubbleUpExtract : public OpRewritePattern { + static_assert(std::is_same_v || + std::is_same_v); + +public: + using OpRewritePattern::OpRewritePattern; + BubbleUpExtract(MLIRContext *context, bool enableAggressiveMode); + LogicalResult matchAndRewrite(ExtractOpTy op, + PatternRewriter &rewriter) const override; + +private: + Value createExtractOp(ExtractOpTy op, Value value, Location loc, + PatternRewriter &rewriter) const; + template + void bubbleUpIntBinaryOp(ExtractOpTy op, BinOpTy binOp, Location loc, + PatternRewriter &rewriter) const; + template + void bubbleUpFloatBinaryOp(ExtractOpTy op, BinOpTy binOp, Location loc, + PatternRewriter &rewriter) const; + + void bubbleUpOperation(ExtractOpTy op, arith::ExtSIOp parentOp, Location loc, + PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, arith::CmpIOp parentOp, Location loc, + PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, arith::TruncFOp parentOp, Location loc, + PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, arith::ExtFOp parentOp, Location loc, + PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, arith::FPToSIOp parentOp, Location loc, + PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, arith::SIToFPOp parentOp, Location loc, + PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, triton::ClampFOp parentOp, + Location loc, PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, arith::CmpFOp parentOp, Location loc, + PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, triton::BroadcastOp parentOp, + Location loc, PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, triton::ExpandDimsOp parentOp, + Location loc, PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, triton::SplatOp parentOp, Location loc, + PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, triton::MakeRangeOp parentOp, + Location loc, PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, triton::AddPtrOp parentOp, + Location loc, PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, math::FloorOp parentOp, Location loc, + PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, math::CeilOp parentOp, Location loc, + PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, tensor::ExtractSliceOp parentOp, + Location loc, PatternRewriter &rewriter) const; + + bool enableAggressiveMode; +}; + +template +BubbleUpExtract::BubbleUpExtract(MLIRContext *context, + bool enableAggressiveMode) + : OpRewritePattern(context), + enableAggressiveMode(enableAggressiveMode) {} + +template +LogicalResult +BubbleUpExtract::matchAndRewrite(ExtractOpTy op, + PatternRewriter &rewriter) const { + Value tensorValue; + if constexpr (std::is_same_v) { + tensorValue = op.getTensor(); + } else if constexpr (std::is_same_v) { + tensorValue = op.getSource(); + if (tensorValue.getType() == op.getResult().getType()) { + rewriter.replaceAllUsesWith(op.getResult(), tensorValue); + rewriter.eraseOp(op); + return success(); + } + } else { + llvm_unreachable("Unhandled case"); + } + auto funcOp = op->template getParentOfType(); + auto parentOp = tensorValue.getDefiningOp(); + auto loc = op.getLoc(); + + if (!parentOp || (!enableAggressiveMode && !parentOp->hasOneUse())) { + return failure(); + } + + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "Before bubble up\n" << op << '\n' << funcOp << "\n"; + }); + + if (auto extsiOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, extsiOp, loc, rewriter); + } else if (auto addIOp = dyn_cast(parentOp)) { + bubbleUpIntBinaryOp(op, addIOp, loc, rewriter); + } else if (auto subIOp = dyn_cast(parentOp)) { + bubbleUpIntBinaryOp(op, subIOp, loc, rewriter); + } else if (auto mulIOp = dyn_cast(parentOp)) { + bubbleUpIntBinaryOp(op, mulIOp, loc, rewriter); + } else if (auto divSIOp = dyn_cast(parentOp)) { + bubbleUpIntBinaryOp(op, divSIOp, loc, rewriter); + } else if (auto remSIOp = dyn_cast(parentOp)) { + bubbleUpIntBinaryOp(op, remSIOp, loc, rewriter); + } else if (auto maxSIOp = dyn_cast(parentOp)) { + bubbleUpIntBinaryOp(op, maxSIOp, loc, rewriter); + } else if (auto minSIOp = dyn_cast(parentOp)) { + bubbleUpIntBinaryOp(op, minSIOp, loc, rewriter); + } else if (auto andIOp = dyn_cast(parentOp)) { + bubbleUpIntBinaryOp(op, andIOp, loc, rewriter); + } else if (auto orIOp = dyn_cast(parentOp)) { + bubbleUpIntBinaryOp(op, orIOp, loc, rewriter); + } else if (auto cmpIOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, cmpIOp, loc, rewriter); + } else if (auto truncFOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, truncFOp, loc, rewriter); + } else if (auto extFOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, extFOp, loc, rewriter); + } else if (auto fpTosiOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, fpTosiOp, loc, rewriter); + } else if (auto siTofpOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, siTofpOp, loc, rewriter); + } else if (auto clampFOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, clampFOp, loc, rewriter); + } else if (auto addFOp = dyn_cast(parentOp)) { + bubbleUpFloatBinaryOp(op, addFOp, loc, rewriter); + } else if (auto subFOp = dyn_cast(parentOp)) { + bubbleUpFloatBinaryOp(op, subFOp, loc, rewriter); + } else if (auto mulFOp = dyn_cast(parentOp)) { + bubbleUpFloatBinaryOp(op, mulFOp, loc, rewriter); + } else if (auto divFOp = dyn_cast(parentOp)) { + bubbleUpFloatBinaryOp(op, divFOp, loc, rewriter); + } else if (auto minNumFOp = dyn_cast(parentOp)) { + bubbleUpFloatBinaryOp(op, minNumFOp, loc, rewriter); + } else if (auto maxNumFOp = dyn_cast(parentOp)) { + bubbleUpFloatBinaryOp(op, maxNumFOp, loc, rewriter); + } else if (auto cmpFOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, cmpFOp, loc, rewriter); + } else if (auto broadCastOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, broadCastOp, loc, rewriter); + } else if (auto expandDimsOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, expandDimsOp, loc, rewriter); + } else if (auto splatOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, splatOp, loc, rewriter); + } else if (auto makeRangeOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, makeRangeOp, loc, rewriter); + } else if (auto addPtrOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, addPtrOp, loc, rewriter); + } else if (auto floorOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, floorOp, loc, rewriter); + } else if (auto ceilOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, ceilOp, loc, rewriter); + } else if (auto extractSliceOp = dyn_cast(parentOp)) { + if constexpr (std::is_same_v) { + bubbleUpOperation(op, extractSliceOp, loc, rewriter); + } else { + return failure(); + } + } else { + return failure(); + } + if (parentOp->use_empty()) + rewriter.eraseOp(parentOp); + + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "After bubble up\n" << funcOp << '\n'; + }); + + return success(); +} + +template +Value BubbleUpExtract::createExtractOp( + ExtractOpTy op, Value value, Location loc, + PatternRewriter &rewriter) const { + llvm_unreachable("Unhandled extract operation"); +} + +template <> +Value BubbleUpExtract::createExtractOp( + tensor::ExtractOp op, Value value, Location loc, + PatternRewriter &rewriter) const { + auto extractedOp = + rewriter.create(loc, value, op.getIndices()); + extractedOp->setAttr(mlir::dicp::discreteAttrName, + UnitAttr::get(rewriter.getContext())); + return extractedOp; +} + +template <> +Value BubbleUpExtract::createExtractOp( + tensor::ExtractSliceOp op, Value value, Location loc, + PatternRewriter &rewriter) const { + auto extractedOp = rewriter.create( + loc, value, op.getMixedOffsets(), op.getMixedSizes(), + op.getMixedStrides()); + extractedOp->setAttr(mlir::dicp::discreteAttrName, + UnitAttr::get(rewriter.getContext())); + return extractedOp; +} + +template +template +void BubbleUpExtract::bubbleUpIntBinaryOp( + ExtractOpTy op, BinOpTy binOp, Location loc, + PatternRewriter &rewriter) const { + auto lhs = createExtractOp(op, binOp.getLhs(), loc, rewriter); + auto rhs = createExtractOp(op, binOp.getRhs(), loc, rewriter); + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "Binary\n" << *op << '\n' << binOp << '\n'; + }); + rewriter.replaceOpWithNewOp(op, lhs, rhs); +} + +template +template +void BubbleUpExtract::bubbleUpFloatBinaryOp( + ExtractOpTy op, BinOpTy binOp, Location loc, + PatternRewriter &rewriter) const { + auto lhs = createExtractOp(op, binOp.getLhs(), loc, rewriter); + auto rhs = createExtractOp(op, binOp.getRhs(), loc, rewriter); + rewriter.replaceOpWithNewOp(op, lhs, rhs); +} + +template +void BubbleUpExtract::bubbleUpOperation( + ExtractOpTy op, arith::ExtSIOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto in = createExtractOp(op, parentOp.getIn(), loc, rewriter); + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), in); +} + +template +void BubbleUpExtract::bubbleUpOperation( + ExtractOpTy op, arith::CmpIOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto lhs = createExtractOp(op, parentOp.getLhs(), loc, rewriter); + auto rhs = createExtractOp(op, parentOp.getRhs(), loc, rewriter); + rewriter.replaceOpWithNewOp(op, parentOp.getPredicateAttr(), + lhs, rhs); +} + +template <> +void BubbleUpExtract::bubbleUpOperation( + tensor::ExtractOp op, triton::BroadcastOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto src = parentOp.getSrc(); + auto srcShape = cast(src.getType()).getShape(); + SmallVector newIndices; + for (const auto &[index, shape] : + llvm::zip_equal(op.getIndices(), srcShape)) { + if (shape == 1) { + newIndices.push_back( + rewriter.create(loc, rewriter.getIndexAttr(0))); + } else { + newIndices.push_back(index); + } + } + auto extractedOp = rewriter.create(loc, src, newIndices); + extractedOp->setAttr(mlir::dicp::discreteAttrName, + UnitAttr::get(rewriter.getContext())); + rewriter.replaceOp(op, extractedOp); +} + +template <> +void BubbleUpExtract::bubbleUpOperation( + tensor::ExtractSliceOp op, triton::BroadcastOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto src = parentOp.getSrc(); + auto srcShape = cast(src.getType()).getShape(); + SmallVector newOffsets; + SmallVector newSizes; + bool isScalarLikeSrc = true; + for (const auto &[offset, size, shape] : + llvm::zip_equal(op.getMixedOffsets(), op.getMixedSizes(), srcShape)) { + if (shape == 1) { + newOffsets.push_back(rewriter.getIndexAttr(0)); + newSizes.push_back(rewriter.getIndexAttr(1)); + } else { + newOffsets.push_back(offset); + newSizes.push_back(size); + } + if (getConstantIntValue(newSizes.back()).value_or(-1) != 1) + isScalarLikeSrc = false; + } + auto extractedOp = rewriter.create( + loc, src, newOffsets, newSizes, op.getMixedStrides()); + extractedOp->setAttr(mlir::dicp::discreteAttrName, + UnitAttr::get(rewriter.getContext())); + if (isScalarLikeSrc) { + SmallVector indices( + srcShape.size(), + rewriter.create(loc, rewriter.getIndexAttr(0))); + auto extractedValue = + rewriter.create(loc, extractedOp, indices); + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), + extractedValue); + } else { + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), extractedOp); + } +} + +template <> +void BubbleUpExtract::bubbleUpOperation( + tensor::ExtractOp op, triton::ExpandDimsOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto src = parentOp.getSrc(); + SmallVector newIndices; + for (const auto index : llvm::enumerate(op.getIndices())) { + if (index.index() != parentOp.getAxis()) + newIndices.push_back(index.value()); + } + auto extractedOp = rewriter.create(loc, src, newIndices); + extractedOp->setAttr(mlir::dicp::discreteAttrName, + UnitAttr::get(rewriter.getContext())); + rewriter.replaceOp(op, extractedOp); +} + +template <> +void BubbleUpExtract::bubbleUpOperation( + tensor::ExtractSliceOp op, triton::ExpandDimsOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto src = parentOp.getSrc(); + auto srcShape = cast(src.getType()).getShape(); + SmallVector newOffsets; + SmallVector newSizes; + SmallVector newStrides; + for (size_t i = 0; i <= srcShape.size(); i++) { + if (i != parentOp.getAxis()) { + newOffsets.push_back(op.getMixedOffsets()[i]); + newSizes.push_back(op.getMixedSizes()[i]); + newStrides.push_back(op.getMixedStrides()[i]); + } + } + auto extractedOp = rewriter.create( + loc, src, newOffsets, newSizes, newStrides); + extractedOp->setAttr(mlir::dicp::discreteAttrName, + UnitAttr::get(rewriter.getContext())); + rewriter.replaceOpWithNewOp(op, extractedOp, + parentOp.getAxisAttr()); +} + +template <> +void BubbleUpExtract::bubbleUpOperation( + tensor::ExtractOp op, triton::SplatOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto src = parentOp.getSrc(); + rewriter.replaceOp(op, src); +} + +template <> +void BubbleUpExtract::bubbleUpOperation( + tensor::ExtractSliceOp op, triton::SplatOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto src = parentOp.getSrc(); + rewriter.replaceOpWithNewOp( + op, cast(op.getResult().getType()), src); +} + +template <> +void BubbleUpExtract::bubbleUpOperation( + tensor::ExtractOp op, triton::MakeRangeOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto resultType = cast(parentOp.getResult().getType()); + rewriter.replaceOpWithNewOp( + op, resultType.getElementType(), op.getIndices()[0]); +} + +template <> +void BubbleUpExtract::bubbleUpOperation( + tensor::ExtractSliceOp op, triton::MakeRangeOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto resultType = cast(parentOp.getResult().getType()); + Value idx; + if (auto offsetVal = dyn_cast(op.getMixedOffsets()[0])) { + idx = offsetVal; + } else { + idx = rewriter.create( + op.getLoc(), rewriter.getIndexAttr( + getConstantIntValue(op.getMixedOffsets()[0]).value())); + } + idx = rewriter.create(op.getLoc(), + resultType.getElementType(), idx); + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), + idx); +} + +template +void BubbleUpExtract::bubbleUpOperation( + ExtractOpTy op, triton::AddPtrOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto ptr = createExtractOp(op, parentOp.getPtr(), loc, rewriter); + auto offset = createExtractOp(op, parentOp.getOffset(), loc, rewriter); + rewriter.replaceOpWithNewOp(op, ptr.getType(), ptr, offset); +} + +template +void BubbleUpExtract::bubbleUpOperation( + ExtractOpTy op, arith::TruncFOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto in = createExtractOp(op, parentOp.getIn(), loc, rewriter); + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), + in); +} + +template +void BubbleUpExtract::bubbleUpOperation( + ExtractOpTy op, arith::ExtFOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto in = createExtractOp(op, parentOp.getIn(), loc, rewriter); + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), in); +} + +template +void BubbleUpExtract::bubbleUpOperation( + ExtractOpTy op, arith::FPToSIOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto in = createExtractOp(op, parentOp.getIn(), loc, rewriter); + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), + in); +} + +template +void BubbleUpExtract::bubbleUpOperation( + ExtractOpTy op, arith::SIToFPOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto in = createExtractOp(op, parentOp.getIn(), loc, rewriter); + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), + in); +} + +template +void BubbleUpExtract::bubbleUpOperation( + ExtractOpTy op, triton::ClampFOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto x = createExtractOp(op, parentOp.getX(), loc, rewriter); + auto min = createExtractOp(op, parentOp.getMin(), loc, rewriter); + auto max = createExtractOp(op, parentOp.getMax(), loc, rewriter); + rewriter.replaceOpWithNewOp(op, x, min, max, + parentOp.getPropagateNan()); +} + +template +void BubbleUpExtract::bubbleUpOperation( + ExtractOpTy op, arith::CmpFOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto lhs = createExtractOp(op, parentOp.getLhs(), loc, rewriter); + auto rhs = createExtractOp(op, parentOp.getRhs(), loc, rewriter); + rewriter.replaceOpWithNewOp(op, parentOp.getPredicateAttr(), + lhs, rhs); +} + +template +void BubbleUpExtract::bubbleUpOperation( + ExtractOpTy op, math::FloorOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto operand = createExtractOp(op, parentOp.getOperand(), loc, rewriter); + rewriter.replaceOpWithNewOp(op, operand, + parentOp.getFastmath()); +} + +template +void BubbleUpExtract::bubbleUpOperation( + ExtractOpTy op, math::CeilOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto operand = createExtractOp(op, parentOp.getOperand(), loc, rewriter); + rewriter.replaceOpWithNewOp(op, operand, + parentOp.getFastmath()); +} + +template <> +void BubbleUpExtract::bubbleUpOperation( + tensor::ExtractOp op, tensor::ExtractSliceOp parentOp, Location loc, + PatternRewriter &rewriter) const { + SmallVector newIndices; + for (const auto &[offset, index] : + llvm::zip_equal(parentOp.getMixedOffsets(), op.getIndices())) { + Value offsetVal; + if (auto v = dyn_cast(offset)) { + offsetVal = v; + } else { + offsetVal = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(*getConstantIntValue(offset))); + } + newIndices.push_back( + rewriter.create(op.getLoc(), offsetVal, index)); + } + rewriter + .replaceOpWithNewOp(op, parentOp.getSource(), + newIndices) + ->setAttr(mlir::dicp::discreteAttrName, + UnitAttr::get(rewriter.getContext())); +} + +BubbleUpOperationPass::BubbleUpOperationPass( + const BubbleUpOperationOptions &options) + : BubbleUpOperationBase(options) {} + +void BubbleUpOperationPass::runOnOperation() { + ModuleOp moduleOp = getOperation(); + MLIRContext *ctx = &getContext(); + + RewritePatternSet patterns(ctx); + patterns.add, + BubbleUpExtract>(ctx, + enableAggressiveMode); + + if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) { + moduleOp->emitError("failed to apply Patterns"); + signalPassFailure(); + } + + PassManager pm(&getContext(), moduleOp.getOperationName()); + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + if (failed(runPipeline(pm, getOperation()))) { + signalPassFailure(); + } +} + +std::unique_ptr> +triton::createBubbleUpOperationPass(const BubbleUpOperationOptions &options) { + return std::make_unique(options); +} diff --git a/compiler/lib/Conversion/TritonToUnstructure/CMakeLists.txt b/compiler/lib/Conversion/TritonToUnstructure/CMakeLists.txt new file mode 100644 index 00000000..802f5130 --- /dev/null +++ b/compiler/lib/Conversion/TritonToUnstructure/CMakeLists.txt @@ -0,0 +1,20 @@ +add_triton_library(TritonToUnstructure + UnstructureConversionPass.cpp + OffsetAnalysis.cpp + BubbleUpOperation.cpp + + DEPENDS + TritonToUnstructureConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport + TritonIR + TritonAnalysis + MLIRSCFTransforms +) \ No newline at end of file diff --git a/compiler/lib/Conversion/TritonToUnstructure/OffsetAnalysis.cpp b/compiler/lib/Conversion/TritonToUnstructure/OffsetAnalysis.cpp new file mode 100644 index 00000000..cf66046b --- /dev/null +++ b/compiler/lib/Conversion/TritonToUnstructure/OffsetAnalysis.cpp @@ -0,0 +1,965 @@ +#include "dicp/Conversion/TritonToUnstructure/OffsetAnalysis.h" +#include "dicp/Utils/Utils.h" + +#include "mlir/Dialect/Linalg/IR/Linalg.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "triton-offset-analysis" + +namespace mlir { +namespace triton { + +PtrOffsetInfo::PtrOffsetInfo() : ptr(nullptr), offset(nullptr) {} + +PtrOffsetInfo::PtrOffsetInfo(const PtrOffsetInfo &other) { *this = other; } + +PtrOffsetInfo::PtrOffsetInfo(const Value &ptr) : ptr(ptr) { setZeroOffset(); } + +PtrOffsetInfo::PtrOffsetInfo(ArrayRef structured) + : ptr(nullptr), offset(nullptr) { + setStructured(structured); +} + +PtrOffsetInfo::PtrOffsetInfo(const Value &ptr, bool structured) : ptr(ptr) { + setZeroOffset(); + if (auto tensorType = dyn_cast(ptr.getType())) + this->structured.resize(tensorType.getRank(), structured); +} + +PtrOffsetInfo::PtrOffsetInfo(const Value &ptr, ArrayRef structured) + : ptr(ptr) { + setStructured(structured); +} + +PtrOffsetInfo::PtrOffsetInfo(const Value &ptr, const Value &offset, + bool structured) + : ptr(ptr), offset(offset) { + if (auto tensorType = dyn_cast(ptr.getType())) + this->structured.resize(tensorType.getRank(), structured); +} + +PtrOffsetInfo::PtrOffsetInfo(const Value &ptr, const Value &offset, + ArrayRef structured) + : ptr(ptr), offset(offset) { + setStructured(structured); +} + +PtrOffsetInfo &PtrOffsetInfo::operator=(const PtrOffsetInfo &other) { + setPtr(other.getPtr()); + setOffset(other.getOffset()); + setOffsets(other.getOffsets()); + setStructured(other.getStructured()); + setScalarLike(other.isScalarLike()); + return *this; +} + +Value PtrOffsetInfo::getPtr() const { return this->ptr; } +Value PtrOffsetInfo::getOffset() const { return this->offset; } +SmallVector PtrOffsetInfo::getOffsets() const { + return this->tptOffsets; +} +SmallVector &PtrOffsetInfo::getOffsetsRef() { return this->tptOffsets; } + +bool PtrOffsetInfo::isScalarLike() const { return this->scalarLike; } + +SmallVector &PtrOffsetInfo::getStructuredRef() { + return this->structured; +} +const SmallVector &PtrOffsetInfo::getStructured() const { + return this->structured; +} + +int PtrOffsetInfo::getRank() const { return structured.size(); } + +void PtrOffsetInfo::setPtr(const Value &ptr) { this->ptr = ptr; } +void PtrOffsetInfo::setOffset(const Value &offset) { this->offset = offset; } + +void PtrOffsetInfo::setOffsets(ValueRange offsets) { + tptOffsets.clear(); + for (auto offset : offsets) + tptOffsets.push_back(offset); +} + +void PtrOffsetInfo::setStructured() { + assert(ptr && "ptr Should be to infer rank"); + this->structured.clear(); + if (auto tensorType = dyn_cast(ptr.getType())) + this->structured.resize(tensorType.getRank(), true); +} + +void PtrOffsetInfo::setStructured(int rank) { + this->structured.clear(); + this->structured.resize(rank, true); +} + +void PtrOffsetInfo::setUnstructured() { + assert(ptr && "ptr Should be to infer rank"); + this->structured.clear(); + if (auto tensorType = dyn_cast(ptr.getType())) + this->structured.resize(tensorType.getRank(), false); +} + +void PtrOffsetInfo::setUnstructured(int rank) { + this->structured.clear(); + this->structured.resize(rank, false); +} + +void PtrOffsetInfo::setStructured(ArrayRef structured) { + this->structured.resize(structured.size()); + for (size_t i = 0; i < structured.size(); i++) + this->structured[i] = structured[i]; +} + +void PtrOffsetInfo::setStructured(const PtrOffsetInfo &other) { + this->setStructured(other.getStructured()); +} + +void PtrOffsetInfo::setScalarLike(bool scalarLike) { + this->scalarLike = scalarLike; +} + +bool PtrOffsetInfo::isStructured(int dim) const { + return this->scalarLike || structured[dim]; +} + +bool PtrOffsetInfo::isStructured() const { + return this->scalarLike || + llvm::all_of(structured, [](auto dim) { return dim; }); +} + +bool PtrOffsetInfo::isUnstructured() const { + return llvm::all_of(structured, [](auto dim) { return !dim; }); +} + +void PtrOffsetInfo::setZeroOffset() { + if (!ptr) + return; + Value offset; + OpBuilder builder(ptr.getContext()); + builder.setInsertionPointToStart(ptr.getParentBlock()); + if (auto tensorType = dyn_cast(ptr.getType())) { + offset = builder.create( + ptr.getLoc(), DenseElementsAttr::get( + RankedTensorType::get(tensorType.getShape(), + builder.getIntegerType(64)), + builder.getZeroAttr(builder.getIntegerType(64)))); + } else { + offset = builder.create(ptr.getLoc(), + builder.getI64IntegerAttr(0)); + } + setOffset(offset); +} + +PtrOffsetInfo combineInfo(const PtrOffsetInfo &lhs, const PtrOffsetInfo &rhs) { + PtrOffsetInfo info; + assert(lhs.getRank() == rhs.getRank() && "Rank must be same to be combined"); + + info.setScalarLike(lhs.isScalarLike() && rhs.isScalarLike()); + SmallVector &structuredRef = info.getStructuredRef(); + structuredRef.resize(lhs.getRank()); + for (size_t i = 0; i < structuredRef.size(); i++) + structuredRef[i] = lhs.isStructured(i) && rhs.isStructured(i); + return info; +} + +void parse(Value operand, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + if (offsetMap.contains(operand)) { + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "found\n" << operand << '\n'; + }); + return; + } + + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "parse\n" << operand << '\n'; + }); + LLVM_DEBUG({}); + + if (auto *defOp = operand.getDefiningOp()) { + if (isa(defOp->getDialect())) { + parseArithOp(defOp, loc, rewriter, offsetMap); + } else if (isa(defOp->getDialect())) { + parseTritonOp(defOp, loc, rewriter, offsetMap); + } else { + if (auto ifOp = dyn_cast(defOp)) { + parseIf(ifOp, loc, rewriter, offsetMap, operand); + } else if (auto yieldOp = dyn_cast(defOp)) { + parseYield(yieldOp, loc, rewriter, offsetMap); + } else if (auto loopOp = dyn_cast(defOp)) { + parseLoopOp(loopOp, loc, rewriter, offsetMap, operand); + } else if (auto extractOp = dyn_cast(defOp)) { + parseExtract(extractOp, loc, rewriter, offsetMap); + } + } + } else if (auto blockArgument = dyn_cast(operand)) { + auto parentOp = blockArgument.getOwner()->getParentOp(); + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "Handling block argument\n" << *blockArgument.getOwner() << '\n'; + }); + if (isa(parentOp)) { + if (auto ptrType = dyn_cast(operand.getType())) { + offsetMap[operand] = PtrOffsetInfo(operand, true); + } else { + offsetMap[operand] = PtrOffsetInfo(); + } + } else if (auto loopOp = dyn_cast(parentOp)) { + parseLoopRegionIterArg(loopOp, loc, rewriter, offsetMap, blockArgument); + } + } else { + llvm_unreachable("Unreachable"); + } + + if (!offsetMap.contains(operand)) { + offsetMap[operand] = PtrOffsetInfo(); + if (auto tensorType = dyn_cast(operand.getType())) + offsetMap[operand].setUnstructured(tensorType.getRank()); + } + + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "finish parse\n" << operand << '\n'; + auto data = offsetMap.at(operand); + for (auto s : data.getStructuredRef()) + os << s; + os << "\n"; + }); +} + +void parseLoopRegionIterArg(LoopLikeOpInterface loopOp, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap, + BlockArgument regionIterArg) { + if (auto whileOp = dyn_cast(loopOp.getOperation()); + whileOp && whileOp.getAfterBody() == regionIterArg.getOwner()) { + auto argNum = regionIterArg.getArgNumber(); + auto conditionArg = whileOp.getConditionOp().getArgs()[argNum]; + parse(conditionArg, loc, rewriter, offsetMap); + offsetMap[regionIterArg] = offsetMap[conditionArg]; + return; + } + OpOperand *initArgOperand = loopOp.getTiedLoopInit(regionIterArg); + if (!initArgOperand) + return; + Value initArg = initArgOperand->get(); + parse(initArg, loc, rewriter, offsetMap); + offsetMap[regionIterArg] = offsetMap[initArg]; +} + +void parseArithOp(Operation *arithOp, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + assert(isa(arithOp->getDialect())); + if (auto addIOp = dyn_cast(arithOp)) { + parseAddI(addIOp, loc, rewriter, offsetMap); + } else if (auto subIOp = dyn_cast(arithOp)) { + parseSubI(subIOp, loc, rewriter, offsetMap); + } else if (auto indexCastOp = dyn_cast(arithOp)) { + parseIndexCast(indexCastOp, loc, rewriter, offsetMap); + } else if (auto constantFloatOp = dyn_cast(arithOp)) { + parseConstantOp(constantFloatOp, loc, rewriter, offsetMap); + } else if (auto constantIntOp = dyn_cast(arithOp)) { + parseConstantOp(constantIntOp, loc, rewriter, offsetMap); + } else if (auto constantOp = dyn_cast(arithOp)) { + parseConstantOp(constantOp, loc, rewriter, offsetMap); + } else if (auto extSIOp = dyn_cast(arithOp)) { + parseExtSI(extSIOp, loc, rewriter, offsetMap); + } else if (auto mulIOp = dyn_cast(arithOp)) { + parseMulI(mulIOp, loc, rewriter, offsetMap); + } else if (auto remSIOp = dyn_cast(arithOp)) { + parseBinaryOp(remSIOp, loc, rewriter, offsetMap); + } else if (auto divSIOp = dyn_cast(arithOp)) { + parseBinaryOp(divSIOp, loc, rewriter, offsetMap); + } else if (auto selectOp = dyn_cast(arithOp)) { + parseSelect(selectOp, loc, rewriter, offsetMap); + } else if (auto fPToSIOp = dyn_cast(arithOp)) { + parseFPToSI(fPToSIOp, loc, rewriter, offsetMap); + } else if (auto sIToFPOp = dyn_cast(arithOp)) { + parseSIToFP(sIToFPOp, loc, rewriter, offsetMap); + } else if (auto mulFOp = dyn_cast(arithOp)) { + parseBinaryOp(mulFOp, loc, rewriter, offsetMap); + } else if (auto divFOp = dyn_cast(arithOp)) { + parseBinaryOp(divFOp, loc, rewriter, offsetMap); + } else if (auto addFOp = dyn_cast(arithOp)) { + parseBinaryOp(addFOp, loc, rewriter, offsetMap); + } else if (auto subFOp = dyn_cast(arithOp)) { + parseBinaryOp(subFOp, loc, rewriter, offsetMap); + } else if (auto minNumFOp = dyn_cast(arithOp)) { + parseBinaryOp(minNumFOp, loc, rewriter, offsetMap); + } else if (auto maxNumFOp = dyn_cast(arithOp)) { + parseBinaryOp(maxNumFOp, loc, rewriter, offsetMap); + } else if (auto maxSIOp = dyn_cast(arithOp)) { + parseBinaryOp(maxSIOp, loc, rewriter, offsetMap); + } else if (auto minSIOp = dyn_cast(arithOp)) { + parseBinaryOp(minSIOp, loc, rewriter, offsetMap); + } else if (auto cmpIOp = dyn_cast(arithOp)) { + parseBinaryOp(cmpIOp, loc, rewriter, offsetMap); + } else if (auto andIOp = dyn_cast(arithOp)) { + parseBinaryOp(andIOp, loc, rewriter, offsetMap); + } else if (auto orIOp = dyn_cast(arithOp)) { + parseBinaryOp(orIOp, loc, rewriter, offsetMap); + } +} + +void parseTritonOp(Operation *tritonOp, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + assert(isa(tritonOp->getDialect())); + if (auto addPtrOp = dyn_cast(tritonOp)) { + parseAddPtr(addPtrOp, loc, rewriter, offsetMap); + } else if (auto splatOp = dyn_cast(tritonOp)) { + parseSplat(splatOp, loc, rewriter, offsetMap); + } else if (auto getProgramIdOp = dyn_cast(tritonOp)) { + parseConstantOp(getProgramIdOp, loc, rewriter, offsetMap); + } else if (auto getNumProgramsOp = + dyn_cast(tritonOp)) { + parseConstantOp(getNumProgramsOp, loc, rewriter, offsetMap); + } else if (auto makeRangeOp = dyn_cast(tritonOp)) { + parseMakeRange(makeRangeOp, loc, rewriter, offsetMap); + } else if (auto bitcastOp = dyn_cast(tritonOp)) { + parseBitcast(bitcastOp, loc, rewriter, offsetMap); + } else if (auto loadOp = dyn_cast(tritonOp)) { + parseLoad(loadOp, loc, rewriter, offsetMap); + } else if (auto broadcastOp = dyn_cast(tritonOp)) { + parseBroadcast(broadcastOp, loc, rewriter, offsetMap); + } else if (auto expandDimsOp = dyn_cast(tritonOp)) { + parseExpandDims(expandDimsOp, loc, rewriter, offsetMap); + } else if (auto clampFOp = dyn_cast(tritonOp)) { + parseClampF(clampFOp, loc, rewriter, offsetMap); + } else if (auto makeTensorDescOp = + dyn_cast(tritonOp)) { + parseMakeTensorDesc(makeTensorDescOp, loc, rewriter, offsetMap); + } else if (auto makeTensorPtrOp = + dyn_cast(tritonOp)) { + parseMakeTensorPtr(makeTensorPtrOp, loc, rewriter, offsetMap); + } else if (auto reduceOp = dyn_cast(tritonOp)) { + parseReduce(reduceOp, loc, rewriter, offsetMap); + } else if (auto reduceReturnOp = dyn_cast(tritonOp)) { + parseReduceReturn(reduceReturnOp, loc, rewriter, offsetMap); + } else if (auto advanceOp = dyn_cast(tritonOp)) { + parseAdvance(advanceOp, loc, rewriter, offsetMap); + } else if (auto intToPtrOp = dyn_cast(tritonOp)) { + parseIntToPtr(intToPtrOp, loc, rewriter, offsetMap); + } +} + +void parseAddPtr(triton::AddPtrOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get addPtr base_ptr + Value ptr = op.getPtr(); + parse(ptr, op.getLoc(), rewriter, offsetMap); + // Get addPtr offset + Value offsetValue = op.getOffset(); + parse(offsetValue, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo ptrOffsetInfo = offsetMap.at(ptr); + PtrOffsetInfo offsetOffsetInfo = offsetMap.at(offsetValue); + // Modify IR + + RewriterBase::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + if (auto offsetType = dyn_cast(offsetValue.getType())) { + auto offsetElementType = cast(offsetType.getElementType()); + if (offsetElementType.getWidth() != 64) { + auto newOffsetType = RankedTensorType::get(offsetType.getShape(), + rewriter.getIntegerType(64)); + offsetValue = rewriter.create(op.getLoc(), newOffsetType, + offsetValue); + } + } else { + auto offsetIntType = cast(offsetValue.getType()); + if (offsetIntType.getWidth() != 64) { + offsetValue = rewriter.create( + op.getLoc(), rewriter.getIntegerType(64), offsetValue); + } + } + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "[parseAddPtr] Adding offset\n"; + os << ptrOffsetInfo.getOffset() << '\n' << offsetValue << '\n'; + }); + Value offset = rewriter.create( + op.getLoc(), ptrOffsetInfo.getOffset(), offsetValue); + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "[parseAddPtr] offset is\n" << offset << '\n'; + }); + // Set addPtr offset map + auto dst = op.getResult(); + auto dstOffsetInfo = combineInfo(ptrOffsetInfo, offsetOffsetInfo); + dstOffsetInfo.setPtr(ptrOffsetInfo.getPtr()); + dstOffsetInfo.setOffset(offset); + offsetMap[dst] = dstOffsetInfo; + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + SmallVector &ptrStructured = ptrOffsetInfo.getStructuredRef(); + SmallVector &offsetStructured = offsetOffsetInfo.getStructuredRef(); + os << "[parseAddPtr] ptrStructured: "; + for (size_t i = 0; i < ptrStructured.size(); i++) + os << ptrStructured[i]; + os << "\n"; + os << "[parseAddPtr] offsetStructured: "; + for (size_t i = 0; i < offsetStructured.size(); i++) + os << offsetStructured[i]; + os << "\n"; + }); +} + +void parseSplat(triton::SplatOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get splat src + auto src = op.getSrc(); + parse(src, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo srcOffsetInfo = offsetMap.at(src); + auto dst = op.getResult(); + auto dstType = cast(dst.getType()); + PtrOffsetInfo dstOffsetInfo(srcOffsetInfo.getPtr()); + // Modify IR + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "[parseSplat] dst is\n" << dst << '\n'; + }); + if (isa(dstType.getElementType())) { + RewriterBase::InsertionGuard guard(rewriter); + auto dstShape = dstType.getShape(); + rewriter.setInsertionPoint(op); + Value valueOffset = srcOffsetInfo.getOffset(); + Value offset = rewriter.create( + loc, RankedTensorType::get(dstShape, rewriter.getIntegerType(64)), + valueOffset); + dstOffsetInfo.setOffset(offset); + } + // Set addPtr offset map + + dstOffsetInfo.setStructured(dstType.getRank()); + dstOffsetInfo.setScalarLike(true); + offsetMap[dst] = dstOffsetInfo; +} + +template +void parseBinaryOp(BinOpTy op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + auto lhs = op.getLhs(); + parse(lhs, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo lhsOffsetInfo = offsetMap.at(lhs); + SmallVector &lhsStructured = lhsOffsetInfo.getStructuredRef(); + auto rhs = op.getRhs(); + parse(rhs, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo rhsOffsetInfo = offsetMap.at(rhs); + SmallVector &rhsStructured = rhsOffsetInfo.getStructuredRef(); + auto dst = op->getResult(0); + PtrOffsetInfo dstOffsetInfo; + dstOffsetInfo.setScalarLike(lhsOffsetInfo.isScalarLike() && + rhsOffsetInfo.isScalarLike()); + if (dstOffsetInfo.isScalarLike()) + dstOffsetInfo.setStructured(lhsStructured.size()); + else + dstOffsetInfo.setUnstructured(lhsStructured.size()); + offsetMap[dst] = dstOffsetInfo; +} + +void parseAddI(arith::AddIOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get addi lhs + auto lhs = op.getLhs(); + parse(lhs, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo lhsOffsetInfo = offsetMap.at(lhs); + // Get addi rhs + auto rhs = op.getRhs(); + parse(rhs, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo rhsOffsetInfo = offsetMap.at(rhs); + // Set addi offset map + auto dst = op.getResult(); + offsetMap[dst] = combineInfo(lhsOffsetInfo, rhsOffsetInfo); +} + +void parseSubI(arith::SubIOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get addi lhs + auto lhs = op.getLhs(); + parse(lhs, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo lhsOffsetInfo = offsetMap.at(lhs); + // Get addi rhs + auto rhs = op.getRhs(); + parse(rhs, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo rhsOffsetInfo = offsetMap.at(rhs); + // Set addi offset map + auto dst = op.getResult(); + offsetMap[dst] = combineInfo(lhsOffsetInfo, rhsOffsetInfo); + if (!(lhsOffsetInfo.isStructured() && rhsOffsetInfo.isScalarLike())) { + offsetMap[dst].setUnstructured(offsetMap[dst].getRank()); + } +} + +void parseIndexCast(arith::IndexCastOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get indexCast input + auto src = op.getIn(); + parse(src, op.getLoc(), rewriter, offsetMap); + // Set indexCast offset map + auto dst = op.getOut(); + offsetMap[dst] = offsetMap.at(src); +} + +template +void parseConstantOp(ConstOpTy dst, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Set constant offset map + offsetMap[dst] = PtrOffsetInfo(); + offsetMap[dst].setScalarLike(true); + if (auto tensorType = dyn_cast(dst->getResult(0).getType())) + offsetMap[dst].setStructured(tensorType.getRank()); +} + +void parseMakeRange(triton::MakeRangeOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Set makeRange offset map + auto dst = op.getResult(); + offsetMap[dst] = PtrOffsetInfo(); + offsetMap[dst].setStructured(1); +} + +void parseExtSI(arith::ExtSIOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get extSI input + auto src = op.getIn(); + parse(src, op.getLoc(), rewriter, offsetMap); + // Set extSI offset map + auto dst = op.getOut(); + offsetMap[dst] = offsetMap.at(src); +} + +void parseBitcast(triton::BitcastOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get bitcast src + auto src = op.getSrc(); + parse(src, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo srcOffsetInfo = offsetMap.at(src); + SmallVector &srcStructured = srcOffsetInfo.getStructuredRef(); + // Set extSI offset map + auto dst = op.getResult(); + if (auto ptr = srcOffsetInfo.getPtr()) { + Type ptrType = dst.getType(); + if (auto tensorType = dyn_cast(ptrType)) + ptrType = tensorType.getElementType(); + rewriter.setInsertionPoint(op); + ptr = rewriter.create(loc, ptrType, ptr); + offsetMap[dst] = + PtrOffsetInfo(ptr, srcOffsetInfo.getOffset(), srcStructured); + } else { + offsetMap[dst] = PtrOffsetInfo(srcStructured); + } + offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); +} + +void parseLoad(triton::LoadOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get load ptr + auto ptr = op.getPtr(); + parse(ptr, op.getLoc(), rewriter, offsetMap); + // Set load offset map + auto dst = op.getResult(); + offsetMap[dst] = PtrOffsetInfo(); + offsetMap[dst].setScalarLike(offsetMap[ptr].isScalarLike()); + auto tensorType = dyn_cast(dst.getType()); + if (!tensorType) + return; + offsetMap[dst].setUnstructured(tensorType.getRank()); +} + +void parseMulI(arith::MulIOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get muli lhs + auto lhs = op.getLhs(); + parse(lhs, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo lhsOffsetInfo = offsetMap.at(lhs); + SmallVector &lhsStructured = lhsOffsetInfo.getStructuredRef(); + bool lhsScalarLike = lhsOffsetInfo.isScalarLike(); + // Get muli rhs + auto rhs = op.getRhs(); + parse(rhs, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo rhsOffsetInfo = offsetMap.at(rhs); + SmallVector &rhsStructured = rhsOffsetInfo.getStructuredRef(); + bool rhsScalarLike = rhsOffsetInfo.isScalarLike(); + // Set muli offset map + size_t maxSize = std::max(lhsStructured.size(), rhsStructured.size()); + auto dst = op.getResult(); + offsetMap[dst] = PtrOffsetInfo(); + offsetMap[dst].setScalarLike(lhsScalarLike && rhsScalarLike); + SmallVector &dstStructured = offsetMap[dst].getStructuredRef(); + dstStructured.resize(maxSize); + for (size_t i = 0; i < maxSize; i++) + if (lhsScalarLike) + dstStructured[i] = rhsStructured[i]; + else if (rhsScalarLike) + dstStructured[i] = lhsStructured[i]; + else + dstStructured[i] = false; +} + +void parseBroadcast(triton::BroadcastOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get broadcast src + auto src = op.getSrcMutable().get(); + parse(src, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo srcOffsetInfo = offsetMap.at(src); + SmallVector &srcStructured = srcOffsetInfo.getStructuredRef(); + // Get broadcast dim + auto dst = op.getResult(); + assert(isa(src.getType()) && + "tt.broadcast's input should be a tensor"); + auto srcType = cast(src.getType()); + auto dstType = cast(dst.getType()); + assert(srcType.getRank() == dstType.getRank() && + "rank of source shoule be equal to destnation"); + auto broadcastDim = mlir::dicp::getBroadcastDims(srcType, dstType); + // Set broadcast offset map + offsetMap[dst] = PtrOffsetInfo(srcOffsetInfo.getPtr()); + offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); + + if (srcOffsetInfo.getPtr()) { + RewriterBase::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + Value valueOffset = srcOffsetInfo.getOffset(); + Value offset = rewriter.create( + loc, + RankedTensorType::get(dstType.getShape(), rewriter.getIntegerType(64)), + valueOffset); + + offsetMap[dst].setOffset(offset); + } + + SmallVector &dstStructured = offsetMap[dst].getStructuredRef(); + dstStructured.resize(srcStructured.size()); + for (size_t i = 0; i < dstStructured.size(); i++) + if (llvm::find(broadcastDim, i) != broadcastDim.end()) + dstStructured[i] = true; + else + dstStructured[i] = srcStructured[i]; +} + +void parseExpandDims(triton::ExpandDimsOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get expandDims src + auto src = op.getSrcMutable().get(); + parse(src, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo srcOffsetInfo = offsetMap.at(src); + SmallVector &srcStructured = srcOffsetInfo.getStructuredRef(); + // Set expandDims offset map + auto dst = op.getResult(); + offsetMap[dst] = PtrOffsetInfo(srcOffsetInfo.getPtr()); + offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); + if (srcOffsetInfo.getPtr()) { + RewriterBase::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + Value valueOffset = srcOffsetInfo.getOffset(); + Value offset = rewriter.create(loc, valueOffset, + op.getAxisAttr()); + + offsetMap[dst].setOffset(offset); + } + SmallVector &dstStructured = offsetMap[dst].getStructuredRef(); + dstStructured.resize(srcStructured.size() + 1); + size_t j = 0; + for (size_t i = 0; i < dstStructured.size(); i++) + if (i == op.getAxis()) { + dstStructured[i] = true; + } else { + dstStructured[i] = srcStructured[j]; + j++; + } +} + +void parseClampF(triton::ClampFOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get clampF src + auto src = op.getX(); + parse(src, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo srcOffsetInfo = offsetMap.at(src); + // Get clampF min + auto clampMin = op.getX(); + parse(clampMin, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo minOffsetInfo = offsetMap.at(clampMin); + // Get clampF max + auto clampMax = op.getX(); + parse(clampMax, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo maxOffsetInfo = offsetMap.at(clampMax); + // Set clampF offset map + auto dst = op.getResult(); + offsetMap[dst] = PtrOffsetInfo(); + offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike() && + minOffsetInfo.isScalarLike() && + maxOffsetInfo.isScalarLike()); + auto dstType = dyn_cast(dst.getType()); + if (!dstType) + return; + offsetMap[dst].setUnstructured(dstType.getRank()); +} + +void parseSelect(arith::SelectOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get select condition + auto condition = op.getCondition(); + parse(condition, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo conditionOffsetInfo = offsetMap.at(condition); + bool conditionScalarLike = conditionOffsetInfo.isScalarLike(); + // Get select trueValue + auto trueValue = op.getTrueValue(); + parse(trueValue, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo trueValueOffsetInfo = offsetMap.at(trueValue); + SmallVector &trueValueStructured = + trueValueOffsetInfo.getStructuredRef(); + bool trueValueScalarLike = trueValueOffsetInfo.isScalarLike(); + // Get select falseValue + auto falseValue = op.getFalseValue(); + parse(falseValue, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo falseValueOffsetInfo = offsetMap.at(falseValue); + SmallVector &falseValueStructured = + falseValueOffsetInfo.getStructuredRef(); + bool falseValueScalarLike = falseValueOffsetInfo.isScalarLike(); + // Set select offset map + auto dst = op.getResult(); + offsetMap[dst] = PtrOffsetInfo(); + auto dstType = dyn_cast(dst.getType()); + if (!dstType) + return; + offsetMap[dst].setUnstructured(dstType.getRank()); +} + +void parseFPToSI(arith::FPToSIOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get FPToSI src + auto src = op.getIn(); + parse(src, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo srcOffsetInfo = offsetMap.at(src); + // Set FPToSI offset map + auto dst = op.getResult(); + offsetMap[dst] = PtrOffsetInfo(); + offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); + auto dstType = dyn_cast(dst.getType()); + if (!dstType) + return; + if (offsetMap[dst].isScalarLike()) + offsetMap[dst].setStructured(dstType.getRank()); + else + offsetMap[dst].setUnstructured(dstType.getRank()); +} + +void parseSIToFP(arith::SIToFPOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get SIToFP src + auto src = op.getIn(); + parse(src, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo srcOffsetInfo = offsetMap.at(src); + // Set SIToFP offset map + auto dst = op.getResult(); + offsetMap[dst] = PtrOffsetInfo(); + offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); + auto dstType = dyn_cast(dst.getType()); + if (!dstType) + return; + if (offsetMap[dst].isScalarLike()) + offsetMap[dst].setStructured(dstType.getRank()); + else + offsetMap[dst].setUnstructured(dstType.getRank()); +} + +void parseMakeTensorDesc(triton::MakeTensorDescOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Set MakeTensorDesc offset map + auto dst = op.getResult(); + offsetMap[dst] = PtrOffsetInfo(); + auto dstType = dyn_cast(dst.getType()); + if (!dstType) + return; + offsetMap[dst].setStructured(dstType.getRank()); +} + +void parseMakeTensorPtr(triton::MakeTensorPtrOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Set MakeTensorPtr offset map + auto dst = op.getResult(); + offsetMap[dst] = PtrOffsetInfo(dst); + auto dstType = dyn_cast( + cast(dst.getType()).getPointeeType()); + if (!dstType) + return; + offsetMap[dst].setStructured(dstType.getRank()); + offsetMap[dst].setOffsets(op.getOffsets()); +} + +void parseAdvance(triton::AdvanceOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Set Advance offset map + auto ptr = op.getPtr(); + parse(ptr, op.getLoc(), rewriter, offsetMap); + auto dst = op.getResult(); + offsetMap[dst] = offsetMap.at(ptr); + auto dstType = dyn_cast( + cast(dst.getType()).getPointeeType()); + if (!dstType) + return; + offsetMap[dst].setStructured(dstType.getRank()); + auto &offsets = offsetMap[dst].getOffsetsRef(); + + RewriterBase::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + for (auto [curOffset, opOffset] : llvm::zip(offsets, op.getOffsets())) { + curOffset = + rewriter.create(op.getLoc(), curOffset, opOffset); + } +} + +void parseReduce(triton::ReduceOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get reduce src + Value src = op->getOperand(0); + parse(src, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo srcOffsetInfo = offsetMap.at(src); + SmallVector &srcStructured = srcOffsetInfo.getStructuredRef(); + // Set reduce offset map + Value dst = op->getResult(0); + auto dstType = dyn_cast(dst.getType()); + offsetMap[dst] = PtrOffsetInfo(); + offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); + if (!dstType) + return; + SmallVector &dstStructured = offsetMap[dst].getStructuredRef(); + auto dstShape = dstType.getShape(); + dstStructured.resize(dstShape.size()); + for (size_t i = 0; i < dstStructured.size(); i++) + if (dstShape[i] == 1) + dstStructured[i] = true; + else + dstStructured[i] = srcStructured[i]; +} + +void parseReduceReturn(triton::ReduceReturnOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get reduce src + Value src = op->getOperand(0); + parse(src, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo srcOffsetInfo = offsetMap.at(src); + SmallVector &srcStructured = srcOffsetInfo.getStructuredRef(); + // Set reduce offset map + Value dst = op->getResult(0); + auto dstType = dyn_cast(dst.getType()); + offsetMap[dst] = PtrOffsetInfo(); + offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); + if (!dstType) + return; + SmallVector &dstStructured = offsetMap[dst].getStructuredRef(); + auto dstShape = dstType.getShape(); + dstStructured.resize(dstShape.size()); + for (size_t i = 0; i < dstStructured.size(); i++) + if (dstShape[i] == 1) + dstStructured[i] = true; + else + dstStructured[i] = srcStructured[i]; +} + +void parseIf(scf::IfOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap, Value dst) { + const unsigned int index = cast(dst).getResultNumber(); + // Get if then region + Block &thenBlock = op.getThenRegion().front(); + Value thenYieldedValue = thenBlock.getTerminator()->getOperand(index); + parse(thenYieldedValue, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo thenOffsetInfo = offsetMap.at(thenYieldedValue); + SmallVector &thenStructured = thenOffsetInfo.getStructuredRef(); + // Get if else region + bool dstIsScalar = thenOffsetInfo.isScalarLike(); + SmallVector elseStructured; + if (op.elseBlock()) { + Block &elseBlock = op.getElseRegion().front(); + Value elseYieldedValue = elseBlock.getTerminator()->getOperand(index); + parse(elseYieldedValue, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo elseOffsetInfo = offsetMap.at(elseYieldedValue); + elseStructured = elseOffsetInfo.getStructuredRef(); + dstIsScalar = dstIsScalar && elseOffsetInfo.isScalarLike(); + } + // Set if offset map + offsetMap[dst] = PtrOffsetInfo(); + offsetMap[dst].setScalarLike(dstIsScalar); + SmallVector &dstStructured = offsetMap[dst].getStructuredRef(); + dstStructured.resize(thenStructured.size()); + for (size_t i = 0; i < dstStructured.size(); i++) + if (op.elseBlock()) + dstStructured[i] = thenStructured[i] && elseStructured[i]; + else + dstStructured[i] = thenStructured[i]; +} + +void parseYield(scf::YieldOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get yield src + for (auto src : op->getOperands()) + parse(src, op.getLoc(), rewriter, offsetMap); +} + +void parseLoopOp(LoopLikeOpInterface op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap, Value dst) { + auto resNum = cast(dst).getResultNumber(); + Value yieldedValue = nullptr; + if (auto whileOp = dyn_cast(op.getOperation())) { + yieldedValue = whileOp.getConditionOp().getArgs()[resNum]; + } else { + yieldedValue = op.getYieldedValues()[resNum]; + } + parse(yieldedValue, op.getLoc(), rewriter, offsetMap); + offsetMap[dst] = offsetMap.at(yieldedValue); +} + +void parseExtractSlice(tensor::ExtractSliceOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get extractSlice src + auto src = op.getOperand(0); + parse(src, op.getLoc(), rewriter, offsetMap); + // Set extractSlice offset map + auto dst = op.getResult(); + offsetMap[dst] = offsetMap.at(src); +} + +void parseExtract(tensor::ExtractOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + auto parentValue = op.getTensor(); + parse(parentValue, op.getLoc(), rewriter, offsetMap); + auto dst = op.getResult(); + offsetMap[dst] = PtrOffsetInfo(); + if (isa(dst.getType())) { + offsetMap[dst].setPtr(dst); + } + offsetMap[dst].setScalarLike(true); +} + +void parseIntToPtr(triton::IntToPtrOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + auto dst = op.getResult(); + offsetMap[dst] = PtrOffsetInfo(dst); + offsetMap[dst].setScalarLike(true); +} + +} // namespace triton +} // namespace mlir diff --git a/compiler/lib/Conversion/TritonToUnstructure/UnstructureConversionPass.cpp b/compiler/lib/Conversion/TritonToUnstructure/UnstructureConversionPass.cpp new file mode 100644 index 00000000..3f69b378 --- /dev/null +++ b/compiler/lib/Conversion/TritonToUnstructure/UnstructureConversionPass.cpp @@ -0,0 +1,882 @@ +#include "dicp/Conversion/TritonToUnstructure/UnstructureConversionPass.h" +#include "dicp/Utils/Utils.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "bishengir/Dialect/Annotation/IR/Annotation.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/STLExtras.h" + +#include + +#define DEBUG_TYPE "triton-unstructure-converter" + +using namespace mlir; +using namespace triton; + +#include "llvm/Support/Debug.h" + +template +bool UnstructuredMemAccessConverter::checkUnstructureAnnotated( + MemAccOpTy op, PatternRewriter &rewriter) const { + return llvm::any_of(op->getUsers(), [&rewriter](Operation *user) { + auto annotationOp = dyn_cast(user); + if (annotationOp && annotationOp->hasAttr("mayDiscretememaccess")) { + rewriter.eraseOp(annotationOp); + return true; + } + return false; + }); +} + +template <> +bool UnstructuredMemAccessConverter::checkUnstructureAnnotated( + triton::StoreOp op, PatternRewriter &rewriter) const { + return llvm::any_of(op.getValue().getUsers(), [&rewriter](Operation *user) { + auto annotationOp = dyn_cast(user); + if (annotationOp && annotationOp->hasAttr("mayDiscretememaccess")) { + rewriter.eraseOp(annotationOp); + return true; + } + return false; + }); +} + +template +Value UnstructuredMemAccessConverter::createExtractOp( + Location loc, Value value, PatternRewriter &rewriter, + ArrayRef iterIdx) const { + if (!value) + return value; + SmallVector indices; + for (auto idx : iterIdx) { + if (auto val = dyn_cast(idx)) { + indices.push_back(val); + } else { + auto idxVal = rewriter.create( + loc, rewriter.getIndexAttr(*getConstantIntValue(idx))); + indices.push_back(idxVal); + } + } + auto extractedOp = rewriter.create(loc, value, indices); + extractedOp->setAttr(mlir::dicp::discreteAttrName, + UnitAttr::get(rewriter.getContext())); + return extractedOp; +} + +template +Value UnstructuredMemAccessConverter::createExtractOp( + Location loc, Value value, PatternRewriter &rewriter, + ArrayRef offsets, ArrayRef sizes, + ArrayRef strides) const { + if (!value) + return value; + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "Extracting\n"; + os << value << "\n"; + }); + auto extractedOp = rewriter.create( + loc, value, offsets, sizes, strides); + extractedOp->setAttr(mlir::dicp::discreteAttrName, + UnitAttr::get(rewriter.getContext())); + return extractedOp; +} + +template <> +template +triton::LoadOp UnstructuredMemAccessConverter::createMemAccOp( + triton::LoadOp op, Value ptrToAccess, Location loc, + PatternRewriter &rewriter, Args &&...args) const { + return rewriter.create(loc, ptrToAccess, op.getCache(), + op.getEvict(), op.getIsVolatile()); +} + +template <> +template +triton::AtomicRMWOp +UnstructuredMemAccessConverter::createMemAccOp( + triton::AtomicRMWOp op, Value ptrToAccess, Location loc, + PatternRewriter &rewriter, Args &&...args) const { + auto extractedValue = + createExtractOp(loc, op.getVal(), rewriter, std::forward(args)...); + auto extractedMask = + createExtractOp(loc, op.getMask(), rewriter, std::forward(args)...); + Type targetType = ptrToAccess.getType(); + if (auto tensorType = dyn_cast(targetType)) { + auto ptrType = cast(tensorType.getElementType()); + targetType = + RankedTensorType::get(tensorType.getShape(), ptrType.getPointeeType()); + } else { + auto resultType = cast(op.getResult().getType()); + SmallVector scalarLikeShape(resultType.getRank(), 1); + targetType = + RankedTensorType::get(scalarLikeShape, resultType.getElementType()); + ptrToAccess = rewriter.create( + loc, RankedTensorType::get(scalarLikeShape, ptrToAccess.getType()), + ptrToAccess); + extractedValue = rewriter.create( + loc, RankedTensorType::get(scalarLikeShape, extractedValue.getType()), + extractedValue); + if (extractedMask) { + extractedMask = rewriter.create( + loc, RankedTensorType::get(scalarLikeShape, extractedMask.getType()), + extractedMask); + } + } + return rewriter.create( + loc, targetType, op.getAtomicRmwOpAttr(), ptrToAccess, extractedValue, + extractedMask, op.getSemAttr(), op.getScopeAttr()); +} + +template <> +template +triton::AtomicCASOp +UnstructuredMemAccessConverter::createMemAccOp( + triton::AtomicCASOp op, Value ptrToAccess, Location loc, + PatternRewriter &rewriter, Args &&...args) const { + auto extractedCmp = + createExtractOp(loc, op.getCmp(), rewriter, std::forward(args)...); + auto extractedValue = + createExtractOp(loc, op.getVal(), rewriter, std::forward(args)...); + Type targetType = ptrToAccess.getType(); + if (auto tensorType = dyn_cast(targetType)) { + auto ptrType = cast(tensorType.getElementType()); + targetType = + RankedTensorType::get(tensorType.getShape(), ptrType.getPointeeType()); + } else { + auto resultType = cast(op.getResult().getType()); + SmallVector scalarLikeShape(resultType.getRank(), 1); + targetType = + RankedTensorType::get(scalarLikeShape, resultType.getElementType()); + ptrToAccess = rewriter.create( + loc, RankedTensorType::get(scalarLikeShape, ptrToAccess.getType()), + ptrToAccess); + extractedCmp = rewriter.create( + loc, RankedTensorType::get(scalarLikeShape, extractedCmp.getType()), + extractedCmp); + extractedValue = rewriter.create( + loc, RankedTensorType::get(scalarLikeShape, extractedValue.getType()), + extractedValue); + } + return rewriter.create( + loc, targetType, ptrToAccess, extractedCmp, extractedValue, + op.getSemAttr(), op.getScopeAttr()); +} + +template <> +template +triton::StoreOp UnstructuredMemAccessConverter::createMemAccOp( + triton::StoreOp op, Value ptrToAccess, Location loc, + PatternRewriter &rewriter, Args &&...args) const { + auto extractedValue = createExtractOp(loc, op.getValue(), rewriter, + std::forward(args)...); + auto extractedMask = + createExtractOp(loc, op.getMask(), rewriter, std::forward(args)...); + return rewriter.create(loc, ptrToAccess, extractedValue, + extractedMask); +} + +template <> +template <> +void UnstructuredMemAccessConverter::splatAndLoadScenario< + triton::LoadOp>(triton::LoadOp op, int rank, + PatternRewriter &rewriter) const { + auto loc = op.getLoc(); + SmallVector idx(rank, rewriter.getIndexAttr(0)); + auto extractedPtr = createExtractOp(loc, op.getPtr(), rewriter, idx); + Value mask = op.getMask(); + Value other = op.getOther(); + Value loadedValue = rewriter.create( + loc, extractedPtr, /*mask=*/nullptr, /*other=*/nullptr, + /*boundaryCheck=*/ArrayRef(), + /*PaddingOptionAttr=*/nullptr); + loadedValue = rewriter.create(loc, op.getResult().getType(), + loadedValue); + if (mask) + rewriter.replaceOpWithNewOp(op, mask, loadedValue, other); + else + rewriter.replaceOp(op, loadedValue); +} + +template +UnstructuredMemAccessConverter::UnstructuredMemAccessConverter( + MLIRContext *context, bool forceScalarizeMode, + const llvm::DenseMap &offsetMap, + const llvm::SmallDenseMap &fromTensorArg) + : OpRewritePattern(context), + forceScalarizeMode(forceScalarizeMode), offsetMap(offsetMap), + fromTensorArg(fromTensorArg) {} + +template +LogicalResult UnstructuredMemAccessConverter::matchAndRewrite( + MemAccOpTy op, PatternRewriter &rewriter) const { + auto loc = op.getLoc(); + + auto ptr = op.getPtr(); + auto ptrType = dyn_cast(ptr.getType()); + + if (auto ptrPtrType = dyn_cast(ptr.getType())) { + if (auto ptrTensorType = + dyn_cast_or_null(ptrPtrType.getPointeeType())) + ptrType = ptrTensorType; + } + + if (!ptrType || op->hasAttr(mlir::dicp::discreteAttrName)) + return failure(); + if (!offsetMap.contains(ptr)) + return op.emitError() << "PtrOffsetInfo should be computed\n" << ptr; + + auto ptrOffsetInfo = offsetMap.at(ptr); + + if (checkUnstructureAnnotated(op, rewriter)) + ptrOffsetInfo.setUnstructured(ptrOffsetInfo.getRank()); + + if (ptrOffsetInfo.isStructured() && + (!ptrOffsetInfo.isScalarLike() || + llvm::all_of(ptrType.getShape(), [](int64_t dim) { return dim == 1; }))) + return failure(); + + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "Converting " << op->getName() << "\n"; + os << op << "\n"; + os << ptrOffsetInfo.isStructured() << "\n"; + os << ptrOffsetInfo.isScalarLike() << "\n"; + }); + + if constexpr (std::is_same_v) + if (ptrOffsetInfo.isScalarLike()) { + splatAndLoadScenario(op, ptrOffsetInfo.getRank(), rewriter); + return success(); + } + + if (op->hasAttr(mlir::dicp::discreteMaskAttrName)) { + if constexpr (std::is_same_v) { + auto selectOp = op.getValue().template getDefiningOp(); + op = rewriter.replaceOpWithNewOp( + op, op.getPtr(), selectOp.getTrueValue(), selectOp.getCondition(), + op.getCache(), op.getEvict()); + rewriter.setInsertionPoint(op); + ptrOffsetInfo.setUnstructured(ptrOffsetInfo.getRank()); + } else if constexpr (std::is_same_v) { + auto selectOp = op.getVal().template getDefiningOp(); + op = rewriter.replaceOpWithNewOp( + op, op.getType(), op.getAtomicRmwOp(), op.getPtr(), + selectOp.getTrueValue(), selectOp.getCondition(), op.getSem(), + op.getScope()); + } + rewriter.setInsertionPoint(op); + ptrOffsetInfo.setUnstructured(ptrOffsetInfo.getRank()); + } + + if (forceScalarizeMode || ptrOffsetInfo.isScalarLike() || + fromTensorArg.at(ptr)) { + ptrOffsetInfo.setUnstructured(ptrOffsetInfo.getRank()); + } + + auto srcPtr = ptrOffsetInfo.getPtr(); + auto ptrOffset = ptrOffsetInfo.getOffset(); + + // LoadLike is operation with result + bool isLoadLike = !op->use_empty(); + + Value zeroIdx = + rewriter.create(loc, rewriter.getIndexAttr(0)); + Value oneIdx = + rewriter.create(loc, rewriter.getIndexAttr(1)); + auto resultShape = ptrType.getShape(); + auto resultElementType = ptrType.getElementType(); + if (auto pointerType = + dyn_cast(ptrType.getElementType())) { + resultElementType = pointerType.getPointeeType(); + } + + int64_t sizeInByte; + if (auto intType = dyn_cast(resultElementType)) { + sizeInByte = intType.getWidth() / 8; + } else if (auto floatType = dyn_cast(resultElementType)) { + sizeInByte = floatType.getWidth() / 8; + } else { + llvm_unreachable("Unhandled element type of tensor"); + } + + for (int i = ptrOffsetInfo.getRank() - 1; i >= 0; i--) { + if (!ptrOffsetInfo.isStructured(i)) + break; + sizeInByte *= resultShape[i]; + } + + // Force scalarize if memory is not aligned + if (sizeInByte % 32 != 0) + ptrOffsetInfo.setUnstructured(ptrOffsetInfo.getRank()); + + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "UnStructured Flag check:\n"; + os << "ptrOffsetInfo.isStructured: " << ptrOffsetInfo.isStructured() + << "\n"; + }); + + Value iterArg = nullptr; + + // Only load case + if (isLoadLike) { + iterArg = + rewriter.create(loc, resultShape, resultElementType); + } + Value newOpResult = nullptr; + + auto insertPoint = rewriter.saveInsertionPoint(); + + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + SmallVector extractedShape; + + for (size_t i = 0; i < resultShape.size(); i++) { + auto size = resultShape[i]; + auto structured = ptrOffsetInfo.getStructuredRef()[i]; + // handle indirect dimension + strides.push_back(rewriter.getIndexAttr(1)); + Value sizeVal = + rewriter.create(loc, rewriter.getIndexAttr(size)); + if (structured) { + offsets.push_back(rewriter.getIndexAttr(0)); + sizes.push_back(rewriter.getIndexAttr(size)); + extractedShape.push_back(size); + } else { + scf::ForOp forOp; + if (auto mtptOp = + srcPtr.template getDefiningOp()) { + auto tptShape = mtptOp.getShape()[i]; + if (tptShape.getType() != rewriter.getIndexType()) { + tptShape = rewriter.create( + loc, rewriter.getIndexType(), tptShape); + } + sizeVal = rewriter.create(loc, sizeVal, tptShape); + } + if (isLoadLike) { + forOp = rewriter.create(loc, zeroIdx, sizeVal, oneIdx, + ValueRange({iterArg})); + if (!newOpResult) { + newOpResult = forOp->getResult(0); + } else { + rewriter.create(loc, forOp->getResult(0)); + } + iterArg = forOp.getRegionIterArg(0); + } else { + forOp = rewriter.create(loc, zeroIdx, sizeVal, oneIdx); + } + sizes.push_back(rewriter.getIndexAttr(1)); + offsets.push_back(forOp.getInductionVar()); + extractedShape.push_back(1); + forOp->setAttr("ExtractedLoadOrStore", + UnitAttr::get(rewriter.getContext())); + rewriter.setInsertionPointToStart(forOp.getBody()); + } + } + + bool fullyUnstructured = ptrOffsetInfo.isUnstructured(); + auto extractedType = RankedTensorType::get(extractedShape, resultElementType); + + Value extractedOffset; + if (fullyUnstructured) { + if (auto mtptOp = + srcPtr.template getDefiningOp()) { + auto I64Type = rewriter.getIntegerType(64); + srcPtr = mtptOp.getBase(); + extractedOffset = rewriter.create(loc, 0, 64); + for (auto [indVar, offset, stride] : llvm::zip_equal( + offsets, ptrOffsetInfo.getOffsets(), mtptOp.getStrides())) { + Value inductionVar = rewriter.create( + loc, I64Type, cast(indVar)); + Value tptOffset = rewriter.create(loc, I64Type, offset); + Value tptStride = rewriter.create(loc, I64Type, stride); + tptOffset = rewriter.create(loc, tptStride, tptOffset); + tptStride = + rewriter.create(loc, tptStride, inductionVar); + extractedOffset = + rewriter.create(loc, extractedOffset, tptOffset); + extractedOffset = + rewriter.create(loc, extractedOffset, tptStride); + } + } else { + extractedOffset = createExtractOp(loc, ptrOffset, rewriter, offsets); + } + } else { + extractedOffset = + createExtractOp(loc, ptrOffset, rewriter, offsets, sizes, strides); + } + + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "Extracted offset\n"; + os << extractedOffset << "\n"; + }); + + assert(isa(srcPtr.getType()) && "src must be ptr type"); + if (!fullyUnstructured) { + srcPtr = rewriter.create( + loc, RankedTensorType::get(extractedShape, srcPtr.getType()), srcPtr); + } + Value ptrToAccess = rewriter.create( + loc, srcPtr.getType(), srcPtr, extractedOffset); + + MemAccOpTy accessedOp; + if (fullyUnstructured) { + accessedOp = createMemAccOp(op, ptrToAccess, loc, rewriter, offsets); + } else { + accessedOp = + createMemAccOp(op, ptrToAccess, loc, rewriter, offsets, sizes, strides); + } + + accessedOp->setAttr(mlir::dicp::discreteAttrName, + UnitAttr::get(rewriter.getContext())); + + if (isLoadLike) { + assert(iterArg && "Load case must have iterArg in for loop"); + + Value value = accessedOp->getResult(0); + Value result; + if (!isa(value.getType()) && + (std::is_same_v || + std::is_same_v)) { + value = rewriter.create(loc, extractedType, value); + } + if (!isa(value.getType())) { + SmallVector indices; + for (auto idx : offsets) { + if (auto val = dyn_cast(idx)) { + indices.push_back(val); + } else { + auto idxVal = rewriter.create( + loc, rewriter.getIndexAttr(*getConstantIntValue(idx))); + indices.push_back(idxVal); + } + } + result = rewriter.create(loc, value, iterArg, indices); + } else { + result = rewriter.create(loc, value, iterArg, + offsets, sizes, strides); + } + rewriter.create(loc, result) + ->setAttr(mlir::dicp::discreteAttrName, + UnitAttr::get(rewriter.getContext())); + rewriter.restoreInsertionPoint(insertPoint); + if constexpr (std::is_same_v) { + if (op.getMask() && op.getOther()) { + rewriter + .replaceOpWithNewOp(op, op.getMask(), newOpResult, + op.getOther()) + ->setAttr(mlir::dicp::discreteAttrName, + UnitAttr::get(rewriter.getContext())); + } else { + rewriter.replaceOp(op, newOpResult); + } + } else { + rewriter.replaceOp(op, newOpResult); + } + } else { + if constexpr (std::is_same_v) { + if (fullyUnstructured && accessedOp.getMask()) { + auto mask = createExtractOp( + loc, accessedOp.getMask(), rewriter, + SmallVector(ptrOffsetInfo.getRank(), + rewriter.getIndexAttr(0))); + rewriter.create(loc, mask, [&](OpBuilder &b, Location loc) { + b.create( + loc, accessedOp.getType(), accessedOp.getAtomicRmwOp(), + accessedOp.getPtr(), accessedOp.getVal(), nullptr, + accessedOp.getSem(), accessedOp.getScope()) + ->setAttr(mlir::dicp::discreteAttrName, + UnitAttr::get(rewriter.getContext())); + b.create(loc); + }); + rewriter.eraseOp(accessedOp); + } + } + rewriter.eraseOp(op); + } + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "After conversion\n" + << ptrToAccess.getDefiningOp() + ->template getParentOfType() + << "\n"; + }); + return success(); +} + +void replaceOperands(MutableArrayRef oprs, RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + for (auto it = oprs.begin(); it != oprs.end(); ++it) { + auto &opr = *it; + auto operand = opr.get(); + if (auto tensorType = dyn_cast(operand.getType()); + tensorType && isa(tensorType.getElementType())) { + parse(operand, operand.getLoc(), rewriter, offsetMap); + opr.set(offsetMap.at(operand).getOffset()); + } else if (auto ptrType = + dyn_cast(operand.getType())) { + parse(operand, operand.getLoc(), rewriter, offsetMap); + if (auto tensorType = + dyn_cast(ptrType.getPointeeType())) { + for (auto offset : offsetMap.at(operand).getOffsets()) { + it->set(offset); + ++it; + } + --it; + } else { + opr.set(offsetMap.at(operand).getOffset()); + } + } + } +} + +void replaceArgs(ValueRange args, RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + for (auto it = args.begin(); it != args.end(); ++it) { + auto arg = *it; + if (auto tensorType = dyn_cast(arg.getType()); + tensorType && isa(tensorType.getElementType())) { + RewriterBase::InsertionGuard guard(rewriter); + if (auto blockArg = dyn_cast(arg)) { + rewriter.setInsertionPointToStart(blockArg.getOwner()); + } else { + rewriter.setInsertionPointAfterValue(arg); + } + auto tempVar = rewriter + .create( + arg.getLoc(), arg.getType(), ValueRange({})) + ->getResult(0); + parse(arg, arg.getLoc(), rewriter, offsetMap); + auto src = offsetMap.at(arg).getPtr(); + rewriter.replaceAllUsesWith(arg, tempVar); + arg.setType(RankedTensorType::get(tensorType.getShape(), + rewriter.getIntegerType(64))); + src = rewriter.create(arg.getLoc(), tempVar.getType(), + src); + rewriter.replaceOpWithNewOp( + tempVar.getDefiningOp(), tempVar.getType(), src, arg); + } else if (auto ptrType = dyn_cast(arg.getType())) { + RewriterBase::InsertionGuard guard(rewriter); + if (auto blockArg = dyn_cast(arg)) { + rewriter.setInsertionPointToStart(blockArg.getOwner()); + } else { + rewriter.setInsertionPointAfterValue(arg); + } + auto tempVar = rewriter + .create( + arg.getLoc(), arg.getType(), ValueRange({})) + ->getResult(0); + parse(arg, arg.getLoc(), rewriter, offsetMap); + rewriter.replaceAllUsesWith(arg, tempVar); + if (auto tensorType = + dyn_cast(ptrType.getPointeeType())) { + auto srcOp = + offsetMap.at(arg).getPtr().getDefiningOp(); + arg.setType(rewriter.getIntegerType(32)); + SmallVector newOffsets; + for (auto offset : offsetMap.at(arg).getOffsets()) { + newOffsets.push_back(*it); + ++it; + } + --it; + rewriter.replaceOpWithNewOp( + tempVar.getDefiningOp(), tempVar.getType(), srcOp.getBase(), + srcOp.getShape(), srcOp.getStrides(), newOffsets, srcOp.getOrder()); + } else { + auto src = offsetMap.at(arg).getPtr(); + arg.setType(rewriter.getIntegerType(64)); + rewriter.replaceOpWithNewOp( + tempVar.getDefiningOp(), tempVar.getType(), src, arg); + } + } + } +} + +void convertTensorPtrPre(LoopLikeOpInterface op, RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + if (auto whileOp = dyn_cast(op.getOperation())) { + replaceArgs(whileOp.getBeforeArguments(), rewriter, offsetMap); + replaceOperands(whileOp.getInitsMutable(), rewriter, offsetMap); + replaceArgs(whileOp.getAfterArguments(), rewriter, offsetMap); + replaceArgs(whileOp->getResults(), rewriter, offsetMap); + replaceOperands(whileOp.getConditionOp().getArgsMutable(), rewriter, + offsetMap); + } else { + replaceArgs(op.getRegionIterArgs(), rewriter, offsetMap); + replaceOperands(op.getInitsMutable(), rewriter, offsetMap); + } +} + +void convertTensorPtrPost(LoopLikeOpInterface op, RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + if (auto whileOp = dyn_cast(op.getOperation())) { + replaceOperands(whileOp.getYieldOp()->getOpOperands(), rewriter, offsetMap); + } else { + replaceArgs(op->getResults(), rewriter, offsetMap); + replaceOperands(*op.getYieldedValuesMutable(), rewriter, offsetMap); + } +} + +int getPtrTensorRank(Type type) { + if (auto ptrType = dyn_cast(type)) { + if (auto tensorType = + dyn_cast(ptrType.getPointeeType())) { + return tensorType.getRank(); + } + } + return 0; +} + +SmallVector constructOperands(ValueRange operands, Value tempVar, + IRMapping mapping) { + SmallVector newOperands; + for (auto opr : operands) { + opr = mapping.lookupOrDefault(opr); + newOperands.push_back(opr); + auto numAppend = getPtrTensorRank(opr.getType()) - 1; + if (numAppend > 0) + newOperands.append(numAppend, tempVar); + } + return newOperands; +} + +SmallVector constructTypes(TypeRange types) { + SmallVector newTypes; + for (auto type : types) { + newTypes.push_back(type); + if (auto ptrType = dyn_cast(type)) { + if (auto tensorType = + dyn_cast(ptrType.getPointeeType())) { + if (tensorType.getRank() > 0) + newTypes.append(tensorType.getRank() - 1, + IntegerType::get(type.getContext(), 32)); + } + } + } + return newTypes; +} + +void replacePtrLoopArguments(Operation *rootOp, + llvm::DenseMap &offsetMap) { + std::function convertTensorPtr = + [&](LoopLikeOpInterface op) { + IRRewriter rewriter(op.getContext()); + IRMapping mapping; + LoopLikeOpInterface newOp; + rewriter.setInsertionPointAfter(op); + Value tempVar = + rewriter + .create( + op.getLoc(), rewriter.getI32Type(), ValueRange({})) + ->getResult(0); + if (auto forOp = dyn_cast(op.getOperation())) { + newOp = rewriter.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), + constructOperands(forOp.getInitArgs(), tempVar, mapping), + [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { + mapping.map(forOp.getInductionVar(), iv); + auto newArgIter = args.begin(); + for (auto oldArg : forOp.getRegionIterArgs()) { + mapping.map(oldArg, *newArgIter); + std::advance(newArgIter, + std::max(getPtrTensorRank(oldArg.getType()), 1)); + } + for (auto &bodyOp : forOp.getBody()->without_terminator()) { + b.clone(bodyOp, mapping); + } + auto yieldOp = + cast(forOp.getBody()->getTerminator()); + b.create( + yieldOp.getLoc(), + constructOperands(yieldOp.getOperands(), tempVar, mapping)); + }); + } else if (auto whileOp = dyn_cast(op.getOperation())) { + newOp = rewriter.create( + whileOp.getLoc(), constructTypes(whileOp->getResultTypes()), + constructOperands(whileOp.getInits(), tempVar, mapping), + [&](OpBuilder &b, Location loc, ValueRange args) { + auto newArgIter = args.begin(); + for (auto oldArg : whileOp.getBeforeArguments()) { + mapping.map(oldArg, *newArgIter); + std::advance(newArgIter, + std::max(getPtrTensorRank(oldArg.getType()), 1)); + } + for (auto &bodyOp : + whileOp.getBeforeBody()->without_terminator()) { + b.clone(bodyOp, mapping); + } + auto conditionOp = whileOp.getConditionOp(); + b.create( + conditionOp.getLoc(), + mapping.lookup(conditionOp.getCondition()), + constructOperands(conditionOp.getArgs(), tempVar, mapping)); + }, + [&](OpBuilder &b, Location loc, ValueRange args) { + auto newArgIter = args.begin(); + for (auto oldArg : whileOp.getAfterArguments()) { + mapping.map(oldArg, *newArgIter); + std::advance(newArgIter, + std::max(getPtrTensorRank(oldArg.getType()), 1)); + } + for (auto &bodyOp : + whileOp.getAfterBody()->without_terminator()) { + b.clone(bodyOp, mapping); + } + auto yieldOp = whileOp.getYieldOp(); + b.create( + yieldOp.getLoc(), + constructOperands(yieldOp.getOperands(), tempVar, mapping)); + }); + } else { + llvm_unreachable("Unsupported loop op"); + } + auto resIter = newOp->result_begin(); + for (auto res : op->getResults()) { + rewriter.replaceAllUsesWith(res, *resIter); + std::advance(resIter, std::max(getPtrTensorRank(res.getType()), 1)); + } + rewriter.eraseOp(op); + op = newOp; + convertTensorPtrPre(op, rewriter, offsetMap); + for (auto *region : op.getLoopRegions()) + region->walk(convertTensorPtr); + convertTensorPtrPost(op, rewriter, offsetMap); + return WalkResult::skip(); + }; + + rootOp->walk(convertTensorPtr); +} + +void TritonToUnstructurePass::runPreparse(LoopLikeOpInterface op) { + IRRewriter rewriter(&getContext()); + auto loc = op.getLoc(); + + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "Pre-parsing " << op->getName() << "\n" << op << "\n"; + }); + + Block::BlockArgListType args; + ValueRange yields; + if (auto whileOp = dyn_cast(op.getOperation())) { + args = whileOp.getBeforeArguments(); + yields = whileOp.getYieldOp().getOperands(); + } else { + args = op.getRegionIterArgs(); + yields = op.getYieldedValues(); + } + + for (auto [arg, yield] : llvm::zip_equal(args, yields)) { + if (auto tensorType = dyn_cast(yield.getType())) { + parse(yield, loc, rewriter, offsetMapForLoopArgs); + offsetMap[arg] = offsetMapForLoopArgs.at(yield); + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "Pre-parsing result of\n" << arg << "\nis "; + for (auto structured : offsetMap[arg].getStructuredRef()) + os << structured; + os << '\n'; + }); + } + } +} + +static bool isFromTensorArg(Value v, + llvm::SmallDenseMap &fromTensorArg) { + if (fromTensorArg.contains(v)) + return fromTensorArg.at(v); + auto *defOp = v.getDefiningOp(); + if (!defOp) { + fromTensorArg[v] = isa(v.getType()); + return isa(v.getType()); + } + for (auto opr : defOp->getOperands()) { + if (isFromTensorArg(opr, fromTensorArg)) { + fromTensorArg[v] = true; + return true; + } + } + fromTensorArg[v] = false; + return false; +} + +template +void TritonToUnstructurePass::runParse(MemAccOpTy op) { + IRRewriter rewriter(&getContext()); + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "Parsing " << op->getName() << "\n" << op << "\n"; + }); + parse(op.getPtr(), op.getLoc(), rewriter, offsetMap); + isFromTensorArg(op.getPtr(), fromTensorArg); +} + +void TritonToUnstructurePass::runOnOperation() { + + ModuleOp moduleOp = getOperation(); + MLIRContext *ctx = &getContext(); + + replacePtrLoopArguments(moduleOp, offsetMapForLoopArgs); + offsetMapForLoopArgs.clear(); + moduleOp->walk([this](LoopLikeOpInterface op) { runPreparse(op); }); + moduleOp->walk([this](Operation *op) { + if (auto loadOp = dyn_cast(op)) { + runParse(loadOp); + } else if (auto storeOp = dyn_cast(op)) { + runParse(storeOp); + } else if (auto atomicRMWOp = dyn_cast(op)) { + runParse(atomicRMWOp); + } else if (auto atomicCASOp = dyn_cast(op)) { + runParse(atomicCASOp); + } + }); + + RewritePatternSet patterns(ctx); + + patterns.add, + UnstructuredMemAccessConverter, + UnstructuredMemAccessConverter, + UnstructuredMemAccessConverter>( + ctx, forceScalarizeMode, offsetMap, fromTensorArg); + + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "Parsing done\n"; + }); + + if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) { + moduleOp->emitError("failed to apply Patterns"); + signalPassFailure(); + } + + PassManager pm(&getContext(), moduleOp.getOperationName()); + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + if (failed(runPipeline(pm, getOperation()))) { + signalPassFailure(); + } +} + +void TritonToUnstructurePass::getDependentDialects( + DialectRegistry ®istry) const { + registry.insert(); +} + +std::unique_ptr> +triton::createTritonToUnstructurePass() { + return std::make_unique(); +} diff --git a/compiler/lib/Dialect/LinalgExt/Transforms/LinalgGenericToSCF.cpp b/compiler/lib/Dialect/LinalgExt/Transforms/LinalgGenericToSCF.cpp index b994651a..083ea25f 100644 --- a/compiler/lib/Dialect/LinalgExt/Transforms/LinalgGenericToSCF.cpp +++ b/compiler/lib/Dialect/LinalgExt/Transforms/LinalgGenericToSCF.cpp @@ -246,6 +246,69 @@ struct LinalgGenericToScfForPattern } }; +/// Pattern that matches: tensor.empty -> linalg.fill -> tensor.extract +/// and optimizes it to directly use the scalar value. +class SimplifySingleElementFillExtractPattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp extractOp, + PatternRewriter &rewriter) const override { + // 1. Verify all indices are constant zero + for (Value index : extractOp.getIndices()) { + auto constIndex = index.getDefiningOp(); + if (!constIndex || constIndex.value() != 0) { + return failure(); + } + } + + // 2. Get the source tensor and check if it comes from linalg.fill + auto fillOp = extractOp.getTensor().getDefiningOp(); + if (!fillOp) { + return failure(); + } + + // 3. Verify the filled tensor has exactly one element + auto filledTensorType = + dyn_cast(fillOp.getOutputs()[0].getType()); + if (!filledTensorType || filledTensorType.getNumElements() != 1) { + return failure(); + } + + // 4. Verify fill value is a scalar (not a tensor) + Value fillValue = fillOp.getInputs()[0]; + if (isa(fillValue.getType())) { + return failure(); + } + + // 5. For safety, ensure the filled tensor is only used by this extract + // This prevents breaking other uses of the same tensor + if (!fillOp.getResult(0).hasOneUse()) { + return failure(); + } + + // 6. Try to find the tensor.empty operation + // We'll check its uses after we replace the extract operation + auto emptyOp = + fillOp.getDpsInitOperand(0)->get().getDefiningOp(); + + // 7. Replace extract with the scalar fill value + rewriter.replaceOp(extractOp, fillValue); + + // 8. Now that extract is replaced, fillOp's result should have no uses + // We can safely erase it + rewriter.eraseOp(fillOp); + + // 9. After fillOp is erased, check if emptyOp still exists and has no uses + if (emptyOp && emptyOp->use_empty()) { + rewriter.eraseOp(emptyOp); + } + + return success(); + } +}; + struct LinalgGenericToSCFPass : mlir::dicp::LinalgExt::impl::LinalgGenericToSCFBase< LinalgGenericToSCFPass> { @@ -256,6 +319,7 @@ struct LinalgGenericToSCFPass { RewritePatternSet patterns(context); patterns.add(context); + patterns.add(context); populateRemoveSingleIterationLoopPattern(patterns); if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) { signalPassFailure(); diff --git a/compiler/lib/Utils/Utils.cpp b/compiler/lib/Utils/Utils.cpp index 89001d47..cced4713 100644 --- a/compiler/lib/Utils/Utils.cpp +++ b/compiler/lib/Utils/Utils.cpp @@ -34,6 +34,7 @@ #include #include #include +#include #define DEBUG_TYPE "Dicp-Utils" using namespace mlir; @@ -241,4 +242,266 @@ scf::ForOp createNestedLoops( return loop; } + +FailureOr specializeTypelessValueToAttr(TypelessValue value, + Type type, OpBuilder &b) { + // Common float and integer MLIR types used as map keys. + mlir::Type f16Ty = Float16Type::get(b.getContext()); + mlir::Type f32Ty = Float32Type::get(b.getContext()); + mlir::Type bf16Ty = BFloat16Type::get(b.getContext()); + + mlir::Type i8TySL = IntegerType::get( + b.getContext(), 8, IntegerType::SignednessSemantics::Signless); + mlir::Type i8TyS = IntegerType::get(b.getContext(), 8, + IntegerType::SignednessSemantics::Signed); + mlir::Type i8TyU = IntegerType::get( + b.getContext(), 8, IntegerType::SignednessSemantics::Unsigned); + + mlir::Type i16TySL = IntegerType::get( + b.getContext(), 16, IntegerType::SignednessSemantics::Signless); + mlir::Type i16TyS = IntegerType::get( + b.getContext(), 16, IntegerType::SignednessSemantics::Signed); + mlir::Type i16TyU = IntegerType::get( + b.getContext(), 16, IntegerType::SignednessSemantics::Unsigned); + + mlir::Type i32TySL = IntegerType::get( + b.getContext(), 32, IntegerType::SignednessSemantics::Signless); + mlir::Type i32TyS = IntegerType::get( + b.getContext(), 32, IntegerType::SignednessSemantics::Signed); + mlir::Type i32TyU = IntegerType::get( + b.getContext(), 32, IntegerType::SignednessSemantics::Unsigned); + + mlir::Type i64TySL = IntegerType::get( + b.getContext(), 64, IntegerType::SignednessSemantics::Signless); + mlir::Type i64TyS = IntegerType::get( + b.getContext(), 64, IntegerType::SignednessSemantics::Signed); + mlir::Type i64TyU = IntegerType::get( + b.getContext(), 64, IntegerType::SignednessSemantics::Unsigned); + + // Create APFloat values for float semantics (half, single, bfloat). + llvm::APFloat halfZero = llvm::APFloat::getZero(llvm::APFloat::IEEEhalf()); + llvm::APFloat halfOne(llvm::APFloat::IEEEhalf(), 1); + llvm::APFloat halfMax = llvm::APFloat::getInf(llvm::APFloat::IEEEhalf()); + llvm::APFloat halfMin = + llvm::APFloat::getInf(llvm::APFloat::IEEEhalf(), /*Negative=*/true); + + llvm::APFloat floatZero = llvm::APFloat::getZero(llvm::APFloat::IEEEsingle()); + llvm::APFloat floatOne(llvm::APFloat::IEEEsingle(), 1); + llvm::APFloat floatMax = llvm::APFloat::getInf(llvm::APFloat::IEEEsingle()); + llvm::APFloat floatMin = + llvm::APFloat::getInf(llvm::APFloat::IEEEsingle(), /*Negative=*/true); + + // BF16 (bfloat16) semantics via APFloat. + llvm::APFloat bfloatZero = llvm::APFloat::getZero(llvm::APFloat::BFloat()); + llvm::APFloat bfloatOne(llvm::APFloat::BFloat(), 1); + llvm::APFloat bfloatMax = llvm::APFloat::getInf(llvm::APFloat::BFloat()); + llvm::APFloat bfloatMin = + llvm::APFloat::getInf(llvm::APFloat::BFloat(), /*Negative=*/true); + + // Helper to use the opaque pointer of a Type as a stable key. + auto toPtr = [](mlir::Type ty) { return ty.getAsOpaquePointer(); }; + + // Store initialization values. Use signed and unsigned integer variants to + // avoid narrowing/overflow problems. + using InitValVariant = + std::variant; + + std::map, InitValVariant> initMap = { + // Zero values (floats and integers). + {{TypelessValue::Zero, toPtr(f16Ty)}, halfZero}, + {{TypelessValue::Zero, toPtr(f32Ty)}, floatZero}, + {{TypelessValue::Zero, toPtr(bf16Ty)}, bfloatZero}, + + {{TypelessValue::Zero, toPtr(i8TySL)}, (int8_t)0}, + {{TypelessValue::Zero, toPtr(i8TyS)}, (int8_t)0}, + {{TypelessValue::Zero, toPtr(i8TyU)}, (uint8_t)0}, + + {{TypelessValue::Zero, toPtr(i16TySL)}, (int16_t)0}, + {{TypelessValue::Zero, toPtr(i16TyS)}, (int16_t)0}, + {{TypelessValue::Zero, toPtr(i16TyU)}, (uint16_t)0}, + + {{TypelessValue::Zero, toPtr(i32TySL)}, (int32_t)0}, + {{TypelessValue::Zero, toPtr(i32TyS)}, (int32_t)0}, + {{TypelessValue::Zero, toPtr(i32TyU)}, (uint32_t)0}, + + {{TypelessValue::Zero, toPtr(i64TySL)}, (int64_t)0}, + {{TypelessValue::Zero, toPtr(i64TyS)}, (int64_t)0}, + {{TypelessValue::Zero, toPtr(i64TyU)}, (uint64_t)0}, + + // Min values (floats and integers). + {{TypelessValue::Min, toPtr(f16Ty)}, halfMin}, + {{TypelessValue::Min, toPtr(f32Ty)}, floatMin}, + {{TypelessValue::Min, toPtr(bf16Ty)}, bfloatMin}, + + {{TypelessValue::Min, toPtr(i8TySL)}, std::numeric_limits::min()}, + {{TypelessValue::Min, toPtr(i8TyS)}, std::numeric_limits::min()}, + {{TypelessValue::Min, toPtr(i8TyU)}, std::numeric_limits::min()}, + + {{TypelessValue::Min, toPtr(i16TySL)}, + std::numeric_limits::min()}, + {{TypelessValue::Min, toPtr(i16TyS)}, + std::numeric_limits::min()}, + {{TypelessValue::Min, toPtr(i16TyU)}, + std::numeric_limits::min()}, + + {{TypelessValue::Min, toPtr(i32TySL)}, + std::numeric_limits::min()}, + {{TypelessValue::Min, toPtr(i32TyS)}, + std::numeric_limits::min()}, + {{TypelessValue::Min, toPtr(i32TyU)}, + std::numeric_limits::min()}, + + {{TypelessValue::Min, toPtr(i64TySL)}, + std::numeric_limits::min()}, + {{TypelessValue::Min, toPtr(i64TyS)}, + std::numeric_limits::min()}, + {{TypelessValue::Min, toPtr(i64TyU)}, + std::numeric_limits::min()}, // 0 + + // Max values (floats and integers). + {{TypelessValue::Max, toPtr(f16Ty)}, halfMax}, + {{TypelessValue::Max, toPtr(f32Ty)}, floatMax}, + {{TypelessValue::Max, toPtr(bf16Ty)}, bfloatMax}, + + {{TypelessValue::Max, toPtr(i8TySL)}, std::numeric_limits::max()}, + {{TypelessValue::Max, toPtr(i8TyS)}, std::numeric_limits::max()}, + {{TypelessValue::Max, toPtr(i8TyU)}, std::numeric_limits::max()}, + + {{TypelessValue::Max, toPtr(i16TySL)}, + std::numeric_limits::max()}, + {{TypelessValue::Max, toPtr(i16TyS)}, + std::numeric_limits::max()}, + {{TypelessValue::Max, toPtr(i16TyU)}, + std::numeric_limits::max()}, + + {{TypelessValue::Max, toPtr(i32TySL)}, + std::numeric_limits::max()}, + {{TypelessValue::Max, toPtr(i32TyS)}, + std::numeric_limits::max()}, + {{TypelessValue::Max, toPtr(i32TyU)}, + std::numeric_limits::max()}, + + {{TypelessValue::Max, toPtr(i64TySL)}, + std::numeric_limits::max()}, + {{TypelessValue::Max, toPtr(i64TyS)}, + std::numeric_limits::max()}, + {{TypelessValue::Max, toPtr(i64TyU)}, + std::numeric_limits::max()}, + }; + + // Lookup key for the requested typeless value + concrete type. + std::pair key = + std::make_pair(value, toPtr(type)); + auto it = initMap.find(key); + if (it == initMap.end()) + return failure(); + + // Integer handling: prefer using the provided 'type' for IntegerAttr so + // signedness/width are preserved. + if (type.isInteger(8) || type.isInteger(16) || type.isInteger(32) || + type.isInteger(64)) { + unsigned bitWidth = type.getIntOrFloatBitWidth(); + + // Signed integers: extract signed variant and create IntegerAttr directly. + if (type.isSignedInteger(bitWidth)) { + switch (bitWidth) { + case 8: + return success(IntegerAttr::get(type, std::get(it->second))); + case 16: + return success(IntegerAttr::get(type, std::get(it->second))); + case 32: + return success(IntegerAttr::get(type, std::get(it->second))); + case 64: + return success(IntegerAttr::get(type, std::get(it->second))); + default: + return failure(); + } + } + + // Unsigned integers: extract unsigned variant. For 64-bit unsigned use + // APInt to avoid overflow of signed int64_t. + if (type.isUnsignedInteger(bitWidth)) { + switch (bitWidth) { + case 8: + return success(IntegerAttr::get( + type, static_cast(std::get(it->second)))); + case 16: + return success(IntegerAttr::get( + type, static_cast(std::get(it->second)))); + case 32: + return success(IntegerAttr::get( + type, static_cast(std::get(it->second)))); + case 64: { + uint64_t uval = std::get(it->second); + llvm::APInt apv(/*numBits=*/64, uval, /*isSigned=*/false); + return success(IntegerAttr::get(type, apv)); + } + default: + return failure(); + } + } + + // Signless integers: treat as signless using the signed variants (original + // code used signless integers everywhere for constants). + switch (bitWidth) { + case 8: + return success(IntegerAttr::get(type, std::get(it->second))); + case 16: + return success(IntegerAttr::get(type, std::get(it->second))); + case 32: + return success(IntegerAttr::get(type, std::get(it->second))); + case 64: + return success(IntegerAttr::get(type, std::get(it->second))); + default: + return failure(); + } + } + + // Floating-point handling (half, bf16, single). + if (isa(type)) + return success(FloatAttr::get(f16Ty, std::get(it->second))); + if (isa(type)) + return success(FloatAttr::get(f32Ty, std::get(it->second))); + if (isa(type)) + return success(FloatAttr::get(bf16Ty, std::get(it->second))); + + return failure(); +} + +// Specialize the Typeless Value (Zero, Min, Max) into a mlir constant value +FailureOr specializeTypelessValueToConstant(TypelessValue value, + Type type, Location loc, + OpBuilder &b) { + std::function getElemType = [&](mlir::Type ty) { + if (auto ptrType = dyn_cast(getElementTypeOrSelf(ty))) + return getElemType(ptrType.getPointeeType()); + if (auto tensorType = mlir::dyn_cast(ty)) + return getElemType(tensorType.getElementType()); + return ty; + }; + + if (value == TypelessValue::Undefined) + return failure(); + if (auto tensorType = mlir::dyn_cast(type)) { + auto elemType = getElemType(tensorType); + FailureOr typedAttr = + specializeTypelessValueToAttr(value, elemType, b); + if (failed(typedAttr)) + return failure(); + auto otherTensorType = + RankedTensorType::get(tensorType.getShape(), elemType); + auto denseAttr = DenseElementsAttr::get(otherTensorType, *typedAttr); + return b.create(loc, denseAttr).getResult(); + } + if (mlir::isa(type) || mlir::isa(type)) { + FailureOr typedAttr = + specializeTypelessValueToAttr(value, type, b); + if (failed(typedAttr)) + return failure(); + return b.create(loc, *typedAttr).getResult(); + } + return failure(); +} + } // namespace mlir::dicp diff --git a/patch/ttshared/triton_shared.patch b/patch/ttshared/triton_shared.patch index 85a06d17..b2f6c863 100644 --- a/patch/ttshared/triton_shared.patch +++ b/patch/ttshared/triton_shared.patch @@ -1,3 +1,18 @@ +diff --git a/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td b/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td +index 17c3ce9..30b88eb 100644 +--- a/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td ++++ b/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td +@@ -233,8 +233,8 @@ def TTS_GetStructuredStateOp : TTS_Op<"get_structured_state", [AttrSizedResultSe + let summary = "Placeholder for the structured pointer states computed during PtrAnalysis."; + let description = "Used to pass the offsets and strides to scf.for op to simplify IR rewrites."; + +- let arguments = (ins AnyTypeOf<[TT_PtrLike, I32Tensor]>:$input); +- let results = (outs AnyTypeOf<[TT_PtrLike, I32Tensor]>:$structured, Variadic:$offsets, Variadic:$strides); ++ let arguments = (ins AnyTypeOf<[TT_PtrLike, TT_IndexTensorLike]>:$input); ++ let results = (outs AnyTypeOf<[TT_PtrLike, TT_IndexTensorLike]>:$structured, Variadic:$offsets, Variadic:$strides); + + let builders = [ + OpBuilder<(ins "Value":$input)>, diff --git a/lib/Analysis/UseAnalysis.cpp b/lib/Analysis/UseAnalysis.cpp index 62e4508..2d81db4 100644 --- a/lib/Analysis/UseAnalysis.cpp diff --git a/test/ascend/failed_tests/test_ldst.py b/test/ascend/failed_tests/test_ldst.py deleted file mode 100644 index 02b0f977..00000000 --- a/test/ascend/failed_tests/test_ldst.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import torch, torch_npu -import triton -import triton.language as tl -import triton.language.math as tl_math -import pytest - - -def test_ldst_indirect_03(): - - @triton.jit - def triton_ldst_indirect_03_kernel(out_ptr0, in_ptr0, in_ptr1, XS: tl.constexpr): - pid = tl.program_id(0) - in_idx0 = pid * XS + tl.arange(0, XS) - tmp0 = tl.load(in_ptr0 + in_idx0) - tmp1 = tl.load(in_ptr1 + tmp0) - tmp2 = tl_math.exp(tmp1) - out0_idx = pid * XS + tl.arange(0, XS) - tl.store(out_ptr0 + out0_idx, tmp2) - - def triton_ldst_indirect_03_func(x0, x1, xs): - n0 = x0.numel() - assert n0 == xs, "test only single core" - y0 = torch.empty((n0,), dtype=x1.dtype, device=x1.device) - triton_ldst_indirect_03_kernel[n0 // xs, 1, 1](y0, x0, x1, XS=xs) - return y0 - - def torch_ldst_indirect_03_func(x0, x1): - return torch.exp(x1[x0]) - - DEV = "npu" - DTYPE = torch.float32 - offset = 8 - N0, N1 = 16, 32 - blocksize = 16 - assert N1 >= N0 + offset, "N1 must be >= N0+offset" - assert N0 == blocksize, "N0 must be == blocksize" - x0 = offset + torch.arange(0, N0, device=DEV) # int64 - x1 = torch.randn((N1,), dtype=DTYPE, device=DEV) - torch_ref = torch_ldst_indirect_03_func(x0, x1) - triton_cal = triton_ldst_indirect_03_func(x0, x1, blocksize) - torch.testing.assert_close(triton_cal, torch_ref) - - -def test_ldst_indirect_04(): - - @triton.jit - def triton_ldst_indirect_04_kernel(out_ptr0, in_ptr0, in_ptr1, XS: tl.constexpr): - pid = tl.program_id(0) - in_idx0 = pid * XS + tl.arange(0, XS) - tmp0 = tl.load(in_ptr0 + in_idx0) - tmp0min = tl.min(tmp0, axis=0) - tmp0max = tl.max(tmp0, axis=0) - tmp0 = tmp0 * 2.0 - tmp0 = tl.clamp(tmp0, tmp0min, tmp0max) - tmp0 = tmp0.to(tl.int32) - tmp1 = tl.load(in_ptr1 + tmp0) - tmp2 = tl_math.exp(tmp1) - out0_idx = pid * XS + tl.arange(0, XS) - tl.store(out_ptr0 + out0_idx, tmp2) - - def triton_ldst_indirect_04_func(x0, x1, xs): - n0 = x0.numel() - assert n0 == xs, "test only single core" - y0 = torch.empty((n0,), dtype=x1.dtype, device=x1.device) - triton_ldst_indirect_04_kernel[n0 // xs, 1, 1](y0, x0, x1, XS=xs) - return y0 - - def torch_ldst_indirect_04_func(x0, x1): - x0min = torch.min(x0) - x0max = torch.max(x0) - idx = torch.clamp(x0 * 2, x0min, x0max) - return torch.exp(x1[idx.to(torch.int32)]) - - DEV = "npu" - DTYPE = torch.float32 - offset = 8 - N0, N1 = 16, 32 - blocksize = 16 - assert N1 >= N0 + offset, "N1 must be >= N0+offset" - assert N0 == blocksize, "N0 must be == blocksize" - x0 = offset + torch.arange(0, N0, dtype=torch.float32, device=DEV) - x1 = torch.randn((N1,), dtype=DTYPE, device=DEV) - torch_ref = torch_ldst_indirect_04_func(x0, x1) - triton_cal = triton_ldst_indirect_04_func(x0, x1, blocksize) - torch.testing.assert_close(triton_cal, torch_ref) - - -def test_ldst_indirect_05(): - - @triton.jit - def triton_ldst_indirect_05_kernel( - out_ptr0, in_ptr1, in_ptr2, stride_in_r, XS: tl.constexpr, RS: tl.constexpr - ): - pid = tl.program_id(0) - in_idx0 = pid * XS + tl.arange(0, XS) - in_idx1 = tl.arange(0, RS) - tmp0 = tl.arange(0, XS) - tmp1 = tl.load(in_ptr1 + in_idx1) - in_idx2 = tmp0[:, None] * stride_in_r + tmp1[None, :] - tmp2 = tl.load(in_ptr2 + in_idx2) - tmp2 = tl_math.exp(tmp2) - out0_idx = in_idx0[:, None] * RS + in_idx1[None, :] - tl.store(out_ptr0 + out0_idx, tmp2) - - def triton_ldst_indirect_05_func(xc, x2, xs, rs): - nr = x2.size()[0] - nc = xc.numel() - stride_in_r = x2.stride()[0] - assert nr == xs, "test only single core" - y0 = torch.empty((nr, nc), dtype=x2.dtype, device=x2.device) - triton_ldst_indirect_05_kernel[nr // xs, 1, 1]( - y0, xc, x2, stride_in_r, XS=xs, RS=rs - ) - return y0 - - def torch_ldst_indirect_05_func(xr, xc, x2): - flatten_idx = (xr[:, None] * x2.stride()[0] + xc[None, :]).flatten() - extracted = x2.flatten()[flatten_idx].reshape([xr.numel(), xc.numel()]) - return torch.exp(extracted) - - DEV = "npu" - DTYPE = torch.float32 - offset = 8 - N0, N1 = 16, 32 - blocksize = 8 - lowdimsize = N0 - assert N1 >= N0 + offset, "N1 must be >= N0+offset" - assert N0 == lowdimsize, "N0 must be == lowdimsize" - xc = offset + torch.arange(0, N0, device=DEV) - xr = torch.arange(0, blocksize, device=DEV) - x2 = torch.randn((blocksize, N1), dtype=DTYPE, device=DEV) - torch_ref = torch_ldst_indirect_05_func(xr, xc, x2) - triton_cal = triton_ldst_indirect_05_func(xc, x2, blocksize, lowdimsize) - torch.testing.assert_close(triton_cal, torch_ref) - - -if __name__ == "__main__": - test_ldst_indirect_03() - print("success: test_ldst_indirect_05") diff --git a/test/ascend/passed_tests/test_fused_rms_norm_rope.py b/test/ascend/passed_tests/test_fused_rms_norm_rope.py new file mode 100644 index 00000000..74cf64ea --- /dev/null +++ b/test/ascend/passed_tests/test_fused_rms_norm_rope.py @@ -0,0 +1,456 @@ +import torch +import triton +import triton.language as tl +import triton.language.extra.deeplink as dl +import pytest + + +def _compute_inv_freq(base: float, head_size) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / ( + base + ** (torch.arange(0, head_size, 2, dtype=torch.float, device="npu") / head_size) + ) + return inv_freq + + +def _compute_cos_sin_cache(head_size) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = _compute_inv_freq(10000.0, head_size) + t = torch.arange(4096, dtype=torch.float32, device="npu") + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + +@triton.jit +def _compute_rotary_emb( + x1, + x2, + cos, + sin, +): + cos = tl.expand_dims(cos, -2) + sin = tl.expand_dims(sin, -2) + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + return o1, o2 + + +@triton.jit +def rms_norm_rope_kernel( + q1, + q2, + k1, + k2, + v, + weight, + cos, + sin, + q1_out, + q2_out, + k1_out, + k2_out, + v_out, + q_size: tl.constexpr, + kv_size: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + SUB_BLK: tl.constexpr, +): + id = tl.program_id(0) + num_q_offsets = tl.arange(0, q_size // head_dim) + num_kv_offsets = tl.arange(0, kv_size // head_dim) + head_offsets = tl.arange(0, head_dim) + half_head_offsets = tl.arange(0, head_dim // 2) + num_token_offsets = tl.arange(0, BLOCK_SIZE) + id * BLOCK_SIZE + half_q_size = q_size // 2 + half_kv_size = kv_size // 2 + half_dim = head_dim // 2 + + q1_data = tl.load( + q1 + + num_token_offsets[:, None, None] * half_q_size + + num_q_offsets[None, :, None] * half_dim + + half_head_offsets[None, None, :], + ) + q2_data = tl.load( + q2 + + num_token_offsets[:, None, None] * half_q_size + + num_q_offsets[None, :, None] * half_dim + + half_head_offsets[None, None, :], + ) + k1_data = tl.load( + k1 + + num_token_offsets[:, None, None] * half_kv_size + + num_kv_offsets[None, :, None] * half_dim + + half_head_offsets[None, None, :], + ) + k2_data = tl.load( + k2 + + num_token_offsets[:, None, None] * half_kv_size + + num_kv_offsets[None, :, None] * half_dim + + half_head_offsets[None, None, :], + ) + v_data = tl.load( + v + + num_token_offsets[:, None, None] * kv_size + + num_kv_offsets[None, :, None] * head_dim + + head_offsets[None, None, :] + ) + cos_data = tl.load( + cos + num_token_offsets[:, None] * head_dim // 2 + half_head_offsets[None, :] + ) + sin_data = tl.load( + sin + num_token_offsets[:, None] * head_dim // 2 + half_head_offsets[None, :] + ) + weight_data = tl.load(weight + half_head_offsets) + + for s in dl.parallel(0, 2, bind_sub_block=True): + q1_sub_data = dl.extract_slice( + q1_data, + (s * SUB_BLK, 0, 0), + (SUB_BLK, q_size // head_dim, head_dim // 2), + (1, 1, 1), + ) + q2_sub_data = dl.extract_slice( + q2_data, + (s * SUB_BLK, 0, 0), + (SUB_BLK, q_size // head_dim, head_dim // 2), + (1, 1, 1), + ) + k1_sub_data = dl.extract_slice( + k1_data, + (s * SUB_BLK, 0, 0), + (SUB_BLK, kv_size // head_dim, head_dim // 2), + (1, 1, 1), + ) + k2_sub_data = dl.extract_slice( + k2_data, + (s * SUB_BLK, 0, 0), + (SUB_BLK, kv_size // head_dim, head_dim // 2), + (1, 1, 1), + ) + cos_sub_data = dl.extract_slice( + cos_data, (s * SUB_BLK, 0), (SUB_BLK, head_dim // 2), (1, 1) + ) + sin_sub_data = dl.extract_slice( + sin_data, (s * SUB_BLK, 0), (SUB_BLK, head_dim // 2), (1, 1) + ) + + # rms norm + var_q1 = tl.sum(q1_sub_data * q1_sub_data, -1) / head_dim + var_q2 = tl.sum(q2_sub_data * q2_sub_data, -1) / head_dim + var_q = var_q1 + var_q2 + var_q = tl.expand_dims(var_q, -1) + q1_sub_data = q1_sub_data * tl.math.rsqrt(var_q + 1e-5) + q2_sub_data = q2_sub_data * tl.math.rsqrt(var_q + 1e-5) + q1_sub_data = weight_data * q1_sub_data + q2_sub_data = weight_data * q2_sub_data + var_k1 = tl.sum(k1_sub_data * k1_sub_data, axis=-1) / head_dim + var_k2 = tl.sum(k2_sub_data * k2_sub_data, axis=-1) / head_dim + var_k = var_k1 + var_k2 + var_k = tl.expand_dims(var_k, -1) + k1_sub_data = k1_sub_data * tl.math.rsqrt(var_k + 1e-5) + k2_sub_data = k2_sub_data * tl.math.rsqrt(var_k + 1e-5) + k1_sub_data = weight_data * k1_sub_data + k2_sub_data = weight_data * k2_sub_data + + # rotary embedding + q1_rope, q2_rope = _compute_rotary_emb( + q1_sub_data, q2_sub_data, cos_sub_data, sin_sub_data + ) + k1_rope, k2_rope = _compute_rotary_emb( + k1_sub_data, k2_sub_data, cos_sub_data, sin_sub_data + ) + num_token_sub_offsets = tl.arange(0, SUB_BLK) + id * BLOCK_SIZE + s * SUB_BLK + + tl.store( + q1_out + + num_token_sub_offsets[:, None, None] * half_q_size + + num_q_offsets[None, :, None] * half_dim + + half_head_offsets[None, None, :], + q1_rope, + ) + tl.store( + q2_out + + num_token_sub_offsets[:, None, None] * half_q_size + + num_q_offsets[None, :, None] * half_dim + + half_head_offsets[None, None, :], + q2_rope, + ) + tl.store( + k1_out + + num_token_sub_offsets[:, None, None] * half_kv_size + + num_kv_offsets[None, :, None] * half_dim + + half_head_offsets[None, None, :], + k1_rope, + ) + tl.store( + k2_out + + num_token_sub_offsets[:, None, None] * half_kv_size + + num_kv_offsets[None, :, None] * half_dim + + half_head_offsets[None, None, :], + k2_rope, + ) + + tl.store( + v_out + + num_token_offsets[:, None, None] * kv_size + + num_kv_offsets[None, :, None] * head_dim + + head_offsets[None, None, :], + v_data, + ) + + +def rms_norm_rope( + qkv, + positions, + num_heads_q, + num_heads_kv, + head_dim, + num_tokens, + BLOCK_SIZE, +): + q_size = num_heads_q * head_dim + kv_size = num_heads_kv * head_dim + + q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1) + q1, q2 = q.view(num_tokens, num_heads_q, head_dim).chunk(2, dim=-1) + k1, k2 = k.view(num_tokens, num_heads_kv, head_dim).chunk(2, dim=-1) + weight = torch.ones(head_dim // 2, dtype=torch.float32, device="npu") + cache = _compute_cos_sin_cache(head_dim) + + positions = positions.flatten() + cos_sin = cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + q1_out = torch.empty( + (num_tokens, num_heads_q, head_dim // 2), dtype=torch.float32, device="npu" + ) + q2_out = torch.empty( + (num_tokens, num_heads_q, head_dim // 2), dtype=torch.float32, device="npu" + ) + k1_out = torch.empty( + (num_tokens, num_heads_kv, head_dim // 2), dtype=torch.float32, device="npu" + ) + k2_out = torch.empty( + (num_tokens, num_heads_kv, head_dim // 2), dtype=torch.float32, device="npu" + ) + v_out = torch.empty( + (num_tokens, num_heads_kv, head_dim), dtype=torch.float32, device="npu" + ) + grid = (num_tokens // BLOCK_SIZE,) + + rms_norm_rope_kernel[grid]( + q1.contiguous(), + q2.contiguous(), + k1.contiguous(), + k2.contiguous(), + v.contiguous(), + weight.contiguous(), + cos.contiguous(), + sin.contiguous(), + q1_out, + q2_out, + k1_out, + k2_out, + v_out, + q_size, + kv_size, + head_dim, + BLOCK_SIZE, + BLOCK_SIZE // 2, + ) + q_out = torch.cat([q1_out, q2_out], dim=-1) + k_out = torch.cat([k1_out, k2_out], dim=-1) + q_out = q_out.view(num_tokens, q_size) + k_out = k_out.view(num_tokens, kv_size) + v_out = v_out.view(num_tokens, kv_size) + return torch.cat([q_out, k_out, v_out], dim=-1) + + +# ------------------ PyTorch 参考实现(用于验证精度) ------------------ +def _apply_qk_norm_rope(qkv, positions, num_heads_q, num_heads_kv, head_dim): + """Pure PyTorch implementation mirroring rms_norm_rope for correctness checks. + + Args: + qkv: (num_tokens, total_dim) + positions: (num_tokens,) + num_heads_q, num_heads_kv, head_dim: ints + + Returns: + concatenated tensor of shape (num_tokens, q_size + kv_size + kv_size) + """ + device = qkv.device + num_tokens = qkv.shape[0] + q_size = num_heads_q * head_dim + kv_size = num_heads_kv * head_dim + + q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1) + + # reshape + q = q.view(num_tokens, num_heads_q, head_dim) + k = k.view(num_tokens, num_heads_kv, head_dim) + + # split halves + q1, q2 = q.chunk(2, dim=-1) + k1, k2 = k.chunk(2, dim=-1) + + # RMS norm scale (same as kernel: sum of squares over full head_dim / head_dim) + var_q = (q1.pow(2).sum(-1) + q2.pow(2).sum(-1)) / head_dim + scale_q = torch.rsqrt(var_q + 1e-5).unsqueeze(-1) + q1 = q1 * scale_q + q2 = q2 * scale_q + + var_k = (k1.pow(2).sum(-1) + k2.pow(2).sum(-1)) / head_dim + scale_k = torch.rsqrt(var_k + 1e-5).unsqueeze(-1) + k1 = k1 * scale_k + k2 = k2 * scale_k + + # weight (the kernel uses a per-dim weight of ones) + weight = torch.ones(head_dim // 2, dtype=torch.float32, device=device) + q1 = q1 * weight + q2 = q2 * weight + k1 = k1 * weight + k2 = k2 * weight + + # rotary + cache = _compute_cos_sin_cache(head_dim) + cos_sin = cache.index_select(0, positions.flatten()) + cos, sin = cos_sin.chunk(2, dim=-1) + # reshape cos/sin to broadcast: (num_tokens, 1, head_dim//2) + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) + + q1_rope = q1 * cos - q2 * sin + q2_rope = q2 * cos + q1 * sin + k1_rope = k1 * cos - k2 * sin + k2_rope = k2 * cos + k1 * sin + + # reconstruct + q_out = torch.cat([q1_rope, q2_rope], dim=-1).view(num_tokens, q_size) + k_out = torch.cat([k1_rope, k2_rope], dim=-1).view(num_tokens, kv_size) + v_out = v # v is unchanged + + return torch.cat([q_out, k_out, v_out], dim=-1) + + +# ------------------ 测试与基准 ------------------ + + +def test_rms_norm_rope(): + """test rms norm rope and benchmark compared to the Triton kernel.""" + num_heads, num_kv_heads, head_dim = 16, 4, 128 + num_tokens = 4 + + total_dim = (num_heads + 2 * num_kv_heads) * head_dim + qkv_base = torch.randn(num_tokens, total_dim, dtype=torch.float32, device="npu") + qkv_base1 = qkv_base.clone() + positions = torch.arange(num_tokens, dtype=torch.long, device="npu") + positions1 = positions.clone() + + torch_output = _apply_qk_norm_rope( + qkv=qkv_base, + positions=positions, + num_heads_q=num_heads, + num_heads_kv=num_kv_heads, + head_dim=head_dim, + ) + + triton_output = rms_norm_rope( + qkv=qkv_base1, + positions=positions1, + num_heads_q=num_heads, + num_heads_kv=num_kv_heads, + head_dim=head_dim, + num_tokens=num_tokens, + BLOCK_SIZE=2, + ) + assert torch.allclose(torch_output, triton_output, atol=1e-2, rtol=0) + print("test rms_norm_rope passed!") + + # 性能测试部分 + try: + import triton.testing as tt + + def benchmark_fn(fn, *args): + return tt.do_bench(lambda: fn(*args), warmup=10, rep=20) + + # Triton 版本性能 + tri_time = benchmark_fn( + rms_norm_rope, + qkv_base1, + positions1, + num_heads, + num_kv_heads, + head_dim, + num_tokens, + 2, + ) + + # PyTorch 版本性能 + torch_time = benchmark_fn( + _apply_qk_norm_rope, qkv_base, positions, num_heads, num_kv_heads, head_dim + ) + + # 打印性能对比结果 + print(f"\n=== 性能对比 ===") + print( + f"Triton: {tri_time:.4f} ms | PyTorch: {torch_time:.4f} ms | 加速比: {torch_time/tri_time:.2f}x" + ) + except Exception: + print("triton.testing.do_bench unavailable: 跳过基准测试。") + + +@pytest.mark.parametrize( + "num_heads,num_kv_heads,head_dim,num_tokens,BLOCK_SIZE", + [ + (16, 4, 128, 4, 2), + (8, 2, 64, 8, 4), + (32, 16, 64, 256, 4), + (48, 16, 64, 192, 4), + ], +) +def test_rms_norm_rope_correctness( + num_heads, + num_kv_heads, + head_dim, + num_tokens, + BLOCK_SIZE, +): + total_dim = (num_heads + 2 * num_kv_heads) * head_dim + + qkv = torch.randn(num_tokens, total_dim, dtype=torch.float32, device="npu") + positions = torch.arange(num_tokens, dtype=torch.long, device="npu") + + torch_out = _apply_qk_norm_rope( + qkv=qkv, + positions=positions, + num_heads_q=num_heads, + num_heads_kv=num_kv_heads, + head_dim=head_dim, + ) + + triton_out = rms_norm_rope( + qkv=qkv.clone(), + positions=positions.clone(), + num_heads_q=num_heads, + num_heads_kv=num_kv_heads, + head_dim=head_dim, + num_tokens=num_tokens, + BLOCK_SIZE=BLOCK_SIZE, + ) + + assert torch.allclose(torch_out, triton_out, atol=1e-4, rtol=1e-4) + + +if __name__ == "__main__": + test_rms_norm_rope() diff --git a/test/ascend/passed_tests/test_ldst.py b/test/ascend/passed_tests/test_ldst.py index 22a09195..f99e1b61 100644 --- a/test/ascend/passed_tests/test_ldst.py +++ b/test/ascend/passed_tests/test_ldst.py @@ -23,6 +23,8 @@ import triton.language as tl import triton.language.math as tl_math import pytest +import test_common +import random def test_ldst_indirect_00(): @@ -140,7 +142,6 @@ def torch_ldst_indirect_02_func(x0, x1): torch.testing.assert_close(triton_cal, torch_ref) -@pytest.mark.skip(reason="runtime error. Waiting for Huawei to fix it.") def test_ldst_indirect_03(): @triton.jit @@ -177,7 +178,6 @@ def torch_ldst_indirect_03_func(x0, x1): torch.testing.assert_close(triton_cal, torch_ref) -@pytest.mark.skip(reason="runtime error. Waiting for Huawei to fix it.") def test_ldst_indirect_04(): @triton.jit @@ -222,7 +222,6 @@ def torch_ldst_indirect_04_func(x0, x1): torch.testing.assert_close(triton_cal, torch_ref) -@pytest.mark.skip(reason="runtime error. Waiting for Huawei to fix it.") def test_ldst_indirect_05(): @triton.jit @@ -381,48 +380,45 @@ def torch_ldst_indirect_07_func(xr, xc, x2): torch.testing.assert_close(triton_cal, torch_ref) -@pytest.mark.skip(reason="Indirect store to be supported") def test_ldst_indirect_08(): @triton.jit def triton_ldst_indirect_08_kernel( out_ptr0, - in_ptr1, - in_ptr2, - in_ptr3, + in_ptr_xc, + in_ptr_x2, stride_in_r, + OUT_COLS: tl.constexpr, XS: tl.constexpr, RS: tl.constexpr, ): pid = tl.program_id(0) - in_idx0 = pid * XS + tl.arange(0, XS) - in_idx1 = tl.arange(0, RS) - tmp0 = tl.arange(0, XS) - tmp1 = tl.load(in_ptr1 + in_idx1) - in_idx2 = tmp0[:, None] * stride_in_r + tmp1[None, :] - tmp2 = tl.load(in_ptr2 + in_idx2) - tmp2 = tl_math.exp(tmp2) - tmp3 = tl.load(in_ptr3 + in_idx1) - tmp3 = tmp3 + 1 - out0_idx = in_idx0[:, None] * RS + tmp3[None, :] - tl.store(out_ptr0 + out0_idx, tmp2) + row_idx_full = pid * XS + tl.arange(0, XS) + col_pos = tl.arange(0, RS) + xc_vals = tl.load(in_ptr_xc + col_pos) + row_arange = tl.arange(0, XS) + gather_flat = row_arange[:, None] * stride_in_r + xc_vals[None, :] + vals = tl.load(in_ptr_x2 + gather_flat) + vals = tl_math.exp(vals) + out_flat = row_idx_full[:, None] * OUT_COLS + xc_vals[None, :] + tl.store(out_ptr0 + out_flat, vals) def triton_ldst_indirect_08_func(xc, x2, xs, rs): - nr = x2.size()[0] - nc = xc.numel() - stride_in_r = x2.stride()[0] + nr = x2.size(0) + out_cols = x2.size(1) + stride_in_r = x2.stride(0) assert nr == xs, "test only single core" - y0 = torch.empty((nr, nc), dtype=x2.dtype, device=x2.device) - xc1 = xc - 1 + y0 = torch.zeros((nr, out_cols), dtype=x2.dtype, device=x2.device) triton_ldst_indirect_08_kernel[nr // xs, 1, 1]( - y0, xc, x2, xc1, stride_in_r, XS=xs, RS=rs + y0, xc, x2, stride_in_r, OUT_COLS=out_cols, XS=xs, RS=rs ) return y0 def torch_ldst_indirect_08_func(xr, xc, x2): - flatten_idx = (xr[:, None] * x2.stride()[0] + xc[None, :]).flatten() - extracted = x2.flatten()[flatten_idx].reshape([xr.numel(), xc.numel()]) - return torch.exp(extracted) + out = torch.zeros((xr.numel(), x2.size(1)), dtype=x2.dtype, device=x2.device) + gathered = torch.exp(x2[xr[:, None], xc[None, :]]) + out.scatter_(1, xc.expand(xr.numel(), -1), gathered) + return out DEV = "npu" DTYPE = torch.float32 @@ -494,6 +490,101 @@ def torch_ldst_indirect_09_func(xr, xc, x2): torch.testing.assert_close(triton_cal, torch_ref) +@triton.jit +def unstructured_mask_2d_kernel( + in_ptr, out_ptr, mask_m_ptr, mask_n_ptr, m, n, M: tl.constexpr, N: tl.constexpr +): + offs_m = tl.arange(0, M) + offs_n = tl.arange(0, N) + + mask_m = tl.load(mask_m_ptr + offs_m, mask=offs_m < m, other=0) != 0 + mask_n = tl.load(mask_n_ptr + offs_n, mask=offs_n < n, other=0) != 0 + + in_ptrs = in_ptr + offs_m[:, None] * N + offs_n[None, :] + # dim 0 with unstructured mask. + v = tl.load(in_ptrs, mask=mask_m[:, None] and offs_n[None, :] < n, other=-2) + out_ptrs = out_ptr + offs_m[:, None] * N + offs_n[None, :] + # dim 1 with unstructured mask. + tl.store(out_ptrs, v, mask=offs_m[:, None] < m and mask_n[None, :]) + + +# helper to get torch dtype from string +def torch_dtype(dtype_str): + return eval(f"torch.{dtype_str}") + + +@pytest.mark.parametrize( + "param_list", + [ + ["float32", (8, 16)], + ], +) +def test_unstructured_mask_2d(param_list): + dtype_str, shape = param_list + dtype = torch_dtype(dtype_str) + M, N = shape + + # make deterministic + random.seed(0) + torch.manual_seed(0) + + # input: use distinct values per element for easy checking + # use arange and cast to dtype + total = M * N + if dtype.is_floating_point: + in_tensor = ( + torch.arange(total, dtype=torch.float32).reshape(M, N).to(dtype).npu() + ) + else: + in_tensor = torch.arange(total, dtype=torch.int64).reshape(M, N).to(dtype).npu() + + # masks: random 0/1 tensors (1D) + mask_m = torch.randint(0, 2, (M,), dtype=torch.int32).npu() # rows + mask_n = torch.randint(0, 2, (N,), dtype=torch.int32).npu() # cols + + # out: initialize with a sentinel so we can tell which positions are untouched + if dtype.is_floating_point: + sentinel = torch.tensor(-999.0, dtype=torch.float32).to(dtype) + else: + sentinel = torch.tensor(-999, dtype=torch.int64).to(dtype) + + out_init = torch.full((M, N), sentinel.item(), dtype=dtype).npu() + out = out_init.clone() + + # call kernel: single program covering full matrix; M,N passed as constexpr + # signature: (in_ptr, out_ptr, mask_m_ptr, mask_n_ptr, m, n, M:tl.constexpr, N:tl.constexpr) + # set m=M, n=N to simplify masks (see analysis) + unstructured_mask_2d_kernel[1, 1](in_tensor, out, mask_m, mask_n, M, N, M=M, N=N) + + # construct reference output according to kernel logic described in analysis: + # when mask_n[j] == 0 -> out should remain the initial sentinel (kernel does not store) + # when mask_n[j] == 1: + # if mask_m[i] == 1 -> out[i,j] == in[i,j] + # else -> out[i,j] == -2 + expected = out_init.clone() + for i in range(M): + row_mask = bool(mask_m[i].item()) + for j in range(N): + col_mask = bool(mask_n[j].item()) + if not col_mask: + # kernel does not store here; keep initial sentinel + expected[i, j] = out_init[i, j] + else: + if row_mask: + expected[i, j] = in_tensor[i, j] + else: + # -2 with the same dtype + if dtype.is_floating_point: + expected[i, j] = torch.tensor(-2.0, dtype=torch.float32).to( + dtype + ) + else: + expected[i, j] = torch.tensor(-2, dtype=expected.dtype) + + # validate using project's common validator + test_common.validate_cmp(dtype_str, out, expected) + + if __name__ == "__main__": - test_ldst_indirect_05() + test_ldst_indirect_08() print("success: test_ldst_indirect_05") diff --git a/test/ascend/passed_tests/test_load_store.py b/test/ascend/passed_tests/test_load_store.py index 797406bd..04390b2b 100644 --- a/test/ascend/passed_tests/test_load_store.py +++ b/test/ascend/passed_tests/test_load_store.py @@ -213,7 +213,7 @@ def test_load_store_sle_mask(param_list): ["int8", (8, 8, 4), 2, 128, 64], ], ) -def test_load_store_sle_mask(param_list): +def test_load_store_sge_mask(param_list): # 生成数据 dtype, shape, ncore, xblock, xblock_sub = param_list x0 = test_common.generate_tensor(shape, dtype).npu() diff --git a/tools/dicp_triton_opt/CMakeLists.txt b/tools/dicp_triton_opt/CMakeLists.txt index 02597a0a..bd00d88a 100644 --- a/tools/dicp_triton_opt/CMakeLists.txt +++ b/tools/dicp_triton_opt/CMakeLists.txt @@ -21,6 +21,7 @@ target_link_libraries(dicp_opt PRIVATE LinalgExtTransforms LinkedToHIVM DICPLinalgExt + DiscreteMaskAccessConversion TritonToLinalg TritonTilingExtIR diff --git a/tools/dicp_triton_opt/dicp_triton_opt.cpp b/tools/dicp_triton_opt/dicp_triton_opt.cpp index fe87b666..dccd884e 100644 --- a/tools/dicp_triton_opt/dicp_triton_opt.cpp +++ b/tools/dicp_triton_opt/dicp_triton_opt.cpp @@ -1,7 +1,10 @@ +#include "dicp/Conversion/DiscreteMaskAccessConversion/Passes.h" #include "dicp/Conversion/LinalgToLinked/Passes.h" #include "dicp/Conversion/LinalgToNPU/Passes.h" #include "dicp/Conversion/LinkedToHIVM/Passes.h" +#include "dicp/Conversion/TritonToLinalgNPU/MemRefCopyGatherToTensorInsert/Passes.h" #include "dicp/Conversion/TritonToLinalgNPU/TritonToLinalgNPUCoversion/Passes.h" +#include "dicp/Conversion/TritonToUnstructure/Passes.h" #include "dicp/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "dicp/Dialect/LinalgExt/Transforms/Passes.h" #include "dicp/Dialect/NPU/IR/NPUDialect.h" @@ -86,12 +89,17 @@ inline void registerDICPDialects(mlir::DialectRegistry ®istry) { mlir::registerAllPasses(); mlir::registerLinalgPasses(); + mlir::triton::registerDiscreteMaskAccessConversionPass(); + mlir::triton::registerTritonToUnstructurePass(); + mlir::triton::registerBubbleUpOperationPass(); + dicp::npu::registerLinalgToNPUPass(); dicp::linked::registerLinalgToLinkedPass(); dicp::trtion_ext::registerCanonicalizeTritonIRAscendPass(); dicp::trtion_ext::registerCanonicalizeCmpiPass(); dicp::linked::registerLinkedToHIVMPass(); dicp::linked::registerTritonToLinalgNPUCoversionPass(); + dicp::linked::registerMemRefCopyGatherToTensorInsertPass(); dicp::LinalgExt::registerLinalgIfToSelectPass(); dicp::LinalgExt::registerLinalgGenericToSCFPass(); diff --git a/triton_dicp_triton.cc b/triton_dicp_triton.cc index 9a9d0992..979e0071 100644 --- a/triton_dicp_triton.cc +++ b/triton_dicp_triton.cc @@ -1,8 +1,11 @@ +#include "dicp/Conversion/DiscreteMaskAccessConversion/Passes.h" #include "dicp/Conversion/LinalgToLinked/LinalgToLinked.h" #include "dicp/Conversion/LinalgToLinked/Passes.h" #include "dicp/Conversion/LinalgToNPU/Passes.h" #include "dicp/Conversion/LinkedToHIVM/Passes.h" #include "dicp/Conversion/TritonToLinalgNPU/TritonToLinalgNPUCoversion/Passes.h" +#include "dicp/Conversion/TritonToUnstructure/BubbleUpOperation.h" +#include "dicp/Conversion/TritonToUnstructure/UnstructureConversionPass.h" #include "dicp/Dialect/LinalgExt/Transforms/Passes.h" #include "dicp/Dialect/TritonExt/Transforms/Passes.h" @@ -44,6 +47,13 @@ void init_triton_dicp_triton_pass_triton_shared_ascend(py::module &&m) { dicp::trtion_ext::createCanonicalizeTritonIRAscendPass); ADD_PASS_WRAPPER_0("add_triton_to_linalg_npu", dicp::linked::createTritonToLinalgNPUCoversionPass); + ADD_PASS_OPTION_WRAPPER_2("add_discrete_mask_access_conversion", + triton::createDiscreteMaskAccessConversionPass, + bool, bool); + ADD_PASS_WRAPPER_0("add_triton_to_unstructure", + triton::createTritonToUnstructurePass); + ADD_PASS_WRAPPER_0("add_bubble_up_operation", + triton::createBubbleUpOperationPass); } void init_triton_dicp_triton_pass_linked_npu(py::module &&m) {