Enzyme main
Loading...
Searching...
No Matches
GradientUtils.cpp
Go to the documentation of this file.
1//===- GradientUtils.cpp - Utilities for gradient interfaces --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "Dialect/Ops.h"
14
15#include "mlir/IR/Matchers.h"
16#include "mlir/IR/SymbolTable.h"
17#include "mlir/Interfaces/FunctionInterfaces.h"
18
19// TODO: this shouldn't depend on specific dialects except Enzyme.
20#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21
22#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
23#include "mlir/Dialect/Func/IR/FuncOps.h"
24#include "mlir/IR/Dominance.h"
25#include "llvm/ADT/BreadthFirstIterator.h"
26
27using namespace mlir;
28using namespace mlir::enzyme;
29
31 MEnzymeLogic &Logic, FunctionOpInterface newFunc_,
32 FunctionOpInterface oldFunc_, MTypeAnalysis &TA_, MTypeResults TR_,
33 IRMapping &invertedPointers_, const llvm::ArrayRef<bool> returnPrimals,
34 const llvm::ArrayRef<bool> returnShadows,
35 const SmallPtrSetImpl<mlir::Value> &constantvalues_,
36 const SmallPtrSetImpl<mlir::Value> &activevals_,
37 ArrayRef<DIFFE_TYPE> ReturnActivity, ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
38 IRMapping &originalToNewFn_,
39 std::map<Operation *, Operation *> &originalToNewFnOps_,
40 DerivativeMode mode, unsigned width, bool omp, llvm::StringRef postpasses,
41 bool verifyPostPasses, bool strongZero)
42 : newFunc(newFunc_), Logic(Logic), AtomicAdd(false), mode(mode),
43 oldFunc(oldFunc_), invertedPointers(invertedPointers_),
44 originalToNewFn(originalToNewFn_),
45 originalToNewFnOps(originalToNewFnOps_), blocksNotForAnalysis(),
46 activityAnalyzer(std::make_unique<enzyme::ActivityAnalyzer>(
47 blocksNotForAnalysis, readOnlyCache, constantvalues_, activevals_,
48 ReturnActivity)),
49 TA(TA_), TR(TR_), omp(omp), verifyPostPasses(verifyPostPasses),
50 postpasses(postpasses), strongZero(strongZero),
51 returnPrimals(returnPrimals), returnShadows(returnShadows), width(width),
52 ArgDiffeTypes(ArgDiffeTypes_), RetDiffeTypes(ReturnActivity) {}
53
55 const mlir::Value originst) const {
56 if (!originalToNewFn.contains(originst)) {
57 llvm::errs() << oldFunc << "\n";
58 llvm::errs() << newFunc << "\n";
59 llvm::errs() << originst << "\n";
60 llvm_unreachable("Could not get new val from original");
61 }
62 return originalToNewFn.lookupOrNull(originst);
63}
64
65SmallVector<mlir::Value, 1>
67 SmallVector<mlir::Value, 1> results;
68 for (auto op : originst) {
69 results.push_back(getNewFromOriginal(op));
70 }
71 return results;
72}
73
74Block *
76 if (!originalToNewFn.contains(originst)) {
77 llvm::errs() << oldFunc << "\n";
78 llvm::errs() << newFunc << "\n";
79 llvm::errs() << originst << "\n";
80 llvm_unreachable("Could not get new blk from original");
81 }
82 return originalToNewFn.lookupOrNull(originst);
83}
84
85Operation *
87 assert(originst);
88 auto found = originalToNewFnOps.find(originst);
89 if (found == originalToNewFnOps.end()) {
90 llvm::errs() << oldFunc << "\n";
91 llvm::errs() << newFunc << "\n";
92 for (auto &pair : originalToNewFnOps) {
93 llvm::errs() << " map[" << pair.first << "] = " << pair.second << "\n";
94 llvm::errs() << " map[" << *pair.first << "] = " << *pair.second << "\n";
95 }
96 llvm::errs() << originst << " - " << *originst << "\n";
97 llvm_unreachable("Could not get new op from original");
98 }
99 return found->second;
100}
101
103 Operation *op) {
104 IRMapping map;
105 for (auto operand : op->getOperands())
106 map.map(operand, getNewFromOriginal(operand));
107 return B.clone(*op, map);
108}
111 return activityAnalyzer->isConstantOperation(TR, op);
112}
114 return activityAnalyzer->isConstantValue(TR, v);
115}
116
118 OpBuilder &Builder2) {
119 // TODO
120 if (invertedPointers.contains(v))
121 return invertedPointers.lookupOrNull(v);
123 if (isConstantValue(v)) {
124 if (auto iface =
125 dyn_cast<AutoDiffTypeInterface>(getShadowType(v.getType()))) {
126 OpBuilder::InsertionGuard guard(Builder2);
127 if (auto op = v.getDefiningOp())
128 Builder2.setInsertionPoint(getNewFromOriginal(op));
129 else {
130 auto ba = cast<BlockArgument>(v);
131 Builder2.setInsertionPointToStart(getNewFromOriginal(ba.getOwner()));
132 }
133 Value dv = iface.createNullValue(Builder2, v.getLoc());
134 invertedPointers.map(v, dv);
135 return dv;
136 }
137 return getNewFromOriginal(v);
138 }
139 llvm::errs() << " could not invert pointer v " << v << "\n";
140 llvm_unreachable("could not invert pointer");
141}
142
144 std::function<Value(Location, Type)> hook) {
145 if (hook != nullptr)
146 gradientCreatorHook.push_back(hook);
147}
148
150 std::function<Value(Location, Type)> hook) {
151 if (hook != nullptr)
152 gradientCreatorHook.pop_back();
153}
154
155Value MDiffeGradientUtils::getNewGradient(Location loc, Type t) {
156 if (gradientCreatorHook.empty()) {
157 auto shadowty = getShadowType(t);
158 OpBuilder builder(t.getContext());
159 builder.setInsertionPointToStart(initializationBlock);
160
161 auto shadow = enzyme::InitOp::create(
162 builder, loc, enzyme::GradientType::get(t.getContext(), shadowty));
163 auto toset =
164 cast<AutoDiffTypeInterface>(shadowty).createNullValue(builder, loc);
165 enzyme::SetOp::create(builder, loc, shadow, toset);
166 return shadow;
167 } else {
168 return gradientCreatorHook.back()(loc, t);
169 }
170}
171
172mlir::Value
174 auto found = differentials.lookupOrNull(oval);
175 if (found != nullptr)
176 return found;
177
178 Value shadow = getNewGradient(oval.getLoc(), oval.getType());
179 differentials.map(oval, shadow);
180 return shadow;
181}
182
184 mlir::Value toset,
185 OpBuilder &BuilderM) {
186 assert(!isConstantValue(oval));
187 auto iface = cast<AutoDiffTypeInterface>(oval.getType());
188 if (!iface.isMutable()) {
189 auto shadow = getDifferential(oval);
190 enzyme::SetOp::create(BuilderM, oval.getLoc(), shadow, toset);
191 } else {
192 MGradientUtils::setDiffe(oval, toset, BuilderM);
193 }
194}
195
197 OpBuilder &BuilderM) {
198 assert(!isConstantValue(oval));
199 auto iface = cast<AutoDiffTypeInterface>(getShadowType(oval.getType()));
200 assert(!iface.isMutable());
201 setDiffe(oval, iface.createNullValue(BuilderM, oval.getLoc()), BuilderM);
202}
203
204mlir::Value mlir::enzyme::MDiffeGradientUtils::diffe(mlir::Value oval,
205 OpBuilder &BuilderM) {
206
207 auto shadow = getDifferential(oval);
208 return enzyme::GetOp::create(BuilderM, oval.getLoc(),
209 getShadowType(oval.getType()), shadow);
210}
211
212void mlir::enzyme::MGradientUtils::setDiffe(mlir::Value val, mlir::Value toset,
213 OpBuilder &BuilderM) {
214 /*
215 if (auto arg = dyn_cast<Argument>(val))
216 assert(arg->getParent() == oldFunc);
217 if (auto inst = dyn_cast<Instruction>(val))
218 assert(inst->getParent()->getParent() == oldFunc);
219 */
220 if (isConstantValue(val)) {
221 llvm::errs() << newFunc << "\n";
222 llvm::errs() << val << "\n";
223 }
224 assert(!isConstantValue(val));
225
226 if (mode == DerivativeMode::ForwardMode ||
228 setInvertedPointer(val, toset);
229 }
230 /*
231 Value *tostore = getDifferential(val);
232 if (toset->getType() != tostore->getType()->getPointerElementType()) {
233 llvm::errs() << "toset:" << *toset << "\n";
234 llvm::errs() << "tostore:" << *tostore << "\n";
235 }
236 assert(toset->getType() == tostore->getType()->getPointerElementType());
237 BuilderM.CreateStore(toset, tostore);
238 */
239}
240
242 assert(getShadowType(val.getType()) == toset.getType());
243 auto found = invertedPointers.lookupOrNull(val);
244 assert(found != nullptr);
245 auto placeholder = found.getDefiningOp<enzyme::PlaceholderOp>();
246 placeholder.replaceAllUsesWith(toset);
247 erase(placeholder);
248 invertedPointers.map(val, toset);
249}
250
252 // TODO also block arguments
253 // assert(TR.getFunction() == oldFunc);
254
255 // Don't create derivatives for code that results in termination
256 // if (notForAnalysis.find(&oBB) != notForAnalysis.end())
257 // continue;
258
259 // LoopContext loopContext;
260 // getContext(cast<BasicBlock>(getNewFromOriginal(&oBB)), loopContext);
261
262 oldFunc.walk([&](Block *blk) {
263 if (blk == &oldFunc.getFunctionBody().getBlocks().front())
264 return;
265 auto nblk = getNewFromOriginal(blk);
266 for (auto val : llvm::reverse(blk->getArguments())) {
267 if (isConstantValue(val))
268 continue;
269 auto i = val.getArgNumber();
270 if (mode == DerivativeMode::ForwardMode ||
272 cast<AutoDiffTypeInterface>(val.getType()).isMutable()) {
273 mlir::Value dval;
274 if (i == blk->getArguments().size() - 1)
275 dval = nblk->addArgument(getShadowType(val.getType()), val.getLoc());
276 else
277 dval =
278 nblk->insertArgument(nblk->args_begin() + i + 1,
279 getShadowType(val.getType()), val.getLoc());
280
281 invertedPointers.map(val, dval);
282 }
283 }
284 });
285
286 oldFunc.walk([&](Operation *inst) {
287 if (inst == oldFunc)
288 return;
289
290 OpBuilder BuilderZ(getNewFromOriginal(inst));
291 for (auto res : inst->getResults()) {
292 if (isConstantValue(res))
293 continue;
294
295 if (!(mode == DerivativeMode::ForwardMode ||
297 cast<AutoDiffTypeInterface>(res.getType()).isMutable()))
298 continue;
299 mlir::Type antiTy = getShadowType(res.getType());
300 auto anti = enzyme::PlaceholderOp::create(BuilderZ, res.getLoc(), antiTy);
301 invertedPointers.map(res, anti);
302 }
303 });
304}
305
306LogicalResult MGradientUtils::visitChild(Operation *op) {
308 if ((op->getBlock()->getTerminator() != op) &&
309 llvm::all_of(op->getResults(),
310 [this](Value v) { return isConstantValue(v); }) &&
311 /*iface.hasNoEffect()*/ activityAnalyzer->isConstantOperation(TR, op)) {
312 return success();
313 }
314 // }
315 if (auto iface = dyn_cast<AutoDiffOpInterface>(op)) {
316 OpBuilder builder(op->getContext());
317 builder.setInsertionPoint(getNewFromOriginal(op));
318 return iface.createForwardModeTangent(builder, this);
319 }
320 }
321 return op->emitError() << "could not compute the adjoint for this operation "
322 << *op;
323}
Type getShadowType(Type type, unsigned width)
DerivativeMode
Definition Utils.h:390
Helper class to analyze the differential activity.
SmallVector< std::function< Value(Location, Type)> > gradientCreatorHook
void registerGradientCreatorHook(std::function< Value(Location, Type)> hook)
mlir::Value getDifferential(mlir::Value origv)
void setDiffe(mlir::Value origv, mlir::Value newv, mlir::OpBuilder &builder)
void deregisterGradientCreatorHook(std::function< Value(Location, Type)> hook)
Value getNewGradient(Location loc, Type t)
mlir::Value diffe(mlir::Value origv, mlir::OpBuilder &builder)
void zeroDiffe(mlir::Value origv, mlir::OpBuilder &builder)
Operation * cloneWithNewOperands(OpBuilder &B, Operation *op)
mlir::Type getShadowType(mlir::Type T)
LogicalResult visitChild(Operation *op)
void setDiffe(mlir::Value origv, mlir::Value newv, mlir::OpBuilder &builder)
mlir::Value invertPointerM(mlir::Value v, OpBuilder &Builder2)
SmallVector< mlir::Value, 1 > getNewFromOriginal(ValueRange originst) const
MGradientUtils(MEnzymeLogic &Logic, FunctionOpInterface newFunc_, FunctionOpInterface oldFunc_, MTypeAnalysis &TA_, MTypeResults TR_, IRMapping &invertedPointers_, const llvm::ArrayRef< bool > returnPrimals, const llvm::ArrayRef< bool > returnShadows, const SmallPtrSetImpl< mlir::Value > &constantvalues_, const SmallPtrSetImpl< mlir::Value > &activevals_, ArrayRef< DIFFE_TYPE > ReturnActivities, ArrayRef< DIFFE_TYPE > ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map< Operation *, Operation * > &originalToNewFnOps_, DerivativeMode mode, unsigned width, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
std::unique_ptr< enzyme::ActivityAnalyzer > activityAnalyzer
bool isConstantInstruction(mlir::Operation *v) const
void setInvertedPointer(mlir::Value origv, mlir::Value newv)
bool isConstantValue(mlir::Value v) const