Enzyme main
Loading...
Searching...
No Matches
EnzymeLogic.cpp
Go to the documentation of this file.
1#include "Dialect/Ops.h"
7#include "mlir/IR/Matchers.h"
8#include "mlir/Interfaces/FunctionInterfaces.h"
9
10// TODO: this shouldn't depend on specific dialects except Enzyme.
11#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
12
13#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
14#include "mlir/Dialect/Func/IR/FuncOps.h"
15#include "mlir/IR/Dominance.h"
16#include "mlir/Pass/PassManager.h"
17#include "mlir/Pass/PassRegistry.h"
18
19#include "llvm/ADT/BreadthFirstIterator.h"
20
21#include "EnzymeLogic.h"
22#include "GradientUtils.h"
23
24using namespace mlir;
25using namespace mlir::enzyme;
26
27void createTerminator(MGradientUtils *gutils, mlir::Block *oBB,
28 const ArrayRef<bool> returnPrimals,
29 const ArrayRef<bool> returnShadows) {
30 auto inst = oBB->getTerminator();
31
32 mlir::Block *nBB = gutils->getNewFromOriginal(inst->getBlock());
33 assert(nBB);
34 auto newInst = nBB->getTerminator();
35
36 OpBuilder nBuilder(inst);
37 nBuilder.setInsertionPointToEnd(nBB);
38
39 if (auto binst = dyn_cast<BranchOpInterface>(inst)) {
41 return;
42 }
43
44 // In forward mode we only need to update the return value
45 if (!inst->hasTrait<OpTrait::ReturnLike>())
46 return;
47
48 SmallVector<mlir::Value, 2> retargs;
49
50 for (auto &&[ret, returnPrimal, returnShadow] :
51 llvm::zip(inst->getOperands(), returnPrimals, returnShadows)) {
52 if (returnPrimal) {
53 retargs.push_back(gutils->getNewFromOriginal(ret));
54 }
55 if (returnShadow) {
56 if (!gutils->isConstantValue(ret)) {
57 retargs.push_back(gutils->invertPointerM(ret, nBuilder));
58 } else {
59 Type retTy = cast<AutoDiffTypeInterface>(ret.getType()).getShadowType();
60 auto toret = cast<AutoDiffTypeInterface>(retTy).createNullValue(
61 nBuilder, ret.getLoc());
62 retargs.push_back(toret);
63 }
64 }
65 }
66
67 nBB->push_back(newInst->create(newInst->getLoc(), newInst->getName(),
68 TypeRange(), retargs, newInst->getAttrs(),
69 mlir::PropertyRef(), newInst->getSuccessors(),
70 newInst->getNumRegions()));
71 gutils->erase(newInst);
72 return;
73}
74
75//===----------------------------------------------------------------------===//
76//===----------------------------------------------------------------------===//
77
79 FunctionOpInterface fn, std::vector<DIFFE_TYPE> RetActivity,
80 std::vector<DIFFE_TYPE> ArgActivity, MTypeAnalysis &TA,
81 std::vector<bool> returnPrimals, DerivativeMode mode, bool freeMemory,
82 size_t width, mlir::Type addedType, MFnTypeInfo type_args,
83 std::vector<bool> volatile_args, void *augmented, bool omp,
84 llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero) {
85 if (fn.getFunctionBody().empty()) {
86 llvm::errs() << fn << "\n";
87 llvm_unreachable("Differentiating empty function");
88 }
89 assert(fn.getFunctionBody().front().getNumArguments() == ArgActivity.size());
90 assert(fn.getFunctionBody().front().getNumArguments() ==
91 volatile_args.size());
92
93 MForwardCacheKey tup = {
94 fn, RetActivity, ArgActivity,
95 // std::map<Argument *, bool>(_uncacheable_args.begin(),
96 // _uncacheable_args.end()),
97 returnPrimals, mode, static_cast<unsigned>(width), addedType, type_args,
98 omp, strongZero};
99
100 if (ForwardCachedFunctions.find(tup) != ForwardCachedFunctions.end()) {
101 return ForwardCachedFunctions.find(tup)->second;
102 }
103 std::vector<bool> returnShadows;
104 for (auto act : RetActivity) {
105 returnShadows.push_back(act != DIFFE_TYPE::CONSTANT);
106 }
107 SmallVector<bool> returnPrimalsP(returnPrimals.begin(), returnPrimals.end());
108 SmallVector<bool> returnShadowsP(returnShadows.begin(), returnShadows.end());
110 *this, mode, width, fn, TA, type_args, returnPrimalsP, returnShadowsP,
111 RetActivity, ArgActivity, addedType,
112 /*omp*/ false, postpasses, verifyPostPasses, strongZero);
113 ForwardCachedFunctions[tup] = gutils->newFunc;
114
116 ForwardCachedFunctions, tup, gutils->newFunc);
118 // gutils->FreeMemory = freeMemory;
119
120 const SmallPtrSet<mlir::Block *, 4> guaranteedUnreachable;
121 // = getGuaranteedUnreachable(gutils->oldFunc);
122
123 // gutils->forceActiveDetection();
124 gutils->forceAugmentedReturns();
125 /*
126
127 // TODO populate with actual unnecessaryInstructions once the dependency
128 // cycle with activity analysis is removed
129 SmallPtrSet<const Instruction *, 4> unnecessaryInstructionsTmp;
130 for (auto BB : guaranteedUnreachable) {
131 for (auto &I : *BB)
132 unnecessaryInstructionsTmp.insert(&I);
134 if (mode == DerivativeMode::ForwardModeSplit)
135 gutils->computeGuaranteedFrees();
137 SmallPtrSet<const Value *, 4> unnecessaryValues;
138 SmallPtrSet<const Instruction *, 4> unnecessaryInstructions;
139 calculateUnusedValuesInFunction(
140 *gutils->oldFunc, unnecessaryValues, unnecessaryInstructions,
141 returnUsed, mode, gutils, TLI, constant_args, guaranteedUnreachable);
142 gutils->unnecessaryValuesP = &unnecessaryValues;
144 SmallPtrSet<const Instruction *, 4> unnecessaryStores;
145 calculateUnusedStoresInFunction(*gutils->oldFunc, unnecessaryStores,
146 unnecessaryInstructions, gutils, TLI);
147 */
148
149 bool valid = true;
150 for (Block &oBB : gutils->oldFunc.getFunctionBody().getBlocks()) {
151 // Don't create derivatives for code that results in termination
152 if (guaranteedUnreachable.find(&oBB) != guaranteedUnreachable.end()) {
153 auto newBB = gutils->getNewFromOriginal(&oBB);
154
155 for (auto &I : make_early_inc_range(reverse(oBB))) {
156 gutils->eraseIfUnused(&I, /*erase*/ true, /*check*/ false);
157 }
158
159 OpBuilder builder(gutils->oldFunc.getContext());
160 builder.setInsertionPointToEnd(newBB);
161 LLVM::UnreachableOp::create(builder, gutils->oldFunc.getLoc());
162 continue;
163 }
165 assert(oBB.getTerminator());
166
167 auto first = oBB.begin();
168 auto last = oBB.empty() ? oBB.end() : std::prev(oBB.end());
169 for (auto it = first; it != last; ++it) {
170 // TODO: propagate errors.
171 auto res = gutils->visitChild(&*it);
172 valid &= res.succeeded();
173 }
174
175 createTerminator(gutils, &oBB, returnPrimalsP, returnShadowsP);
176 }
177
178 // if (mode == DerivativeMode::ForwardModeSplit && augmenteddata)
179 // restoreCache(gutils, augmenteddata->tapeIndices, guaranteedUnreachable);
180
181 // gutils->eraseFictiousPHIs();
182
183 // mlir::Block *entry = &gutils->newFunc.getFunctionBody().front();
184
185 // cleanupInversionAllocs(gutils, entry);
186 // clearFunctionAttributes(gutils->newFunc);
187
188 /*
189 if (llvm::verifyFunction(*gutils->newFunc, &llvm::errs())) {
190 llvm::errs() << *gutils->oldFunc << "\n";
191 llvm::errs() << *gutils->newFunc << "\n";
192 report_fatal_error("function failed verification (4)");
193 }
194 */
195
196 auto nf = gutils->newFunc;
197 delete gutils;
198
199 if (!valid)
200 return nullptr;
201
202 if (postpasses != "") {
203 mlir::PassManager pm(nf->getContext());
204 pm.enableVerifier(verifyPostPasses);
205 std::string error_message;
206 // llvm::raw_string_ostream error_stream(error_message);
207 mlir::LogicalResult result = mlir::parsePassPipeline(postpasses, pm);
208 if (mlir::failed(result)) {
209 return nullptr;
210 }
211
212 if (!mlir::succeeded(pm.run(nf))) {
213 return nullptr;
214 }
215 }
216
217 return nf;
218}
void createTerminator(DiffeGradientUtils *gutils, BasicBlock *oBB, DIFFE_TYPE retType, bool returnPrimal, bool returnShadow)
static std::map< K, V >::iterator insert_or_assign2(std::map< K, V > &map, K key, V val)
Insert into a map.
Definition Utils.h:846
DerivativeMode
Definition Utils.h:390
static MDiffeGradientUtils * CreateFromClone(MEnzymeLogic &Logic, DerivativeMode mode, unsigned width, FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo, const llvm::ArrayRef< bool > returnPrimals, const llvm::ArrayRef< bool > returnShadows, ArrayRef< DIFFE_TYPE > RetActivity, ArrayRef< DIFFE_TYPE > ArgActivity, mlir::Type additionalArg, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
FunctionOpInterface CreateForwardDiff(FunctionOpInterface fn, std::vector< DIFFE_TYPE > retType, std::vector< DIFFE_TYPE > constants, MTypeAnalysis &TA, std::vector< bool > returnPrimals, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector< bool > volatile_args, void *augmented, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
std::map< MForwardCacheKey, FunctionOpInterface > ForwardCachedFunctions
void erase(Operation *op)
mlir::Value invertPointerM(mlir::Value v, OpBuilder &Builder2)
SmallVector< mlir::Value, 1 > getNewFromOriginal(ValueRange originst) const
bool isConstantValue(mlir::Value v) const
void branchingForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils)