Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ if (TRITON_BUILD_PYTHON_MODULE)

LinalgToLinked
LinkedToHIVM
DiscreteMaskAccessConversion
TritonToUnstructure
)
target_link_libraries(tritonDicpTriton PRIVATE Python3::Module pybind11::headers)
endif()
8 changes: 8 additions & 0 deletions backend/npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,11 @@ def ttir_to_linalg(mod, metadata, opt, *, named_ops=True):

def ttir_to_ttsharedir_ascend(mod, metadata, opt, *, named_ops=False):
pm = ir.pass_manager(mod.context)
dicp_triton.passes.triton_shared_ascend.add_discrete_mask_access_conversion(
pm, False, False
)
dicp_triton.passes.triton_shared_ascend.add_triton_to_unstructure(pm)
dicp_triton.passes.triton_shared_ascend.add_bubble_up_operation(pm)
dicp_triton.passes.triton_shared_ascend.add_canonicalize_cmpi(pm)
dicp_triton.passes.triton_shared_ascend.add_canonicalize_triton_ir_ascend(pm)
dicp_triton.passes.triton_shared_ascend.add_triton_to_linalg_npu(pm)
Expand All @@ -404,6 +409,9 @@ def ttir_to_ttsharedir_ascend(mod, metadata, opt, *, named_ops=False):
cmd_list = [
_get_dicp_opt_path(),
"kernel.ttir.mlir",
"--discrete-mask-access-conversion",
"--triton-to-unstructure",
"--bubble-up-operation",
"--canonicalize-cmpi",
"--canonicalize-triton-ir-ascend",
"--triton-to-linalg-npu-conversion",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls --name DiscreteMaskAccessConversion)
add_public_tablegen_target(DiscreteMaskAccessConversionPassIncGen)
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@

#ifndef TRITON_ANALYSIS_MASKANALYSIS_H
#define TRITON_ANALYSIS_MASKANALYSIS_H

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

#include <utility>

namespace mlir {

// this class helps build Operations
class OpBuilder;

namespace dicp {
// use to decode the pattern in a mask used for load and store

class MaskState {
public:
OpFoldResult start;
OpFoldResult end;
SmallVector<OpFoldResult> dims;
SmallVector<OpFoldResult> offsets;
OpFoldResult scalar;

int64_t getRank() const {
assert(dims.size() == offsets.size() && "dims and offsets rank mismatch!");
return dims.size();
}

bool isEmpty() const { return getRank() == 0 && !scalar && !start && !end; }

bool isMask() const {
return !start && !end && !scalar && dims.size() != 0 && offsets.size() != 0;
}

// parse value recursively
LogicalResult parse(Value operand, const Location &loc, OpBuilder &builder);

tensor::ExtractSliceOp getExtractSlice(Value source, const Location &loc,
OpBuilder &builder) const;

tensor::InsertSliceOp getInsertSlice(Value source, Value dest,
const Location &loc,
OpBuilder &builder) const;

memref::SubViewOp getSubview(Value source, const Location &loc,
OpBuilder &builder) const;

void eraseInsertedOps(Operation *rawOp, PatternRewriter &rewriter);

private:
LogicalResult addStateScalar(const MaskState &state,
const OpFoldResult scalar, const Location &loc,
OpBuilder &builder);

LogicalResult addStates(const MaskState &lhsState, const MaskState &rhsState,
const Location &loc, OpBuilder &builder);

LogicalResult divStateScalar(const MaskState &state,
const OpFoldResult scalar, const Location &loc,
OpBuilder &builder);

LogicalResult divStates(const MaskState &lhsState, const MaskState &rhsState,
const Location &loc, OpBuilder &builder);

// Helper function to handle operator `and` both mask state
LogicalResult minStates(const MaskState &lhsState, const MaskState &rhsState,
const Location &loc, OpBuilder &builder);

// Helper functions to parse values to populate MaskState

LogicalResult parseConstant(arith::ConstantOp constOp, const Location &loc,
OpBuilder &builder);

// Operand is an integer scalar
LogicalResult parseIntScalar(Value scalar, const Location &loc,
OpBuilder &builder);

// TODO
LogicalResult parseAdd(arith::AddIOp addOp, const Location &loc,
OpBuilder &builder);

// operand is the result of divsi
LogicalResult parseDiv(arith::DivSIOp divOp, const Location &loc,
OpBuilder &builder);

// Operand is the result of andi
LogicalResult parseAnd(arith::AndIOp andOp, const Location &loc,
OpBuilder &builder);

// Operand is the result of cmpi, necessary method to fuse scalar, start and
// end into dims and offset
LogicalResult parseCmp(arith::CmpIOp cmpOp, const Location &loc,
OpBuilder &builder);

// Operand is the result of select
LogicalResult parseSel(arith::SelectOp selOp, const Location &loc,
OpBuilder &builder);

// Operand is the result of make_range
LogicalResult parseMakeRange(triton::MakeRangeOp rangeOp, const Location &loc,
OpBuilder &builder);

// Operand is the result of broadcast
LogicalResult parseBroadcast(triton::BroadcastOp broadcastOp,
const Location &loc, OpBuilder &builder);

// Operand is the result of splat
LogicalResult parseSplat(triton::SplatOp splatOp, const Location &loc,
OpBuilder &builder);

// Operand is the result of expand_dims
LogicalResult parseExpandDims(triton::ExpandDimsOp expandDimsOp,
const Location &loc, OpBuilder &builder);
};

} // namespace dicp

} // namespace mlir

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#ifndef TRITON_DLC_DISCRETE_MASK_ACCESS_CONVERSION_PASSES_H
#define TRITON_DLC_DISCRETE_MASK_ACCESS_CONVERSION_PASSES_H

