From b148f6a1e3db648d096fe721da5c9c973eb422b7 Mon Sep 17 00:00:00 2001 From: Vimarsh Sathia Date: Mon, 1 Dec 2025 00:41:10 -0600 Subject: [PATCH 1/3] Add a fwddiff region op --- enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td | 30 ++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index 5e197c21128..f847101a248 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -243,8 +243,36 @@ def AutoDiffRegionOp : Enzyme_Op<"autodiff_region", [AutomaticAllocationScope]> }]; } +def ForwardDiffRegionOp : Enzyme_Op<"fwddiff_region", [AutomaticAllocationScope]> { + let summary = "Perform forward mode AD on a child region"; + let arguments = (ins Variadic:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity, DefaultValuedAttr:$width, DefaultValuedAttr:$strong_zero, OptionalAttr:$fn); + let regions = (region AnyRegion:$body); + let results = (outs Variadic:$outputs); + + let assemblyFormat = [{ + `(` $inputs `)` $body attr-dict-with-keyword `:` functional-type($inputs, results) + }]; + + let extraClassDeclaration = [{ + + /// Collect all primal input values + ::llvm::SmallVector<::mlir::Value> getPrimalInputs() { + return ::mlir::enzyme::detail::filterGradInputs( + *this); + } + + /// Collect all input shadow values(for primals with activity marked as + /// `enzyme_dup`/`enzyme_dupnoneed` + ::llvm::SmallVector<::mlir::Value> getShadows() { + return ::mlir::enzyme::detail::filterGradInputs(*this); + } + + }]; +} + def YieldOp : Enzyme_Op<"yield", [Pure, ReturnLike, Terminator, - ParentOneOf<["AutoDiffRegionOp", "LoopOp"]>]> { + ParentOneOf<["AutoDiffRegionOp", "ForwardDiffRegionOp", "LoopOp"]>]> { let summary = "Yield values at the end of an autodiff_region or loop op"; let arguments = (ins Variadic:$operands); let assemblyFormat = [{ From c7b930611cbe14559371861ae6ffda9302a4c539 Mon Sep 17 00:00:00 2001 From: Vimarsh Sathia Date: Mon, 1 Dec 2025 20:43:37 -0600 Subject: [PATCH 2/3] add canonicalizer --- enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td | 3 +- enzyme/Enzyme/MLIR/Dialect/Ops.cpp | 129 ++++++++++++++++-------- 2 files changed, 91 insertions(+), 41 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index f847101a248..8942454d511 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -252,7 +252,8 @@ def ForwardDiffRegionOp : Enzyme_Op<"fwddiff_region", [AutomaticAllocationScope] let assemblyFormat = [{ `(` $inputs `)` $body attr-dict-with-keyword `:` functional-type($inputs, results) }]; - + + let hasCanonicalizer = 1; let extraClassDeclaration = [{ /// Collect all primal input values diff --git a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp index 02a361bcda1..602f7d719d1 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp +++ b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp @@ -165,6 +165,59 @@ ForwardDiffOp::verifySymbolUses(SymbolTableCollection &symbolTable) { // ForwardDiffOp //===----------------------------------------------------------------------===// +// Some templated helpers for rewriting EnzymeOps(we can overload the create +// definitions as and when necessary) +template struct EnzymeOpCreator; + +template <> struct EnzymeOpCreator { + static AutoDiffOp create(PatternRewriter &rewriter, AutoDiffOp uop, + ArrayRef out_ty, ArrayRef in_args, + ArrayAttr newInActivity, ArrayAttr newRetActivity) { + + return AutoDiffOp::create(rewriter, uop.getLoc(), out_ty, uop.getFnAttr(), + in_args, newInActivity, newRetActivity, + uop.getWidthAttr(), uop.getStrongZeroAttr()); + } +}; + +template <> struct EnzymeOpCreator { + static AutoDiffRegionOp create(PatternRewriter &rewriter, + AutoDiffRegionOp uop, ArrayRef out_ty, + ArrayRef in_args, + ArrayAttr newInActivity, + ArrayAttr newRetActivity) { + auto newOp = AutoDiffRegionOp::create( + rewriter, uop.getLoc(), out_ty, in_args, newInActivity, newRetActivity, + uop.getWidthAttr(), uop.getStrongZeroAttr(), uop.getFnAttr()); + newOp.getBody().takeBody(uop.getBody()); + return newOp; + } +}; + +template <> struct EnzymeOpCreator { + static ForwardDiffOp create(PatternRewriter &rewriter, ForwardDiffOp uop, + ArrayRef out_ty, ArrayRef in_args, + ArrayAttr newInActivity, + ArrayAttr newRetActivity) { + return ForwardDiffOp::create( + rewriter, uop.getLoc(), out_ty, uop.getFnAttr(), in_args, newInActivity, + newRetActivity, uop.getWidthAttr(), uop.getStrongZeroAttr()); + } +}; + +template <> struct EnzymeOpCreator { + static ForwardDiffRegionOp + create(PatternRewriter &rewriter, ForwardDiffRegionOp uop, + ArrayRef out_ty, ArrayRef in_args, + ArrayAttr newInActivity, ArrayAttr newRetActivity) { + auto newOp = ForwardDiffRegionOp::create( + rewriter, uop.getLoc(), out_ty, in_args, newInActivity, newRetActivity, + uop.getWidthAttr(), uop.getStrongZeroAttr(), uop.getFnAttr()); + newOp.getBody().takeBody(uop.getBody()); + return newOp; + } +}; + // Helper: check if any input is mutable. static inline bool isMutable(Type type) { if (isa(type) || isa(type) || @@ -188,11 +241,12 @@ static inline bool isMutable(Type type) { * ------> enzyme_const * */ -class FwdInpOpt final : public OpRewritePattern { +template +class FwdInpOpt final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ForwardDiffOp uop, + LogicalResult matchAndRewrite(SourceOp uop, PatternRewriter &rewriter) const override { if (uop.getOutputs().size() == 0) @@ -282,12 +336,23 @@ class FwdInpOpt final : public OpRewritePattern { ArrayAttr::get(rewriter.getContext(), llvm::ArrayRef(newInActivityArgs.begin(), newInActivityArgs.end())); - rewriter.replaceOpWithNewOp( - uop, uop->getResultTypes(), uop.getFnAttr(), in_args, newInActivity, - uop.getRetActivityAttr(), uop.getWidthAttr(), uop.getStrongZeroAttr()); + + if constexpr (std::is_same_v) { + + rewriter.replaceOpWithNewOp( + uop, uop->getResultTypes(), uop.getFnAttr(), in_args, newInActivity, + uop.getRetActivityAttr(), uop.getWidthAttr(), + uop.getStrongZeroAttr()); + } else { + rewriter.replaceOpWithNewOp( + uop, uop->getResultTypes(), in_args, newInActivity, + uop.getRetActivityAttr(), uop.getWidthAttr(), uop.getStrongZeroAttr(), + uop.getFnAttr()); + } return success(); } }; + /** * * Modifies return activites for the FwdDiffOp @@ -301,11 +366,15 @@ class FwdInpOpt final : public OpRewritePattern { * ------> enzyme_const ----- * */ -class FwdRetOpt final : public OpRewritePattern { +template +class FwdRetOpt final : public OpRewritePattern { +private: + using SourceOpCreator = EnzymeOpCreator; + public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ForwardDiffOp uop, + LogicalResult matchAndRewrite(SourceOp uop, PatternRewriter &rewriter) const override { if (uop.getOutputs().size() == 0) @@ -436,10 +505,9 @@ class FwdRetOpt final : public OpRewritePattern { llvm::ArrayRef(newRetActivityArgs.begin(), newRetActivityArgs.end())); - ForwardDiffOp newOp = ForwardDiffOp::create( - rewriter, uop.getLoc(), out_ty, uop.getFnAttr(), uop.getInputs(), - uop.getActivityAttr(), newRetActivity, uop.getWidthAttr(), - uop.getStrongZeroAttr()); + SmallVector in_args = uop.getInputs(); + SourceOp newOp = SourceOpCreator::create( + rewriter, uop, out_ty, in_args, uop.getActivityAttr(), newRetActivity); // Map old uses of uop to newOp auto oldIdx = 0; @@ -499,7 +567,13 @@ class FwdRetOpt final : public OpRewritePattern { void ForwardDiffOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add(context); + patterns.add, FwdInpOpt>(context); +} + +void ForwardDiffRegionOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add, FwdInpOpt>( + context); } LogicalResult AutoDiffOp::verifySymbolUses(SymbolTableCollection &symbolTable) { @@ -600,7 +674,7 @@ void BroadcastOp::build(OpBuilder &builder, OperationState &result, Value input, template class ReverseRetOpt final : public OpRewritePattern { private: - struct SourceOpCreator; + using SourceOpCreator = EnzymeOpCreator; public: using OpRewritePattern::OpRewritePattern; @@ -884,31 +958,6 @@ class ReverseRetOpt final : public OpRewritePattern { } }; -template <> struct ReverseRetOpt::SourceOpCreator { - static AutoDiffOp create(PatternRewriter &rewriter, AutoDiffOp uop, - ArrayRef out_ty, ArrayRef in_args, - ArrayAttr newInActivity, ArrayAttr newRetActivity) { - - return AutoDiffOp::create(rewriter, uop.getLoc(), out_ty, uop.getFnAttr(), - in_args, newInActivity, newRetActivity, - uop.getWidthAttr(), uop.getStrongZeroAttr()); - } -}; - -template <> struct ReverseRetOpt::SourceOpCreator { - static AutoDiffRegionOp create(PatternRewriter &rewriter, - AutoDiffRegionOp uop, ArrayRef out_ty, - ArrayRef in_args, - ArrayAttr newInActivity, - ArrayAttr newRetActivity) { - - auto newOp = AutoDiffRegionOp::create( - rewriter, uop.getLoc(), out_ty, in_args, newInActivity, newRetActivity, - uop.getWidthAttr(), uop.getStrongZeroAttr(), uop.getFnAttr()); - newOp.getBody().takeBody(uop.getBody()); - return newOp; - } -}; void AutoDiffOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add>(context); From ba7074c4e52955df665488925c2a7b2ef6d0575b Mon Sep 17 00:00:00 2001 From: Vimarsh Sathia Date: Tue, 2 Dec 2025 02:29:04 -0600 Subject: [PATCH 3/3] Added region inlining --- .../MLIR/Passes/InlineEnzymeRegions.cpp | 52 +++++++++++++++++-- 1 file changed, 49 insertions(+), 3 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Passes/InlineEnzymeRegions.cpp b/enzyme/Enzyme/MLIR/Passes/InlineEnzymeRegions.cpp index 200ee379a47..cb2ad0eccbe 100644 --- a/enzyme/Enzyme/MLIR/Passes/InlineEnzymeRegions.cpp +++ b/enzyme/Enzyme/MLIR/Passes/InlineEnzymeRegions.cpp @@ -46,8 +46,7 @@ static StringRef getArgAttrsAttrName(Operation *operation) { return ""; } -static void serializeFunctionAttributes(Operation *fn, - enzyme::AutoDiffRegionOp regionOp) { +static void serializeFunctionAttributes(Operation *fn, Operation *regionOp) { SmallVector fnAttrs; fnAttrs.reserve(fn->getAttrDictionary().size()); for (auto attr : fn->getAttrs()) { @@ -100,6 +99,52 @@ struct InlineEnzymeAutoDiff : public OpRewritePattern { } }; +struct InlineEnzymeForwardDiff + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(enzyme::ForwardDiffOp op, + PatternRewriter &rewriter) const override { + SymbolTableCollection symbolTable; + + FunctionOpInterface fn = dyn_cast_or_null( + symbolTable.lookupNearestSymbolFrom(op, op.getFnAttr())); + + if (!fn) + return failure(); + + Region &targetRegion = fn.getFunctionBody(); + + if (targetRegion.empty()) + return failure(); + + // Use a StringAttr rather than a SymbolRefAttr so the function can get + // symbol-DCE'd + auto fnAttr = StringAttr::get(op.getContext(), op.getFn()); + auto regionOp = rewriter.replaceOpWithNewOp( + op, op.getResultTypes(), op.getInputs(), op.getActivity(), + op.getRetActivity(), op.getWidth(), op.getStrongZero(), fnAttr); + + serializeFunctionAttributes(fn, regionOp); + rewriter.cloneRegionBefore(targetRegion, regionOp.getBody(), + regionOp.getBody().begin()); + + SmallVector toErase; + for (Operation &bodyOp : regionOp.getBody().getOps()) { + if (bodyOp.hasTrait()) { + PatternRewriter::InsertionGuard insertionGuard(rewriter); + rewriter.setInsertionPoint(&bodyOp); + enzyme::YieldOp::create(rewriter, bodyOp.getLoc(), + bodyOp.getOperands()); + toErase.push_back(&bodyOp); + } + } + + for (Operation *opToErase : toErase) + rewriter.eraseOp(opToErase); + return success(); + } +}; + // Based on // https://github.com/llvm/llvm-project/blob/665da0a1649814471739c41a702e0e9447316b20/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp static FailureOr @@ -248,7 +293,8 @@ struct InlineEnzymeIntoRegion InlineEnzymeIntoRegion> { void runOnOperation() override { RewritePatternSet patterns(&getContext()); - patterns.insert(&getContext()); + patterns.insert( + &getContext()); GreedyRewriteConfig config; (void)applyPatternsGreedily(getOperation(), std::move(patterns), config);