17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19#include "mlir/Interfaces/FunctionInterfaces.h"
20#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21#include "mlir/Transforms/RegionUtils.h"
22#include "llvm/ADT/TypeSwitch.h"
28#define GEN_PASS_DEF_INLINEENZYMEINTOREGIONPASS
29#define GEN_PASS_DEF_OUTLINEENZYMEFROMREGIONPASS
30#include "Passes/Passes.h.inc"
35constexpr static llvm::StringLiteral kFnAttrsName =
"fn_attrs";
37static StringRef getFunctionTypeAttrName(Operation *operation) {
38 if (
auto iface = dyn_cast<mlir::enzyme::AutoDiffFunctionInterface>(operation))
39 return iface.getFunctionTypeAttrName();
43static StringRef getArgAttrsAttrName(Operation *operation) {
44 if (
auto iface = dyn_cast<mlir::enzyme::AutoDiffFunctionInterface>(operation))
45 return iface.getArgAttrsAttrName();
49static void serializeFunctionAttributes(Operation *fn, Operation *regionOp) {
50 SmallVector<NamedAttribute> fnAttrs;
51 fnAttrs.reserve(fn->getAttrDictionary().size());
52 for (
auto attr : fn->getAttrs()) {
55 if (attr.getName() == getFunctionTypeAttrName(fn) ||
56 attr.getName() == SymbolTable::getSymbolAttrName())
58 fnAttrs.push_back(attr);
61 regionOp->setAttr(kFnAttrsName,
62 DictionaryAttr::getWithSorted(fn->getContext(), fnAttrs));
65template <
typename DiffRegionOp>
66static void deserializeFunctionAttributes(DiffRegionOp op,
67 Operation *outlinedFunc,
68 unsigned addedArgCount) {
69 if (!op->template hasAttrOfType<DictionaryAttr>(kFnAttrsName))
72 MLIRContext *ctx = op->getContext();
73 SmallVector<NamedAttribute> fnAttrs;
74 for (
auto attr : op->template getAttrOfType<DictionaryAttr>(kFnAttrsName)) {
78 if (attr.getName() == getArgAttrsAttrName(outlinedFunc)) {
79 SmallVector<Attribute> argAttrs(
80 cast<ArrayAttr>(attr.getValue())
81 .template getAsRange<DictionaryAttr>());
82 for (
unsigned i = 0; i < addedArgCount; ++i)
83 argAttrs.push_back(DictionaryAttr::getWithSorted(ctx, {}));
85 NamedAttribute(attr.getName(), ArrayAttr::get(ctx, argAttrs)));
87 fnAttrs.push_back(attr);
89 outlinedFunc->setAttrs(fnAttrs);
94 LogicalResult matchAndRewrite(enzyme::AutoDiffOp op,
95 PatternRewriter &rewriter)
const override {
96 SymbolTableCollection symbolTable;
104struct InlineEnzymeForwardDiff
107 LogicalResult matchAndRewrite(enzyme::ForwardDiffOp op,
108 PatternRewriter &rewriter)
const override {
109 SymbolTableCollection symbolTable;
110 FunctionOpInterface fn = dyn_cast_or_null<FunctionOpInterface>(
111 symbolTable.lookupNearestSymbolFrom(op, op.getFnAttr()));
116 Region &targetRegion = fn.getFunctionBody();
118 if (targetRegion.empty())
123 auto fnAttr = StringAttr::get(op.getContext(), op.getFn());
124 auto regionOp = rewriter.replaceOpWithNewOp<enzyme::ForwardDiffRegionOp>(
125 op, op.getResultTypes(), op.getInputs(), op.getActivity(),
126 op.getRetActivity(), op.getWidth(), op.getStrongZero(), fnAttr);
128 serializeFunctionAttributes(fn, regionOp);
129 rewriter.cloneRegionBefore(targetRegion, regionOp.getBody(),
130 regionOp.getBody().begin());
132 SmallVector<Operation *> toErase;
133 for (Operation &bodyOp : regionOp.getBody().getOps()) {
134 if (bodyOp.hasTrait<OpTrait::ReturnLike>()) {
135 PatternRewriter::InsertionGuard insertionGuard(rewriter);
136 rewriter.setInsertionPoint(&bodyOp);
137 enzyme::YieldOp::create(rewriter, bodyOp.getLoc(),
138 bodyOp.getOperands());
139 toErase.push_back(&bodyOp);
143 for (Operation *opToErase : toErase)
144 rewriter.eraseOp(opToErase);
151template <
typename DiffRegionOp>
152static FailureOr<func::FuncOp> outlineAutoDiffFunc(
153 DiffRegionOp op, StringRef funcName, SmallVectorImpl<Value> &inputs,
154 SmallVectorImpl<enzyme::Activity> &argActivities, OpBuilder &builder) {
155 Region &autodiffRegion = op.getBody();
156 SmallVector<Type> argTypes(autodiffRegion.getArgumentTypes()), resultTypes;
157 SmallVector<Location> argLocs(autodiffRegion.getNumArguments(), op.getLoc());
160 autodiffRegion.walk([&](enzyme::YieldOp yieldOp) {
162 llvm::append_range(resultTypes, yieldOp.getOperandTypes());
163 return WalkResult::interrupt();
166 return op.emitError()
167 <<
"enzyme.yield was not found in enzyme.autodiff_region";
169 llvm::SetVector<Value> freeValues;
170 getUsedValuesDefinedAbove(autodiffRegion, freeValues);
172 llvm::SmallVector<Value> primalValuesAbove = op.getPrimalInputs();
173 llvm::SmallVector<Value> blockArgs(autodiffRegion.getArguments());
174 for (Value value : llvm::make_early_inc_range(freeValues)) {
175 bool isPrimal =
false;
176 for (
auto [pval, bval] : llvm::zip(primalValuesAbove, blockArgs)) {
179 for (OpOperand &use : llvm::make_early_inc_range(value.getUses())) {
180 if (op->isProperAncestor(use.getOwner()))
183 freeValues.remove(value);
188 inputs.push_back(value);
189 argTypes.push_back(value.getType());
190 argLocs.push_back(value.getLoc());
191 argActivities.push_back(enzyme::Activity::enzyme_const);
194 auto fnType = builder.getFunctionType(argTypes, resultTypes);
199 Location loc = UnknownLoc::get(op.getContext());
200 auto outlinedFunc = func::FuncOp::create(builder, loc, funcName, fnType);
201 Region &outlinedBody = outlinedFunc.getBody();
202 deserializeFunctionAttributes(op, outlinedFunc, freeValues.size());
206 Block *entryBlock = builder.createBlock(&outlinedBody, outlinedBody.begin(),
208 unsigned originalArgCount = autodiffRegion.getNumArguments();
209 for (
const auto &arg : autodiffRegion.getArguments())
210 map.map(arg, entryBlock->getArgument(arg.getArgNumber()));
211 for (
const auto &operand : enumerate(freeValues))
212 map.map(operand.value(),
213 entryBlock->getArgument(originalArgCount + operand.index()));
214 autodiffRegion.cloneInto(&outlinedBody, map);
217 for (Block &block : autodiffRegion) {
218 Block *clonedBlock = map.lookup(&block);
219 auto terminator = dyn_cast<enzyme::YieldOp>(clonedBlock->getTerminator());
222 OpBuilder replacer(terminator);
223 func::ReturnOp::create(replacer, terminator->getLoc(),
224 terminator->getOperands());
231 Block *clonedEntry = map.lookup(&autodiffRegion.front());
232 entryBlock->getOperations().splice(entryBlock->getOperations().end(),
233 clonedEntry->getOperations());
234 clonedEntry->erase();
238LogicalResult outlineEnzymeAutoDiffRegion(enzyme::AutoDiffRegionOp op,
240 OpBuilder &builder) {
241 OpBuilder::InsertionGuard insertionGuard(builder);
242 builder.setInsertionPointAfter(op->getParentOfType<SymbolOpInterface>());
244 SmallVector<enzyme::Activity> argActivities =
245 llvm::map_to_vector(op.getActivity().getAsRange<enzyme::ActivityAttr>(),
246 [](
auto attr) { return attr.getValue(); });
247 size_t numSeeds = llvm::count_if(
248 op.getRetActivity().getAsRange<enzyme::ActivityAttr>(), [](
auto attr) {
249 switch (attr.getValue()) {
250 case enzyme::Activity::enzyme_active:
251 case enzyme::Activity::enzyme_activenoneed:
257 size_t numPrimalsAndShadows = op.getInputs().size() - numSeeds;
258 SmallVector<Value> primalsAndShadows(
259 op.getInputs().begin(), op.getInputs().begin() + numPrimalsAndShadows);
260 SmallVector<Value> seedInputs(op.getInputs().begin() + numPrimalsAndShadows,
261 op.getInputs().end());
268 FailureOr<func::FuncOp> outlinedFunc = outlineAutoDiffFunc(
269 op, funcName, primalsAndShadows, argActivities, builder);
270 if (failed(outlinedFunc))
273 SmallVector<Value> allInputs;
274 allInputs.append(primalsAndShadows);
275 allInputs.append(seedInputs);
277 builder.setInsertionPoint(op);
278 ArrayAttr argActivityAttr = builder.getArrayAttr(llvm::map_to_vector(
279 argActivities, [&op](enzyme::Activity actv) -> Attribute {
280 return enzyme::ActivityAttr::get(op.getContext(), actv);
282 auto newOp = enzyme::AutoDiffOp::create(
283 builder, op.getLoc(), op.getResultTypes(), outlinedFunc->getName(),
284 allInputs, argActivityAttr, op.getRetActivity(), op.getWidth(),
286 op.replaceAllUsesWith(newOp.getResults());
291LogicalResult outlineEnzymeForwardDiffRegion(enzyme::ForwardDiffRegionOp op,
293 OpBuilder &builder) {
294 OpBuilder::InsertionGuard insertionGuard(builder);
295 builder.setInsertionPointAfter(op->getParentOfType<SymbolOpInterface>());
297 SmallVector<enzyme::Activity> argActivities =
298 llvm::map_to_vector(op.getActivity().getAsRange<enzyme::ActivityAttr>(),
299 [](
auto attr) { return attr.getValue(); });
300 SmallVector<Value> primalsAndShadows(op.getInputs());
306 FailureOr<func::FuncOp> outlinedFunc = outlineAutoDiffFunc(
307 op, funcName, primalsAndShadows, argActivities, builder);
308 if (failed(outlinedFunc))
311 SmallVector<Value> allInputs;
312 allInputs.append(primalsAndShadows);
314 builder.setInsertionPoint(op);
315 ArrayAttr argActivityAttr = builder.getArrayAttr(llvm::map_to_vector(
316 argActivities, [&op](enzyme::Activity actv) -> Attribute {
317 return enzyme::ActivityAttr::get(op.getContext(), actv);
319 auto newOp = enzyme::ForwardDiffOp::create(
320 builder, op.getLoc(), op.getResultTypes(), outlinedFunc->getName(),
321 primalsAndShadows, argActivityAttr, op.getRetActivity(), op.getWidth(),
323 op.replaceAllUsesWith(newOp.getResults());
328struct InlineEnzymeIntoRegion
329 :
public enzyme::impl::InlineEnzymeIntoRegionPassBase<
330 InlineEnzymeIntoRegion> {
331 void runOnOperation()
override {
332 RewritePatternSet patterns(&getContext());
333 patterns.insert<InlineEnzymeAutoDiff, InlineEnzymeForwardDiff>(
336 GreedyRewriteConfig config;
337 (void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
341struct OutlineEnzymeFromRegion
342 :
public enzyme::impl::OutlineEnzymeFromRegionPassBase<
343 OutlineEnzymeFromRegion> {
344 void runOnOperation()
override {
345 SmallVector<enzyme::AutoDiffRegionOp> toOutlineRev;
346 getOperation()->walk(
347 [&](enzyme::AutoDiffRegionOp op) { toOutlineRev.push_back(op); });
349 SmallVector<enzyme::ForwardDiffRegionOp> toOutlineFwd;
350 getOperation()->walk(
351 [&](enzyme::ForwardDiffRegionOp op) { toOutlineFwd.push_back(op); });
353 OpBuilder builder(getOperation());
354 unsigned increment = 0;
355 for (
auto regionOp : toOutlineRev) {
356 auto symbol = regionOp->getParentOfType<SymbolOpInterface>();
357 std::string defaultName =
358 (Twine(symbol.getName(),
"_to_diff") + Twine(increment)).
str();
359 if (failed(outlineEnzymeAutoDiffRegion(regionOp, defaultName, builder)))
360 return signalPassFailure();
365 for (
auto regionOp : toOutlineFwd) {
366 auto symbol = regionOp->getParentOfType<SymbolOpInterface>();
367 std::string defaultName =
368 (Twine(symbol.getName(),
"_to_fwddiff") + Twine(increment)).
str();
370 outlineEnzymeForwardDiffRegion(regionOp, defaultName, builder)))
371 return signalPassFailure();
381 RewriterBase &rewriter,
382 SymbolTableCollection &symbolTable) {
383 FunctionOpInterface fn = dyn_cast_or_null<FunctionOpInterface>(
384 symbolTable.lookupNearestSymbolFrom(op, op.getFnAttr()));
387 Region &targetRegion = fn.getFunctionBody();
388 if (targetRegion.empty())
393 auto fnAttr = StringAttr::get(op.getContext(), op.getFn());
394 auto regionOp = rewriter.replaceOpWithNewOp<enzyme::AutoDiffRegionOp>(
395 op, op.getResultTypes(), op.getInputs(), op.getActivity(),
396 op.getRetActivity(), op.getWidth(), op.getStrongZero(), fnAttr);
397 serializeFunctionAttributes(fn, regionOp);
398 rewriter.cloneRegionBefore(targetRegion, regionOp.getBody(),
399 regionOp.getBody().begin());
400 SmallVector<Operation *> toErase;
401 for (Operation &bodyOp : regionOp.getBody().getOps()) {
402 if (bodyOp.hasTrait<OpTrait::ReturnLike>()) {
403 PatternRewriter::InsertionGuard insertionGuard(rewriter);
404 rewriter.setInsertionPoint(&bodyOp);
405 enzyme::YieldOp::create(rewriter, bodyOp.getLoc(), bodyOp.getOperands());
406 toErase.push_back(&bodyOp);
410 for (Operation *opToErase : toErase)
411 rewriter.eraseOp(opToErase);
static std::string str(AugmentedStruct c)
bool inlineAutodiffOp(enzyme::AutoDiffOp &op, RewriterBase &rewriter, SymbolTableCollection &symbolTable)