Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,37 @@ def AutoDiffRegionOp : Enzyme_Op<"autodiff_region", [AutomaticAllocationScope]>
}];
}

def ForwardDiffRegionOp : Enzyme_Op<"fwddiff_region", [AutomaticAllocationScope]> {
let summary = "Perform forward mode AD on a child region";
let arguments = (ins Variadic<AnyType>:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity, DefaultValuedAttr<I64Attr, "1">:$width, DefaultValuedAttr<BoolAttr, "false">:$strong_zero, OptionalAttr<StrAttr>:$fn);
let regions = (region AnyRegion:$body);
let results = (outs Variadic<AnyType>:$outputs);

let assemblyFormat = [{
`(` $inputs `)` $body attr-dict-with-keyword `:` functional-type($inputs, results)
}];

let hasCanonicalizer = 1;
let extraClassDeclaration = [{

/// Collect all primal input values
::llvm::SmallVector<::mlir::Value> getPrimalInputs() {
return ::mlir::enzyme::detail::filterGradInputs<ForwardDiffRegionOp, false>(
*this);
}

/// Collect all input shadow values(for primals with activity marked as
/// `enzyme_dup`/`enzyme_dupnoneed`
::llvm::SmallVector<::mlir::Value> getShadows() {
return ::mlir::enzyme::detail::filterGradInputs<ForwardDiffRegionOp, true,
true, false>(*this);
}

}];
}

