Enzyme main
Loading...
Searching...
No Matches
InlineEnzymeRegions.cpp
Go to the documentation of this file.
1//===- InlineEnzymeRegions.cpp - Inline/outline enzyme.autodiff ------------ //
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements passes to inlining and outlining to convert
10// between enzyme.autodiff and enzyme.autodiff_region ops.
11//
12//===----------------------------------------------------------------------===//
13#include "Dialect/Ops.h"
15#include "Passes/Passes.h"
16
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19#include "mlir/Interfaces/FunctionInterfaces.h"
20#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21#include "mlir/Transforms/RegionUtils.h"
22#include "llvm/ADT/TypeSwitch.h"
23
24using namespace mlir;
25
26namespace mlir {
27namespace enzyme {
28#define GEN_PASS_DEF_INLINEENZYMEINTOREGIONPASS
29#define GEN_PASS_DEF_OUTLINEENZYMEFROMREGIONPASS
30#include "Passes/Passes.h.inc"
31} // namespace enzyme
32} // namespace mlir
33
34namespace {
35constexpr static llvm::StringLiteral kFnAttrsName = "fn_attrs";
36
37static StringRef getFunctionTypeAttrName(Operation *operation) {
38 if (auto iface = dyn_cast<mlir::enzyme::AutoDiffFunctionInterface>(operation))
39 return iface.getFunctionTypeAttrName();
40 return "";
41}
42
43static StringRef getArgAttrsAttrName(Operation *operation) {
44 if (auto iface = dyn_cast<mlir::enzyme::AutoDiffFunctionInterface>(operation))
45 return iface.getArgAttrsAttrName();
46 return "";
47}
48
49static void serializeFunctionAttributes(Operation *fn, Operation *regionOp) {
50 SmallVector<NamedAttribute> fnAttrs;
51 fnAttrs.reserve(fn->getAttrDictionary().size());
52 for (auto attr : fn->getAttrs()) {
53 // Don't store the function type or sym_name because they may change when
54 // outlining
55 if (attr.getName() == getFunctionTypeAttrName(fn) ||
56 attr.getName() == SymbolTable::getSymbolAttrName())
57 continue;
58 fnAttrs.push_back(attr);
59 }
60
61 regionOp->setAttr(kFnAttrsName,
62 DictionaryAttr::getWithSorted(fn->getContext(), fnAttrs));
63}
64
65template <typename DiffRegionOp>
66static void deserializeFunctionAttributes(DiffRegionOp op,
67 Operation *outlinedFunc,
68 unsigned addedArgCount) {
69 if (!op->template hasAttrOfType<DictionaryAttr>(kFnAttrsName))
70 return;
71
72 MLIRContext *ctx = op->getContext();
73 SmallVector<NamedAttribute> fnAttrs;
74 for (auto attr : op->template getAttrOfType<DictionaryAttr>(kFnAttrsName)) {
75 // New arguments are potentially added when outlining due to references to
76 // values outside the region. Insert an empty arg attr for each newly
77 // added argument.
78 if (attr.getName() == getArgAttrsAttrName(outlinedFunc)) {
79 SmallVector<Attribute> argAttrs(
80 cast<ArrayAttr>(attr.getValue())
81 .template getAsRange<DictionaryAttr>());
82 for (unsigned i = 0; i < addedArgCount; ++i)
83 argAttrs.push_back(DictionaryAttr::getWithSorted(ctx, {}));
84 fnAttrs.push_back(
85 NamedAttribute(attr.getName(), ArrayAttr::get(ctx, argAttrs)));
86 } else
87 fnAttrs.push_back(attr);
88 }
89 outlinedFunc->setAttrs(fnAttrs);
90}
91
92struct InlineEnzymeAutoDiff : public OpRewritePattern<enzyme::AutoDiffOp> {
93 using OpRewritePattern<enzyme::AutoDiffOp>::OpRewritePattern;
94 LogicalResult matchAndRewrite(enzyme::AutoDiffOp op,
95 PatternRewriter &rewriter) const override {
96 SymbolTableCollection symbolTable;
97 if (!mlir::enzyme::inlineAutodiffOp(op, rewriter, symbolTable))
98 return failure();
99
100 return success();
101 }
102};
103
104struct InlineEnzymeForwardDiff
105 : public OpRewritePattern<enzyme::ForwardDiffOp> {
106 using OpRewritePattern<enzyme::ForwardDiffOp>::OpRewritePattern;
107 LogicalResult matchAndRewrite(enzyme::ForwardDiffOp op,
108 PatternRewriter &rewriter) const override {
109 SymbolTableCollection symbolTable;
110 FunctionOpInterface fn = dyn_cast_or_null<FunctionOpInterface>(
111 symbolTable.lookupNearestSymbolFrom(op, op.getFnAttr()));
112
113 if (!fn)
114 return failure();
115
116 Region &targetRegion = fn.getFunctionBody();
117
118 if (targetRegion.empty())
119 return failure();
120
121 // Use a StringAttr rather than a SymbolRefAttr so the function can get
122 // symbol-DCE'd
123 auto fnAttr = StringAttr::get(op.getContext(), op.getFn());
124 auto regionOp = rewriter.replaceOpWithNewOp<enzyme::ForwardDiffRegionOp>(
125 op, op.getResultTypes(), op.getInputs(), op.getActivity(),
126 op.getRetActivity(), op.getWidth(), op.getStrongZero(), fnAttr);
127
128 serializeFunctionAttributes(fn, regionOp);
129 rewriter.cloneRegionBefore(targetRegion, regionOp.getBody(),
130 regionOp.getBody().begin());
131
132 SmallVector<Operation *> toErase;
133 for (Operation &bodyOp : regionOp.getBody().getOps()) {
134 if (bodyOp.hasTrait<OpTrait::ReturnLike>()) {
135 PatternRewriter::InsertionGuard insertionGuard(rewriter);
136 rewriter.setInsertionPoint(&bodyOp);
137 enzyme::YieldOp::create(rewriter, bodyOp.getLoc(),
138 bodyOp.getOperands());
139 toErase.push_back(&bodyOp);
140 }
141 }
142
143 for (Operation *opToErase : toErase)
144 rewriter.eraseOp(opToErase);
145 return success();
146 }
147};
148
149// Based on
150// https://github.com/llvm/llvm-project/blob/665da0a1649814471739c41a702e0e9447316b20/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
151template <typename DiffRegionOp>
152static FailureOr<func::FuncOp> outlineAutoDiffFunc(
153 DiffRegionOp op, StringRef funcName, SmallVectorImpl<Value> &inputs,
154 SmallVectorImpl<enzyme::Activity> &argActivities, OpBuilder &builder) {
155 Region &autodiffRegion = op.getBody();
156 SmallVector<Type> argTypes(autodiffRegion.getArgumentTypes()), resultTypes;
157 SmallVector<Location> argLocs(autodiffRegion.getNumArguments(), op.getLoc());
158 // Infer the result types from an enzyme.yield op
159 bool found = false;
160 autodiffRegion.walk([&](enzyme::YieldOp yieldOp) {
161 found = true;
162 llvm::append_range(resultTypes, yieldOp.getOperandTypes());
163 return WalkResult::interrupt();
164 });
165 if (!found)
166 return op.emitError()
167 << "enzyme.yield was not found in enzyme.autodiff_region";
168
169 llvm::SetVector<Value> freeValues;
170 getUsedValuesDefinedAbove(autodiffRegion, freeValues);
171
172 llvm::SmallVector<Value> primalValuesAbove = op.getPrimalInputs();
173 llvm::SmallVector<Value> blockArgs(autodiffRegion.getArguments());
174 for (Value value : llvm::make_early_inc_range(freeValues)) {
175 bool isPrimal = false;
176 for (auto [pval, bval] : llvm::zip(primalValuesAbove, blockArgs)) {
177 if (value == pval) {
178 isPrimal = true;
179 for (OpOperand &use : llvm::make_early_inc_range(value.getUses())) {
180 if (op->isProperAncestor(use.getOwner()))
181 use.assign(bval);
182 }
183 freeValues.remove(value);
184 }
185 }
186
187 if (!isPrimal) {
188 inputs.push_back(value);
189 argTypes.push_back(value.getType());
190 argLocs.push_back(value.getLoc());
191 argActivities.push_back(enzyme::Activity::enzyme_const);
192 }
193 }
194 auto fnType = builder.getFunctionType(argTypes, resultTypes);
195
196 // FIXME: making this location the location of the
197 // enzyme.autodiff_region op causes translation to LLVM IR to fail due
198 // to some issue with the dbg info.
199 Location loc = UnknownLoc::get(op.getContext());
200 auto outlinedFunc = func::FuncOp::create(builder, loc, funcName, fnType);
201 Region &outlinedBody = outlinedFunc.getBody();
202 deserializeFunctionAttributes(op, outlinedFunc, freeValues.size());
203
204 // Copy over the function body.
205 IRMapping map;
206 Block *entryBlock = builder.createBlock(&outlinedBody, outlinedBody.begin(),
207 argTypes, argLocs);
208 unsigned originalArgCount = autodiffRegion.getNumArguments();
209 for (const auto &arg : autodiffRegion.getArguments())
210 map.map(arg, entryBlock->getArgument(arg.getArgNumber()));
211 for (const auto &operand : enumerate(freeValues))
212 map.map(operand.value(),
213 entryBlock->getArgument(originalArgCount + operand.index()));
214 autodiffRegion.cloneInto(&outlinedBody, map);
215
216 // Replace the terminators with returns
217 for (Block &block : autodiffRegion) {
218 Block *clonedBlock = map.lookup(&block);
219 auto terminator = dyn_cast<enzyme::YieldOp>(clonedBlock->getTerminator());
220 if (!terminator)
221 continue;
222 OpBuilder replacer(terminator);
223 func::ReturnOp::create(replacer, terminator->getLoc(),
224 terminator->getOperands());
225 terminator->erase();
226 }
227
228 // cloneInto results in two blocks, the actual outlined entry block and the
229 // cloned autodiff_region entry block. Splice the cloned entry block into
230 // the actual entry block, then erase the cloned autodiff_region entry.
231 Block *clonedEntry = map.lookup(&autodiffRegion.front());
232 entryBlock->getOperations().splice(entryBlock->getOperations().end(),
233 clonedEntry->getOperations());
234 clonedEntry->erase();
235 return outlinedFunc;
236}
237
238LogicalResult outlineEnzymeAutoDiffRegion(enzyme::AutoDiffRegionOp op,
239 StringRef funcName,
240 OpBuilder &builder) {
241 OpBuilder::InsertionGuard insertionGuard(builder);
242 builder.setInsertionPointAfter(op->getParentOfType<SymbolOpInterface>());
243
244 SmallVector<enzyme::Activity> argActivities =
245 llvm::map_to_vector(op.getActivity().getAsRange<enzyme::ActivityAttr>(),
246 [](auto attr) { return attr.getValue(); });
247 size_t numSeeds = llvm::count_if(
248 op.getRetActivity().getAsRange<enzyme::ActivityAttr>(), [](auto attr) {
249 switch (attr.getValue()) {
250 case enzyme::Activity::enzyme_active:
251 case enzyme::Activity::enzyme_activenoneed:
252 return true;
253 default:
254 return false;
255 }
256 });
257 size_t numPrimalsAndShadows = op.getInputs().size() - numSeeds;
258 SmallVector<Value> primalsAndShadows(
259 op.getInputs().begin(), op.getInputs().begin() + numPrimalsAndShadows);
260 SmallVector<Value> seedInputs(op.getInputs().begin() + numPrimalsAndShadows,
261 op.getInputs().end());
262
263 // Free variables are appended to primalInputs.
264 // The final input ordering should be:
265 // 1. primals and duplicated argument shadows
266 // 2. free variables
267 // 3. return variable shadows (a.k.a. seeds)
268 FailureOr<func::FuncOp> outlinedFunc = outlineAutoDiffFunc(
269 op, funcName, primalsAndShadows, argActivities, builder);
270 if (failed(outlinedFunc))
271 return failure();
272
273 SmallVector<Value> allInputs;
274 allInputs.append(primalsAndShadows);
275 allInputs.append(seedInputs);
276
277 builder.setInsertionPoint(op);
278 ArrayAttr argActivityAttr = builder.getArrayAttr(llvm::map_to_vector(
279 argActivities, [&op](enzyme::Activity actv) -> Attribute {
280 return enzyme::ActivityAttr::get(op.getContext(), actv);
281 }));
282 auto newOp = enzyme::AutoDiffOp::create(
283 builder, op.getLoc(), op.getResultTypes(), outlinedFunc->getName(),
284 allInputs, argActivityAttr, op.getRetActivity(), op.getWidth(),
285 op.getStrongZero());
286 op.replaceAllUsesWith(newOp.getResults());
287 op.erase();
288 return success();
289}
290
291LogicalResult outlineEnzymeForwardDiffRegion(enzyme::ForwardDiffRegionOp op,
292 StringRef funcName,
293 OpBuilder &builder) {
294 OpBuilder::InsertionGuard insertionGuard(builder);
295 builder.setInsertionPointAfter(op->getParentOfType<SymbolOpInterface>());
296
297 SmallVector<enzyme::Activity> argActivities =
298 llvm::map_to_vector(op.getActivity().getAsRange<enzyme::ActivityAttr>(),
299 [](auto attr) { return attr.getValue(); });
300 SmallVector<Value> primalsAndShadows(op.getInputs());
301
302 // Free variables are appended to primalInputs.
303 // The final input ordering should be:
304 // 1. primals and duplicated argument shadows
305 // 2. free variables
306 FailureOr<func::FuncOp> outlinedFunc = outlineAutoDiffFunc(
307 op, funcName, primalsAndShadows, argActivities, builder);
308 if (failed(outlinedFunc))
309 return failure();
310
311 SmallVector<Value> allInputs;
312 allInputs.append(primalsAndShadows);
313
314 builder.setInsertionPoint(op);
315 ArrayAttr argActivityAttr = builder.getArrayAttr(llvm::map_to_vector(
316 argActivities, [&op](enzyme::Activity actv) -> Attribute {
317 return enzyme::ActivityAttr::get(op.getContext(), actv);
318 }));
319 auto newOp = enzyme::ForwardDiffOp::create(
320 builder, op.getLoc(), op.getResultTypes(), outlinedFunc->getName(),
321 primalsAndShadows, argActivityAttr, op.getRetActivity(), op.getWidth(),
322 op.getStrongZero());
323 op.replaceAllUsesWith(newOp.getResults());
324 op.erase();
325 return success();
326}
327
328struct InlineEnzymeIntoRegion
329 : public enzyme::impl::InlineEnzymeIntoRegionPassBase<
330 InlineEnzymeIntoRegion> {
331 void runOnOperation() override {
332 RewritePatternSet patterns(&getContext());
333 patterns.insert<InlineEnzymeAutoDiff, InlineEnzymeForwardDiff>(
334 &getContext());
335
336 GreedyRewriteConfig config;
337 (void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
338 }
339};
340
341struct OutlineEnzymeFromRegion
342 : public enzyme::impl::OutlineEnzymeFromRegionPassBase<
343 OutlineEnzymeFromRegion> {
344 void runOnOperation() override {
345 SmallVector<enzyme::AutoDiffRegionOp> toOutlineRev;
346 getOperation()->walk(
347 [&](enzyme::AutoDiffRegionOp op) { toOutlineRev.push_back(op); });
348
349 SmallVector<enzyme::ForwardDiffRegionOp> toOutlineFwd;
350 getOperation()->walk(
351 [&](enzyme::ForwardDiffRegionOp op) { toOutlineFwd.push_back(op); });
352
353 OpBuilder builder(getOperation());
354 unsigned increment = 0;
355 for (auto regionOp : toOutlineRev) {
356 auto symbol = regionOp->getParentOfType<SymbolOpInterface>();
357 std::string defaultName =
358 (Twine(symbol.getName(), "_to_diff") + Twine(increment)).str();
359 if (failed(outlineEnzymeAutoDiffRegion(regionOp, defaultName, builder)))
360 return signalPassFailure();
361
362 ++increment;
363 }
364
365 for (auto regionOp : toOutlineFwd) {
366 auto symbol = regionOp->getParentOfType<SymbolOpInterface>();
367 std::string defaultName =
368 (Twine(symbol.getName(), "_to_fwddiff") + Twine(increment)).str();
369 if (failed(
370 outlineEnzymeForwardDiffRegion(regionOp, defaultName, builder)))
371 return signalPassFailure();
372
373 ++increment;
374 }
375 }
376};
377
378} // namespace
379
380bool mlir::enzyme::inlineAutodiffOp(enzyme::AutoDiffOp &op,
381 RewriterBase &rewriter,
382 SymbolTableCollection &symbolTable) {
383 FunctionOpInterface fn = dyn_cast_or_null<FunctionOpInterface>(
384 symbolTable.lookupNearestSymbolFrom(op, op.getFnAttr()));
385 if (!fn)
386 return false;
387 Region &targetRegion = fn.getFunctionBody();
388 if (targetRegion.empty())
389 return false;
390
391 // Use a StringAttr rather than a SymbolRefAttr so the function can get
392 // symbol-DCE'd
393 auto fnAttr = StringAttr::get(op.getContext(), op.getFn());
394 auto regionOp = rewriter.replaceOpWithNewOp<enzyme::AutoDiffRegionOp>(
395 op, op.getResultTypes(), op.getInputs(), op.getActivity(),
396 op.getRetActivity(), op.getWidth(), op.getStrongZero(), fnAttr);
397 serializeFunctionAttributes(fn, regionOp);
398 rewriter.cloneRegionBefore(targetRegion, regionOp.getBody(),
399 regionOp.getBody().begin());
400 SmallVector<Operation *> toErase;
401 for (Operation &bodyOp : regionOp.getBody().getOps()) {
402 if (bodyOp.hasTrait<OpTrait::ReturnLike>()) {
403 PatternRewriter::InsertionGuard insertionGuard(rewriter);
404 rewriter.setInsertionPoint(&bodyOp);
405 enzyme::YieldOp::create(rewriter, bodyOp.getLoc(), bodyOp.getOperands());
406 toErase.push_back(&bodyOp);
407 }
408 }
409
410 for (Operation *opToErase : toErase)
411 rewriter.eraseOp(opToErase);
412 return true;
413}
static std::string str(AugmentedStruct c)
Definition EnzymeLogic.h:62
bool inlineAutodiffOp(enzyme::AutoDiffOp &op, RewriterBase &rewriter, SymbolTableCollection &symbolTable)