#include "mlir/Pass/Pass.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

#include "mlir/IR/PatternMatch.h"

#define GEN_PASS_DECL_DISCRETEMASKACCESSCONVERSION
#include "dicp/Conversion//DiscreteMaskAccessConversion/Passes.h.inc"

#define GEN_PASS_DEF_DISCRETEMASKACCESSCONVERSION
#include "dicp/Conversion//DiscreteMaskAccessConversion/Passes.h.inc"

namespace mlir {
namespace triton {

std::unique_ptr<OperationPass<ModuleOp>> createDiscreteMaskAccessConversionPass(
const DiscreteMaskAccessConversionOptions &options = {});

} // namespace triton
} // namespace mlir

namespace mlir {
namespace triton {

#define GEN_PASS_REGISTRATION
#include "dicp/Conversion//DiscreteMaskAccessConversion/Passes.h.inc"

} // namespace triton
} // namespace mlir

#endif // TRITON_DLC_DISCRETE_MASK_ACCESS_CONVERSION_PASSES_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@


#ifndef DISCRETE_MASK_ACCESS_CONVERSION_PASSES
#define DISCRETE_MASK_ACCESS_CONVERSION_PASSES

include "mlir/Pass/PassBase.td"

def DiscreteMaskAccessConversion : Pass<"discrete-mask-access-conversion", "mlir::ModuleOp"> {
let summary = "Recognize and convert discrete mask memory access";
let constructor = "triton::createDiscreteMaskAccessConversionPass()";
let options = [
Option<"compileOn91095", "compile-on-910-95",
"bool", /*default*/"false",
"compile on 910_95">,
Option<"forceSimtTemplate", "force-simt-template",
"bool", /*default*/"false",
"force to use simt template">
];
}

#endif // DISCRETE_MASK_ACCESS_CONVERSION_PASSES
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef TRITON_ADAPTER_TRITONOPCONVERTER_H
#define TRITON_ADAPTER_TRITONOPCONVERTER_H
#ifndef TRITON_DLC_TRITONOPCONVERTER_H
#define TRITON_DLC_TRITONOPCONVERTER_H

#include "triton/Dialect/Triton/IR/Dialect.h"

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls --name MemRefCopyGatherToTensorInsert)
add_public_tablegen_target(MemRefCopyGatherToTensorInsertPassIncGen)
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef MEMREF_COPY_GATHER_TO_TENSOR_INSERT_PASSES_H
#define MEMREF_COPY_GATHER_TO_TENSOR_INSERT_PASSES_H

#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"

namespace mlir::dicp::linked {

std::unique_ptr<OperationPass<mlir::ModuleOp>>
createMemRefCopyGatherToTensorInsertPass();

#define GEN_PASS_REGISTRATION
#include "dicp/Conversion/TritonToLinalgNPU/MemRefCopyGatherToTensorInsert/Passes.h.inc"

} // namespace mlir::dicp::linked

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef MEMREF_COPY_GATHER_TO_TENSOR_INSERT_PASSES
#define MEMREF_COPY_GATHER_TO_TENSOR_INSERT_PASSES

include "mlir/Pass/PassBase.td"

def MemRefCopyGatherToTensorInsert : Pass<"discrete-gather-to-direct-insert", "mlir::ModuleOp"> {
let summary = "Converts alloc+copy based gather loops to direct tensor insertions";
let description = [{
Identifies patterns where a temporary memref is allocated, populated via
index-based gathers (subview + copy) in a loop, and then converted to a tensor.
Replaces this with a tensor.empty + scf.for (iter_args) + tensor.insert
sequence to avoid stack allocation and enable register-level operation.
}];
let constructor = "::mlir::dicp::linked::createMemRefCopyGatherToTensorInsertPass()";
let dependentDialects = [
"mlir::scf::SCFDialect",
"mlir::memref::MemRefDialect",
"mlir::tensor::TensorDialect",
"mlir::arith::ArithDialect",
"mlir::bufferization::BufferizationDialect"
];
}

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#pragma once

#include "mlir/Pass/Pass.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

#include "mlir/IR/PatternMatch.h"

#define GEN_PASS_DECL_BUBBLEUPOPERATION
#include "dicp/Conversion/TritonToUnstructure/Passes.h.inc"

#define GEN_PASS_DEF_BUBBLEUPOPERATION
#include "dicp/Conversion/TritonToUnstructure/Passes.h.inc"

namespace mlir {
namespace triton {

std::unique_ptr<OperationPass<ModuleOp>>
createBubbleUpOperationPass(const BubbleUpOperationOptions &options = {});

} // namespace triton
} // namespace mlir

using namespace mlir;
using namespace triton;

class BubbleUpOperationPass
: public ::impl::BubbleUpOperationBase<BubbleUpOperationPass> {
public:
explicit BubbleUpOperationPass(const BubbleUpOperationOptions &options);
void runOnOperation() override;
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToUnstructure)
add_public_tablegen_target(TritonToUnstructureConversionPassIncGen)
Loading