def YieldOp : Enzyme_Op<"yield", [Pure, ReturnLike, Terminator,
ParentOneOf<["AutoDiffRegionOp", "LoopOp"]>]> {
ParentOneOf<["AutoDiffRegionOp", "ForwardDiffRegionOp", "LoopOp"]>]> {
let summary = "Yield values at the end of an autodiff_region or loop op";
let arguments = (ins Variadic<AnyType>:$operands);
let assemblyFormat = [{
Expand Down
129 changes: 89 additions & 40 deletions enzyme/Enzyme/MLIR/Dialect/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,59 @@ ForwardDiffOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// ForwardDiffOp
//===----------------------------------------------------------------------===//

// Some templated helpers for rewriting EnzymeOps(we can overload the create
// definitions as and when necessary)
template <typename SourceOp> struct EnzymeOpCreator;

template <> struct EnzymeOpCreator<AutoDiffOp> {
static AutoDiffOp create(PatternRewriter &rewriter, AutoDiffOp uop,
ArrayRef<Type> out_ty, ArrayRef<Value> in_args,
ArrayAttr newInActivity, ArrayAttr newRetActivity) {

return AutoDiffOp::create(rewriter, uop.getLoc(), out_ty, uop.getFnAttr(),
in_args, newInActivity, newRetActivity,
uop.getWidthAttr(), uop.getStrongZeroAttr());
}
};

template <> struct EnzymeOpCreator<AutoDiffRegionOp> {
static AutoDiffRegionOp create(PatternRewriter &rewriter,
AutoDiffRegionOp uop, ArrayRef<Type> out_ty,
ArrayRef<Value> in_args,
ArrayAttr newInActivity,
ArrayAttr newRetActivity) {
auto newOp = AutoDiffRegionOp::create(
rewriter, uop.getLoc(), out_ty, in_args, newInActivity, newRetActivity,
uop.getWidthAttr(), uop.getStrongZeroAttr(), uop.getFnAttr());
newOp.getBody().takeBody(uop.getBody());
return newOp;
}
};

template <> struct EnzymeOpCreator<ForwardDiffOp> {
static ForwardDiffOp create(PatternRewriter &rewriter, ForwardDiffOp uop,
ArrayRef<Type> out_ty, ArrayRef<Value> in_args,
ArrayAttr newInActivity,
ArrayAttr newRetActivity) {
return ForwardDiffOp::create(
rewriter, uop.getLoc(), out_ty, uop.getFnAttr(), in_args, newInActivity,
newRetActivity, uop.getWidthAttr(), uop.getStrongZeroAttr());
}
};

template <> struct EnzymeOpCreator<ForwardDiffRegionOp> {
static ForwardDiffRegionOp
create(PatternRewriter &rewriter, ForwardDiffRegionOp uop,
ArrayRef<Type> out_ty, ArrayRef<Value> in_args,
ArrayAttr newInActivity, ArrayAttr newRetActivity) {
auto newOp = ForwardDiffRegionOp::create(
rewriter, uop.getLoc(), out_ty, in_args, newInActivity, newRetActivity,
uop.getWidthAttr(), uop.getStrongZeroAttr(), uop.getFnAttr());
newOp.getBody().takeBody(uop.getBody());
return newOp;
}
};

// Helper: check if any input is mutable.
static inline bool isMutable(Type type) {
if (isa<mlir::MemRefType>(type) || isa<mlir::UnrankedMemRefType>(type) ||
Expand All @@ -188,11 +241,12 @@ static inline bool isMutable(Type type) {
* ------> enzyme_const
*
*/
class FwdInpOpt final : public OpRewritePattern<ForwardDiffOp> {
template <typename SourceOp>
class FwdInpOpt final : public OpRewritePattern<SourceOp> {
public:
using OpRewritePattern<ForwardDiffOp>::OpRewritePattern;
using OpRewritePattern<SourceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ForwardDiffOp uop,
LogicalResult matchAndRewrite(SourceOp uop,
PatternRewriter &rewriter) const override {

if (uop.getOutputs().size() == 0)
Expand Down Expand Up @@ -282,12 +336,23 @@ class FwdInpOpt final : public OpRewritePattern<ForwardDiffOp> {
ArrayAttr::get(rewriter.getContext(),
llvm::ArrayRef<Attribute>(newInActivityArgs.begin(),
newInActivityArgs.end()));
rewriter.replaceOpWithNewOp<ForwardDiffOp>(
uop, uop->getResultTypes(), uop.getFnAttr(), in_args, newInActivity,
uop.getRetActivityAttr(), uop.getWidthAttr(), uop.getStrongZeroAttr());

if constexpr (std::is_same_v<SourceOp, ForwardDiffOp>) {

rewriter.replaceOpWithNewOp<ForwardDiffOp>(
uop, uop->getResultTypes(), uop.getFnAttr(), in_args, newInActivity,
uop.getRetActivityAttr(), uop.getWidthAttr(),
uop.getStrongZeroAttr());
} else {
rewriter.replaceOpWithNewOp<ForwardDiffRegionOp>(
uop, uop->getResultTypes(), in_args, newInActivity,
uop.getRetActivityAttr(), uop.getWidthAttr(), uop.getStrongZeroAttr(),
uop.getFnAttr());
}
return success();
}
};

/**
*
* Modifies return activites for the FwdDiffOp
Expand All @@ -301,11 +366,15 @@ class FwdInpOpt final : public OpRewritePattern<ForwardDiffOp> {
* ------> enzyme_const -----
*
*/
class FwdRetOpt final : public OpRewritePattern<ForwardDiffOp> {
template <typename SourceOp>
class FwdRetOpt final : public OpRewritePattern<SourceOp> {
private:
using SourceOpCreator = EnzymeOpCreator<SourceOp>;

public:
using OpRewritePattern<ForwardDiffOp>::OpRewritePattern;
using OpRewritePattern<SourceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ForwardDiffOp uop,
LogicalResult matchAndRewrite(SourceOp uop,
PatternRewriter &rewriter) const override {

if (uop.getOutputs().size() == 0)
Expand Down Expand Up @@ -436,10 +505,9 @@ class FwdRetOpt final : public OpRewritePattern<ForwardDiffOp> {
llvm::ArrayRef<Attribute>(newRetActivityArgs.begin(),
newRetActivityArgs.end()));

ForwardDiffOp newOp = ForwardDiffOp::create(
rewriter, uop.getLoc(), out_ty, uop.getFnAttr(), uop.getInputs(),
uop.getActivityAttr(), newRetActivity, uop.getWidthAttr(),
uop.getStrongZeroAttr());
SmallVector<Value> in_args = uop.getInputs();
SourceOp newOp = SourceOpCreator::create(
rewriter, uop, out_ty, in_args, uop.getActivityAttr(), newRetActivity);

// Map old uses of uop to newOp
auto oldIdx = 0;
Expand Down Expand Up @@ -499,7 +567,13 @@ class FwdRetOpt final : public OpRewritePattern<ForwardDiffOp> {
void ForwardDiffOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {

patterns.add<FwdRetOpt, FwdInpOpt>(context);
patterns.add<FwdRetOpt<ForwardDiffOp>, FwdInpOpt<ForwardDiffOp>>(context);
}

void ForwardDiffRegionOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<FwdRetOpt<ForwardDiffRegionOp>, FwdInpOpt<ForwardDiffRegionOp>>(
context);
}

LogicalResult AutoDiffOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
Expand Down Expand Up @@ -600,7 +674,7 @@ void BroadcastOp::build(OpBuilder &builder, OperationState &result, Value input,
template <typename SourceOp>
class ReverseRetOpt final : public OpRewritePattern<SourceOp> {
private:
struct SourceOpCreator;
using SourceOpCreator = EnzymeOpCreator<SourceOp>;

public:
using OpRewritePattern<SourceOp>::OpRewritePattern;
Expand Down Expand Up @@ -884,31 +958,6 @@ class ReverseRetOpt final : public OpRewritePattern<SourceOp> {
}
};

template <> struct ReverseRetOpt<AutoDiffOp>::SourceOpCreator {
static AutoDiffOp create(PatternRewriter &rewriter, AutoDiffOp uop,
ArrayRef<Type> out_ty, ArrayRef<Value> in_args,
ArrayAttr newInActivity, ArrayAttr newRetActivity) {

return AutoDiffOp::create(rewriter, uop.getLoc(), out_ty, uop.getFnAttr(),
in_args, newInActivity, newRetActivity,
uop.getWidthAttr(), uop.getStrongZeroAttr());
}
};

template <> struct ReverseRetOpt<AutoDiffRegionOp>::SourceOpCreator {
static AutoDiffRegionOp create(PatternRewriter &rewriter,
AutoDiffRegionOp uop, ArrayRef<Type> out_ty,
ArrayRef<Value> in_args,
ArrayAttr newInActivity,
ArrayAttr newRetActivity) {

auto newOp = AutoDiffRegionOp::create(
rewriter, uop.getLoc(), out_ty, in_args, newInActivity, newRetActivity,
uop.getWidthAttr(), uop.getStrongZeroAttr(), uop.getFnAttr());
newOp.getBody().takeBody(uop.getBody());
return newOp;
}
};
void AutoDiffOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<ReverseRetOpt<AutoDiffOp>>(context);
Expand Down
52 changes: 49 additions & 3 deletions enzyme/Enzyme/MLIR/Passes/InlineEnzymeRegions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ static StringRef getArgAttrsAttrName(Operation *operation) {
return "";
}

static void serializeFunctionAttributes(Operation *fn,
enzyme::AutoDiffRegionOp regionOp) {
static void serializeFunctionAttributes(Operation *fn, Operation *regionOp) {
SmallVector<NamedAttribute> fnAttrs;
fnAttrs.reserve(fn->getAttrDictionary().size());
for (auto attr : fn->getAttrs()) {
Expand Down Expand Up @@ -100,6 +99,52 @@ struct InlineEnzymeAutoDiff : public OpRewritePattern<enzyme::AutoDiffOp> {
}
};

struct InlineEnzymeForwardDiff
: public OpRewritePattern<enzyme::ForwardDiffOp> {
using OpRewritePattern<enzyme::ForwardDiffOp>::OpRewritePattern;
LogicalResult matchAndRewrite(enzyme::ForwardDiffOp op,
PatternRewriter &rewriter) const override {
SymbolTableCollection symbolTable;

FunctionOpInterface fn = dyn_cast_or_null<FunctionOpInterface>(
symbolTable.lookupNearestSymbolFrom(op, op.getFnAttr()));

if (!fn)
return failure();

Region &targetRegion = fn.getFunctionBody();

if (targetRegion.empty())
return failure();

// Use a StringAttr rather than a SymbolRefAttr so the function can get
// symbol-DCE'd
auto fnAttr = StringAttr::get(op.getContext(), op.getFn());
auto regionOp = rewriter.replaceOpWithNewOp<enzyme::ForwardDiffRegionOp>(
op, op.getResultTypes(), op.getInputs(), op.getActivity(),
op.getRetActivity(), op.getWidth(), op.getStrongZero(), fnAttr);

serializeFunctionAttributes(fn, regionOp);
rewriter.cloneRegionBefore(targetRegion, regionOp.getBody(),
regionOp.getBody().begin());

SmallVector<Operation *> toErase;
for (Operation &bodyOp : regionOp.getBody().getOps()) {
if (bodyOp.hasTrait<OpTrait::ReturnLike>()) {
PatternRewriter::InsertionGuard insertionGuard(rewriter);
rewriter.setInsertionPoint(&bodyOp);
enzyme::YieldOp::create(rewriter, bodyOp.getLoc(),
bodyOp.getOperands());
toErase.push_back(&bodyOp);
}
}

for (Operation *opToErase : toErase)
rewriter.eraseOp(opToErase);
return success();
}
};

// Based on
// https://github.com/llvm/llvm-project/blob/665da0a1649814471739c41a702e0e9447316b20/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
static FailureOr<func::FuncOp>
Expand Down Expand Up @@ -248,7 +293,8 @@ struct InlineEnzymeIntoRegion
InlineEnzymeIntoRegion> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.insert<InlineEnzymeAutoDiff>(&getContext());
patterns.insert<InlineEnzymeAutoDiff, InlineEnzymeForwardDiff>(
&getContext());

GreedyRewriteConfig config;
(void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
Expand Down
Loading