Enzyme main
Loading...
Searching...
No Matches
EnzymeLogicReverse.cpp
Go to the documentation of this file.
1#include "Dialect/Ops.h"
4#include "mlir/IR/Matchers.h"
5#include "mlir/IR/SymbolTable.h"
6#include "mlir/Interfaces/FunctionInterfaces.h"
7
8// TODO: this shouldn't depend on specific dialects except Enzyme.
9#include "mlir/Dialect/Arith/IR/Arith.h"
10#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
11
12#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/Pass/PassManager.h"
15#include "mlir/Pass/PassRegistry.h"
16
17#include "EnzymeLogic.h"
20#include "llvm/ADT/ScopeExit.h"
21
22using namespace mlir;
23using namespace mlir::enzyme;
24
25void handleReturns(Block *oBB, Block *newBB, Block *reverseBB,
26 MGradientUtilsReverse *gutils) {
27 if (oBB->getNumSuccessors() == 0) {
28 Operation *returnStatement = newBB->getTerminator();
29 gutils->erase(returnStatement);
30
31 OpBuilder forwardToBackwardBuilder(newBB, newBB->end());
32
33 Operation *newBranchOp = cf::BranchOp::create(
34 forwardToBackwardBuilder, oBB->getTerminator()->getLoc(), reverseBB);
35
36 gutils->originalToNewFnOps[oBB->getTerminator()] = newBranchOp;
37 }
38}
39
40// Returns true iff the operation:
41// 1. Produces no active data nor active pointers
42// 2. Does not propagate active data nor pointers (via side effects)
43static bool isFullyInactive(Operation *op, MGradientUtils *gutils) {
44 return llvm::all_of(
45 op->getResults(),
46 [gutils](Value v) { return gutils->isConstantValue(v); }) &&
47 gutils->isConstantInstruction(op);
48}
49
50/*
51Create reverse mode adjoint for an operation.
52*/
53LogicalResult MEnzymeLogic::visitChild(Operation *op, OpBuilder &builder,
54 MGradientUtilsReverse *gutils) {
55 if ((op->getBlock()->getTerminator() != op) && isFullyInactive(op, gutils)) {
56 return success();
57 }
58 if (auto ifaceOp = dyn_cast<ReverseAutoDiffOpInterface>(op)) {
59 SmallVector<Value> caches = ifaceOp.cacheValues(gutils);
60 OpBuilder augmentBuilder(gutils->getNewFromOriginal(op));
61 ifaceOp.createShadowValues(augmentBuilder, gutils);
62 return ifaceOp.createReverseModeAdjoint(builder, gutils, caches);
63 }
64 op->emitError() << "could not compute the adjoint for this operation " << *op;
65 return failure();
66}
67
68LogicalResult MEnzymeLogic::visitChildren(Block *oBB, Block *reverseBB,
69 MGradientUtilsReverse *gutils) {
70 OpBuilder revBuilder(reverseBB, reverseBB->end());
71 bool valid = true;
72 if (!oBB->empty()) {
73 auto first = oBB->rbegin();
74 auto last = oBB->rend();
75 for (auto it = first; it != last; ++it) {
76 Operation *op = &*it;
77 valid &= visitChild(op, revBuilder, gutils).succeeded();
78 }
79 }
80 return success(valid);
81}
82
84 Block *oBB, Block *newBB, Block *reverseBB, MGradientUtilsReverse *gutils,
85 llvm::function_ref<buildReturnFunction> buildReturnOp) {
86 OpBuilder revBuilder(reverseBB, reverseBB->end());
87 if (oBB->hasNoPredecessors()) {
88 buildReturnOp(revBuilder, oBB);
89 } else {
90 Location loc = oBB->rbegin()->getLoc();
91 // TODO remove dependency on CF dialect
92
93 Value cache = gutils->insertInit(gutils->getIndexCacheType());
94
95 Value flag =
96 enzyme::PopOp::create(revBuilder, loc, gutils->getIndexType(), cache);
97
98 Block *defaultBlock = nullptr;
99
100 SmallVector<Block *> blocks;
101 SmallVector<APInt> indices;
102
103 OpBuilder newBuilder(newBB, newBB->begin());
104
105 SmallVector<Value, 1> diffes;
106 for (auto arg : oBB->getArguments()) {
107 if (!gutils->isConstantValue(arg) &&
108 !cast<AutoDiffTypeInterface>(arg.getType()).isMutable()) {
109 diffes.push_back(gutils->diffe(arg, revBuilder));
110 gutils->zeroDiffe(arg, revBuilder);
111 continue;
112 }
113 diffes.push_back(nullptr);
114 }
115
116 for (auto [idx, pred] : llvm::enumerate(oBB->getPredecessors())) {
117 auto reversePred = gutils->mapReverseModeBlocks.lookupOrNull(pred);
118
119 Block *newPred = gutils->getNewFromOriginal(pred);
120
121 OpBuilder predecessorBuilder(newPred->getTerminator());
122
123 Value pred_idx_c =
124 arith::ConstantIntOp::create(predecessorBuilder, loc, idx - 1, 32);
125 enzyme::PushOp::create(predecessorBuilder, loc, cache, pred_idx_c);
126
127 if (idx == 0) {
128 defaultBlock = reversePred;
129
130 } else {
131 indices.push_back(APInt(32, idx - 1));
132 blocks.push_back(reversePred);
133 }
134
135 auto term = pred->getTerminator();
136 if (auto iface = dyn_cast<BranchOpInterface>(term)) {
137 for (auto &op : term->getOpOperands())
138 if (auto blk_idx =
139 iface.getSuccessorBlockArgument(op.getOperandNumber()))
140 if (!gutils->isConstantValue(op.get()) &&
141 (*blk_idx).getOwner() == oBB) {
142 auto idx = (*blk_idx).getArgNumber();
143 if (diffes[idx]) {
144
145 Value rev_idx_c =
146 arith::ConstantIntOp::create(revBuilder, loc, idx - 1, 32);
147
148 auto to_prop = arith::SelectOp::create(
149 revBuilder, loc,
150 arith::CmpIOp::create(revBuilder, loc,
151 arith::CmpIPredicate::eq, flag,
152 rev_idx_c),
153 diffes[idx],
154 cast<AutoDiffTypeInterface>(diffes[idx].getType())
155 .createNullValue(revBuilder, loc));
156
157 gutils->addToDiffe(op.get(), to_prop, revBuilder);
158 }
159 }
160 } else {
161 assert(0 && "predecessor did not implement branch op interface");
162 }
163 }
164
165 cf::SwitchOp::create(revBuilder, loc, flag, defaultBlock, ArrayRef<Value>(),
166 ArrayRef<APInt>(indices), ArrayRef<Block *>(blocks),
167 SmallVector<ValueRange>(indices.size(), ValueRange()));
168 }
169}
170
172 MGradientUtilsReverse *gutils, Region &oldRegion, Region &newRegion,
173 llvm::function_ref<buildReturnFunction> buildFuncReturnOp,
174 std::function<std::pair<Value, Value>(Type)> cacheCreator) {
175 gutils->registerCacheCreatorHook(cacheCreator);
176 auto scope = llvm::make_scope_exit(
177 [&]() { gutils->deregisterCacheCreatorHook(cacheCreator); });
178
179 gutils->createReverseModeBlocks(oldRegion, newRegion);
180
181 bool valid = true;
182 for (auto &oBB : oldRegion) {
183 Block *newBB = gutils->getNewFromOriginal(&oBB);
184 Block *reverseBB = gutils->mapReverseModeBlocks.lookupOrNull(&oBB);
185 handleReturns(&oBB, newBB, reverseBB, gutils);
186 valid &= visitChildren(&oBB, reverseBB, gutils).succeeded();
187 handlePredecessors(&oBB, newBB, reverseBB, gutils, buildFuncReturnOp);
188 }
189 return success(valid);
190}
191
193 FunctionOpInterface fn, std::vector<DIFFE_TYPE> retType,
194 std::vector<DIFFE_TYPE> constants, MTypeAnalysis &TA,
195 std::vector<bool> returnPrimals, std::vector<bool> returnShadows,
196 DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType,
197 MFnTypeInfo type_args, std::vector<bool> volatile_args, void *augmented,
198 bool omp, llvm::StringRef postpasses, bool verifyPostPasses,
199 bool strongZero) {
200
201 if (fn.getFunctionBody().empty()) {
202 llvm::errs() << fn << "\n";
203 llvm_unreachable("Differentiating empty function");
204 }
205
206 MReverseCacheKey tup = {fn,
207 retType,
208 constants,
209 returnPrimals,
210 returnShadows,
211 mode,
212 freeMemory,
213 static_cast<unsigned>(width),
214 addedType,
215 type_args,
216 volatile_args,
217 omp};
218
219 {
220 auto cachedFn = ReverseCachedFunctions.find(tup);
221 if (cachedFn != ReverseCachedFunctions.end())
222 return cachedFn->second;
223 }
224
225 SmallVector<bool> returnPrimalsP(returnPrimals.begin(), returnPrimals.end());
226 SmallVector<bool> returnShadowsP(returnShadows.begin(), returnShadows.end());
227
229 *this, mode, width, fn, TA, type_args, returnPrimalsP, returnShadowsP,
230 retType, constants, addedType, omp, postpasses, verifyPostPasses,
231 strongZero);
232
233 ReverseCachedFunctions[tup] = gutils->newFunc;
234
235 Region &oldRegion = gutils->oldFunc.getFunctionBody();
236 Region &newRegion = gutils->newFunc.getFunctionBody();
237
238 auto buildFuncReturnOp = [&](OpBuilder &builder, Block *oBB) {
239 SmallVector<mlir::Value> retargs;
240 for (auto [arg, returnPrimal] :
241 llvm::zip(oBB->getTerminator()->getOperands(), returnPrimals)) {
242 if (returnPrimal) {
243 retargs.push_back(gutils->getNewFromOriginal(arg));
244 }
245 }
246 for (auto [arg, cv] : llvm::zip(oBB->getArguments(), constants)) {
247 if (cv == DIFFE_TYPE::OUT_DIFF) {
248 retargs.push_back(gutils->diffe(arg, builder));
249 }
250 }
251
252 Location loc = oBB->rbegin()->getLoc();
253 if (auto iface = dyn_cast<enzyme::AutoDiffFunctionInterface>(*fn))
254 iface.createReturn(builder, loc, retargs);
255 else
256 fn->emitError() << "this function operation does not implement "
257 "AutoDiffFunctionInterface";
258 return;
259 };
260
261 gutils->forceAugmentedReturns();
262
263 auto res =
264 differentiate(gutils, oldRegion, newRegion, buildFuncReturnOp, nullptr);
265
266 auto nf = gutils->newFunc;
267
268 delete gutils;
269
270 if (!res.succeeded())
271 return nullptr;
272
273 if (postpasses != "") {
274 mlir::PassManager pm(nf->getContext());
275 pm.enableVerifier(verifyPostPasses);
276 std::string error_message;
277 // llvm::raw_string_ostream error_stream(error_message);
278 mlir::LogicalResult result = mlir::parsePassPipeline(postpasses, pm);
279 if (mlir::failed(result)) {
280 return nullptr;
281 }
282
283 if (!mlir::succeeded(pm.run(nf))) {
284 return nullptr;
285 }
286 }
287
288 return nf;
289}
static bool isFullyInactive(Operation *op, MGradientUtils *gutils)
void handleReturns(Block *oBB, Block *newBB, Block *reverseBB, MGradientUtilsReverse *gutils)
DerivativeMode
Definition Utils.h:390
mlir::Value diffe(mlir::Value origv, mlir::OpBuilder &builder)
void zeroDiffe(mlir::Value origv, mlir::OpBuilder &builder)
FunctionOpInterface CreateReverseDiff(FunctionOpInterface fn, std::vector< DIFFE_TYPE > retType, std::vector< DIFFE_TYPE > constants, MTypeAnalysis &TA, std::vector< bool > returnPrimals, std::vector< bool > returnShadows, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector< bool > volatile_args, void *augmented, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
LogicalResult visitChildren(Block *oBB, Block *reverseBB, MGradientUtilsReverse *gutils)
LogicalResult visitChild(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils)
std::map< MReverseCacheKey, FunctionOpInterface > ReverseCachedFunctions
void handlePredecessors(Block *oBB, Block *newBB, Block *reverseBB, MGradientUtilsReverse *gutils, llvm::function_ref< buildReturnFunction > buildReturnOp)
LogicalResult differentiate(MGradientUtilsReverse *gutils, Region &oldRegion, Region &newRegion, llvm::function_ref< buildReturnFunction > buildFuncRetrunOp, std::function< std::pair< Value, Value >(Type)> cacheCreator)
void createReverseModeBlocks(Region &oldFunc, Region &newFunc)
void deregisterCacheCreatorHook(std::function< std::pair< Value, Value >(Type)> hook)
void registerCacheCreatorHook(std::function< std::pair< Value, Value >(Type)> hook)
void addToDiffe(mlir::Value oldGradient, mlir::Value addedGradient, OpBuilder &builder)
static MGradientUtilsReverse * CreateFromClone(MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width, FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo, const ArrayRef< bool > returnPrimals, const ArrayRef< bool > returnShadows, llvm::ArrayRef< DIFFE_TYPE > retType, llvm::ArrayRef< DIFFE_TYPE > constant_args, mlir::Type additionalArg, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
FunctionOpInterface oldFunc
std::map< Operation *, Operation * > originalToNewFnOps
void erase(Operation *op)
SmallVector< mlir::Value, 1 > getNewFromOriginal(ValueRange originst) const
FunctionOpInterface newFunc
bool isConstantInstruction(mlir::Operation *v) const
bool isConstantValue(mlir::Value v) const