Enzyme main
Loading...
Searching...
No Matches
DifferentialUseAnalysis.h
Go to the documentation of this file.
1//===- DifferentialUseAnalysis.h - Determine values needed in reverse pass-===//
2//
3// Enzyme Project
4//
5// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9// If using this code in an academic setting, please cite the following:
10// @incollection{enzymeNeurips,
11// title = {Instead of Rewriting Foreign Code for Machine Learning,
12// Automatically Synthesize Fast Gradients},
13// author = {Moses, William S. and Churavy, Valentin},
14// booktitle = {Advances in Neural Information Processing Systems 33},
15// year = {2020},
16// note = {To appear in},
17// }
18//
19//===----------------------------------------------------------------------===//
20//
21// This file contains the declaration of Differential USe Analysis -- an
22// AD-specific analysis that deduces if a given value is needed in the reverse
23// pass.
24//
25//===----------------------------------------------------------------------===//
26
27#ifndef ENZYME_DIFFERENTIALUSEANALYSIS_H_
28#define ENZYME_DIFFERENTIALUSEANALYSIS_H_
29
30#include <map>
31#include <set>
32
33#include "llvm/IR/BasicBlock.h"
34#include "llvm/IR/Instruction.h"
35
36#include "llvm/ADT/ArrayRef.h"
37#include "llvm/ADT/SmallPtrSet.h"
38#include "llvm/ADT/SmallVector.h"
39
40#include "llvm/Support/Casting.h"
41#include "llvm/Support/ErrorHandling.h"
42
43#include "DiffeGradientUtils.h"
44#include "GradientUtils.h"
45#include "LibraryFuncs.h"
46
47extern "C" {
48extern llvm::cl::opt<bool> EnzymePrintDiffUse;
49}
50
51extern llvm::StringMap<
52 std::function<bool(const llvm::CallInst *, const GradientUtils *,
53 const llvm::Value *, bool, DerivativeMode, bool &)>>
55
56/// Classification of what type of use is requested
57enum class QueryType {
58 // The original value is needed for the derivative
59 Primal = 0,
60 // The shadow value is needed for the derivative
61 Shadow = 1,
62 // The primal value is needed to stand in for the shadow
63 // value and compute the derivative of an instruction
65};
66
67static inline std::string to_string(QueryType mode) {
68 switch (mode) {
70 return "Primal";
72 return "Shadow";
74 return "ShadowByConstPrimal";
75 }
76 llvm_unreachable("illegal QueryType");
77}
78
79typedef std::pair<const llvm::Value *, QueryType> UsageKey;
80
82
83/// Determine if a value is needed directly to compute the adjoint
84/// of the given instruction user. `shadow` denotes whether we are considering
85/// the shadow of the value (shadow=true) or the primal of the value
86/// (shadow=false).
87/// Recursive use is only usable in shadow mode.
89 const GradientUtils *gutils, const llvm::Value *val, DerivativeMode mode,
90 const llvm::Instruction *user,
91 const llvm::SmallPtrSetImpl<llvm::BasicBlock *> &oldUnreachable,
92 QueryType shadow, bool *recursiveUse = nullptr);
93
94template <QueryType VT, bool OneLevel = false>
96 const GradientUtils *gutils, const llvm::Value *inst, DerivativeMode mode,
97 std::map<UsageKey, bool> &seen,
98 const llvm::SmallPtrSetImpl<llvm::BasicBlock *> &oldUnreachable) {
99 using namespace llvm;
100
101 TypeResults const &TR = gutils->TR;
102 static_assert(VT == QueryType::Primal || VT == QueryType::Shadow ||
104 auto idx = UsageKey(inst, VT);
105 if (seen.find(idx) != seen.end())
106 return seen[idx];
107 if (auto ainst = dyn_cast<Instruction>(inst)) {
108 assert(ainst->getParent()->getParent() == gutils->oldFunc);
109 }
110
111 // Inductively claim we aren't needed (and try to find contradiction)
112 seen[idx] = false;
113
114 if (VT == QueryType::Primal) {
115 if (auto op = dyn_cast<BinaryOperator>(inst)) {
116 if (op->getOpcode() == Instruction::FDiv) {
117 if (!gutils->isConstantValue(const_cast<Value *>(inst)) &&
118 !gutils->isConstantValue(op->getOperand(1))) {
120 llvm::errs() << " Need: " << to_string(VT) << " of " << *inst
121 << " in reverse as is active div\n";
122 return seen[idx] = true;
123 }
124 }
125 }
126 if (gutils->mode == DerivativeMode::ForwardModeError &&
127 !gutils->isConstantValue(const_cast<Value *>(inst))) {
129 llvm::errs()
130 << " Need: " << to_string(VT) << " of " << *inst
131 << " in reverse as forward mode error always needs result\n";
132 return seen[idx] = true;
133 }
134 }
135
136 if (auto CI = dyn_cast<CallInst>(inst)) {
137 StringRef funcName = getFuncNameFromCall(const_cast<CallInst *>(CI));
138 if (funcName == "julia.get_pgcstack" || funcName == "julia.ptls_states")
139 return seen[idx] = true;
140 }
141
142 bool inst_cv = gutils->isConstantValue(const_cast<Value *>(inst));
143
144 // Consider all users of this value, do any of them need this in the reverse?
145 for (auto use : inst->users()) {
146 if (use == inst)
147 continue;
148
149 const Instruction *user = dyn_cast<Instruction>(use);
150
151 // A shadow value is only needed in reverse if it or one of its descendants
152 // is used in an active instruction.
153 // If inst is a constant value, the primal may be used in its place and
154 // thus required.
156 inst_cv) {
157 bool recursiveUse = false;
159 gutils, inst, mode, user, oldUnreachable,
162 &recursiveUse)) {
163 return seen[idx] = true;
164 }
165
166 if (recursiveUse && !OneLevel) {
167 bool val;
168 if (VT == QueryType::Shadow)
170 gutils, user, mode, seen, oldUnreachable);
171 else
173 gutils, user, mode, seen, oldUnreachable);
174 if (val) {
176 llvm::errs() << " Need: " << to_string(VT) << " of " << *inst
177 << " in reverse as shadow sub-need " << *user << "\n";
178 return seen[idx] = true;
179 }
180 }
181
182 if (!TR.allFloat(const_cast<Value *>(inst)))
183 if (auto IVI = dyn_cast<Instruction>(user)) {
184 bool inserted = false;
185 if (auto II = dyn_cast<InsertValueInst>(IVI))
186 inserted = II->getInsertedValueOperand() == inst ||
187 II->getAggregateOperand() == inst;
188 if (auto II = dyn_cast<ExtractValueInst>(IVI))
189 inserted = II->getAggregateOperand() == inst;
190 if (auto II = dyn_cast<InsertElementInst>(IVI))
191 inserted = II->getOperand(1) == inst || II->getOperand(0) == inst;
192 if (auto II = dyn_cast<ExtractElementInst>(IVI))
193 inserted = II->getOperand(0) == inst;
194 if (inserted) {
195 SmallVector<const Instruction *, 1> todo;
196 todo.push_back(IVI);
197 while (todo.size()) {
198 auto cur = todo.pop_back_val();
199 for (auto u : cur->users()) {
200 if (auto IVI2 = dyn_cast<InsertValueInst>(u)) {
201 todo.push_back(IVI2);
202 continue;
203 }
204 if (auto IVI2 = dyn_cast<ExtractValueInst>(u)) {
205 todo.push_back(IVI2);
206 continue;
207 }
208 if (auto IVI2 = dyn_cast<InsertElementInst>(u)) {
209 todo.push_back(IVI2);
210 continue;
211 }
212 if (auto IVI2 = dyn_cast<ExtractElementInst>(u)) {
213 todo.push_back(IVI2);
214 continue;
215 }
216
217 bool partial = false;
218 if (auto UI = dyn_cast<Instruction>(u)) {
219 if (!gutils->isConstantValue(
220 const_cast<Instruction *>(cur))) {
221 bool recursiveUse = false;
223 gutils, cur, mode, UI, oldUnreachable,
224 QueryType::Shadow, &recursiveUse)) {
225 partial = true;
226 } else if (recursiveUse && !OneLevel) {
228 gutils, UI, mode, seen, oldUnreachable);
229 }
230 } else if (VT == QueryType::Shadow) {
231 bool recursiveUse = false;
233 gutils, cur, mode, UI, oldUnreachable,
234 QueryType::ShadowByConstPrimal, &recursiveUse)) {
235 partial = true;
236 } else if (recursiveUse && !OneLevel) {
238 QueryType::ShadowByConstPrimal>(gutils, UI, mode,
239 seen, oldUnreachable);
240 }
241 }
242 }
243
244 if (partial) {
245
247 llvm::errs()
248 << " Need (partial) direct " << to_string(VT) << " of "
249 << *inst << " in reverse from insertelem " << *user
250 << " via " << *cur << " in " << *u << "\n";
251 return seen[idx] = true;
252 }
253 }
254 }
255 }
256 }
257
258 if (VT != QueryType::Primal)
259 continue;
260 }
261
262 assert(VT == QueryType::Primal);
263
264 // If a sub user needs, we need
265 if (!OneLevel && is_value_needed_in_reverse<VT>(gutils, user, mode, seen,
266 oldUnreachable)) {
268 llvm::errs() << " Need: " << to_string(VT) << "(" << mode << ") of "
269 << *inst << " in reverse as sub-need " << *user << "\n";
270 return seen[idx] = true;
271 }
272
273 // Anything we may try to rematerialize requires its store operands for
274 // the reverse pass.
275 if (!OneLevel) {
276 bool isStored = false;
277 if (auto SI = dyn_cast<StoreInst>(user))
278 isStored = inst == SI->getValueOperand();
279 else if (auto MTI = dyn_cast<MemTransferInst>(user)) {
280 isStored = inst == MTI->getSource() || inst == MTI->getLength();
281 } else if (auto MS = dyn_cast<MemSetInst>(user)) {
282 isStored = inst == MS->getLength() || inst == MS->getValue();
283 } else if (auto CB = dyn_cast<CallBase>(user)) {
284 auto name = getFuncNameFromCall(CB);
285 if (name == "julia.write_barrier" ||
286 name == "julia.write_barrier_binding") {
287 auto sz = CB->arg_size();
288 // First pointer is the destination
289 for (size_t i = 1; i < sz; i++)
290 isStored |= inst == CB->getArgOperand(i);
291 }
292 }
293 if (isStored) {
294 for (auto pair : gutils->rematerializableAllocations) {
295 // If already decided to cache the whole allocation, ignore
296 if (gutils->needsCacheWholeAllocation(pair.first)) {
297 continue;
298 }
299
300 // If caching the outer allocation and have already set that this is
301 // not needed return early. This is necessary to avoid unnecessarily
302 // deciding stored values are needed if we have already decided to
303 // cache the whole allocation.
304 auto found = seen.find(std::make_pair(pair.first, QueryType::Primal));
305 if (found != seen.end() && !found->second) {
306 continue;
307 }
308
309 // Directly consider all the load uses to avoid an illegal inductive
310 // recurrence. Specifically if we're asking if the alloca is used,
311 // we'll set it to unused, then check the gep, then here we'll
312 // directly say unused by induction instead of checking the final
313 // loads.
314 if (pair.second.stores.count(user)) {
315 for (LoadInst *L : pair.second.loads)
316 if (is_value_needed_in_reverse<VT>(gutils, L, mode, seen,
317 oldUnreachable)) {
319 llvm::errs() << " Need: " << to_string(VT) << " of " << *inst
320 << " in reverse as rematload " << *L << "\n";
321 return seen[idx] = true;
322 }
323 for (auto &pair : pair.second.loadLikeCalls)
325 gutils, pair.operand, mode, pair.loadCall, oldUnreachable,
327 is_value_needed_in_reverse<VT>(gutils, pair.loadCall, mode,
328 seen, oldUnreachable)) {
330 llvm::errs() << " Need: " << to_string(VT) << " of " << *inst
331 << " in reverse as rematloadcall "
332 << *pair.loadCall << "\n";
333 return seen[idx] = true;
334 }
335
336 if (is_value_needed_in_reverse<VT>(gutils, pair.first, mode, seen,
337 oldUnreachable)) {
339 llvm::errs()
340 << " Need: " << to_string(VT) << " of " << *inst
341 << " in reverse as rematalloc " << *pair.first << "\n";
342 return seen[idx] = true;
343 }
344 }
345 }
346 }
347 }
348
349 // One may need to this value in the computation of loop
350 // bounds/comparisons/etc (which even though not active -- will be used for
351 // the reverse pass)
352 // We could potentially optimize this to avoid caching if in combined mode
353 // and the instruction dominates all returns
354 // otherwise it will use the local cache (rather than save for a separate
355 // backwards cache)
356 // We also don't need this if looking at the shadow rather than primal
357 {
358 // Proving that none of the uses (or uses' uses) are used in control flow
359 // allows us to safely not do this load
360
361 // TODO save loop bounds for dynamic loop
362
363 // TODO make this more aggressive and dont need to save loop latch
364 if (isa<BranchInst>(use) || isa<SwitchInst>(use)) {
365 size_t num = 0;
366 for (auto suc : successors(cast<Instruction>(use)->getParent())) {
367 if (!oldUnreachable.count(suc)) {
368 num++;
369 }
370 }
371 if (num <= 1)
372 continue;
374 llvm::errs() << " Need: " << to_string(VT) << " of " << *inst
375 << " in reverse as control-flow " << *user << "\n";
376 return seen[idx] = true;
377 }
378
379 if (auto CI = dyn_cast<CallInst>(use)) {
380 if (auto F = CI->getCalledFunction()) {
381 if (F->getName() == "__kmpc_for_static_init_4" ||
382 F->getName() == "__kmpc_for_static_init_4u" ||
383 F->getName() == "__kmpc_for_static_init_8" ||
384 F->getName() == "__kmpc_for_static_init_8u") {
386 llvm::errs() << " Need: " << to_string(VT) << " of " << *inst
387 << " in reverse as omp init " << *user << "\n";
388 return seen[idx] = true;
389 }
390 }
391 }
392 }
393
394 // The following are types we know we don't need to compute adjoints
395
396 // If a primal value is needed to compute a shadow pointer (e.g. int offset
397 // in gep), it needs preserving.
398 bool primalUsedInShadowPointer = true;
399 if (isa<CastInst>(user) || isa<LoadInst>(user))
400 primalUsedInShadowPointer = false;
401 if (auto CI = dyn_cast<CallInst>(user)) {
402 auto funcName = getFuncNameFromCall(CI);
403 if (funcName == "julia.pointer_from_objref") {
404 primalUsedInShadowPointer = false;
405 }
406 if (funcName == "julia.gc_loaded") {
407 primalUsedInShadowPointer = false;
408 }
409 if (funcName.contains("__enzyme_todense")) {
410 primalUsedInShadowPointer = false;
411 }
412 if (funcName.contains("__enzyme_ignore_derivatives")) {
413 primalUsedInShadowPointer = false;
414 }
415 }
416 if (auto GEP = dyn_cast<GetElementPtrInst>(user)) {
417 bool idxUsed = false;
418 for (auto &idx : GEP->indices()) {
419 if (idx.get() == inst)
420 idxUsed = true;
421 }
422 if (!idxUsed)
423 primalUsedInShadowPointer = false;
424 }
425 if (auto II = dyn_cast<IntrinsicInst>(user)) {
426 if (isIntelSubscriptIntrinsic(*II)) {
427 const std::array<size_t, 4> idxArgsIndices{{0, 1, 2, 4}};
428 bool idxUsed = false;
429 for (auto i : idxArgsIndices) {
430 if (II->getOperand(i) == inst)
431 idxUsed = true;
432 }
433 if (!idxUsed)
434 primalUsedInShadowPointer = false;
435 }
436 }
437 // No need for insert/extractvalue since indices are unsigned
438 // not llvm runtime values
439 if (isa<InsertValueInst>(user) || isa<ExtractValueInst>(user))
440 primalUsedInShadowPointer = false;
441
442 if (primalUsedInShadowPointer)
443 if (!user->getType()->isVoidTy() &&
444 TR.anyPointer(const_cast<Instruction *>(user))) {
446 gutils, user, mode, seen, oldUnreachable)) {
448 llvm::errs() << " Need: " << to_string(VT) << " of " << *inst
449 << " in reverse as used to compute shadow ptr "
450 << *user << "\n";
451 return seen[idx] = true;
452 }
453 }
454
456 gutils, inst, mode, user, oldUnreachable, QueryType::Primal);
457 if (!direct)
458 continue;
459
460 if (inst->getType()->isTokenTy()) {
461 llvm::errs() << " need " << *inst << " via " << *user << "\n";
462 }
463 assert(!inst->getType()->isTokenTy());
464
465 return seen[idx] = true;
466 }
467 return false;
468}
469
470template <QueryType VT>
471static inline bool is_value_needed_in_reverse(
472 const GradientUtils *gutils, const llvm::Value *inst, DerivativeMode mode,
473 const llvm::SmallPtrSetImpl<llvm::BasicBlock *> &oldUnreachable) {
474 static_assert(VT == QueryType::Primal || VT == QueryType::Shadow);
475 std::map<UsageKey, bool> seen;
476 return is_value_needed_in_reverse<VT>(gutils, inst, mode, seen,
477 oldUnreachable);
478}
479
480struct Node {
481 llvm::Value *V;
483 Node(llvm::Value *V, bool outgoing) : V(V), outgoing(outgoing){};
484 bool operator<(const Node N) const {
485 if (V < N.V)
486 return true;
487 return !(N.V < V) && outgoing < N.outgoing;
488 }
489 void dump() {
490 if (V)
491 llvm::errs() << "[" << *V << ", " << (int)outgoing << "]\n";
492 else
493 llvm::errs() << "[" << V << ", " << (int)outgoing << "]\n";
494 }
495};
496
497using Graph = std::map<Node, std::set<Node>>;
498
499void dump(std::map<Node, std::set<Node>> &G);
500
501/* Returns true if there is a path from source 's' to sink 't' in
502 residual graph. Also fills parent[] to store the path */
503void bfs(const std::map<Node, std::set<Node>> &G,
504 const llvm::SetVector<llvm::Value *> &Recompute,
505 std::map<Node, Node> &parent);
506
507// Return 1 if next is better
508// 0 if equal
509// -1 if prev is better, or unknown
510int cmpLoopNest(llvm::Loop *prev, llvm::Loop *next);
511
512void minCut(const llvm::DataLayout &DL, llvm::LoopInfo &OrigLI,
513 const llvm::SetVector<llvm::Value *> &Recomputes,
514 const llvm::SetVector<llvm::Value *> &Intermediates,
515 llvm::SetVector<llvm::Value *> &Required,
516 llvm::SetVector<llvm::Value *> &MinReq, const GradientUtils *gutils,
517 llvm::TargetLibraryInfo &TLI);
518
519__attribute__((always_inline)) static inline void
520forEachDirectInsertUser(llvm::function_ref<void(llvm::Instruction *)> f,
521 const GradientUtils *gutils, llvm::Instruction *IVI,
522 llvm::Value *val, bool useCheck) {
523 using namespace llvm;
524 if (!gutils->isConstantValue(IVI))
525 return;
526 bool inserted = false;
527 if (auto II = dyn_cast<InsertValueInst>(IVI))
528 inserted = II->getInsertedValueOperand() == val ||
529 II->getAggregateOperand() == val;
530 if (auto II = dyn_cast<ExtractValueInst>(IVI))
531 inserted = II->getAggregateOperand() == val;
532 if (auto II = dyn_cast<InsertElementInst>(IVI))
533 inserted = II->getOperand(1) == val || II->getOperand(0) == val;
534 if (auto II = dyn_cast<ExtractElementInst>(IVI))
535 inserted = II->getOperand(0) == val;
536 if (inserted) {
537 SmallVector<Instruction *, 1> todo;
538 todo.push_back(IVI);
539 while (todo.size()) {
540 auto cur = todo.pop_back_val();
541 for (auto u : cur->users()) {
542 if (isa<InsertValueInst>(u) || isa<InsertElementInst>(u) ||
543 isa<ExtractValueInst>(u) || isa<ExtractElementInst>(u)) {
544 auto I2 = cast<Instruction>(u);
545 bool subCheck = useCheck;
546 if (!subCheck) {
548 gutils, I2, gutils->mode, gutils->notForAnalysis);
549 }
550 if (subCheck)
551 f(I2);
552 todo.push_back(I2);
553 continue;
554 }
555 }
556 }
557 }
558}
559
560__attribute__((always_inline)) static inline void
561forEachDifferentialUser(llvm::function_ref<void(llvm::Value *)> f,
562 const GradientUtils *gutils, llvm::Value *V,
563 bool useCheck = false) {
564 for (auto V2 : V->users()) {
565 if (auto Inst = llvm::dyn_cast<llvm::Instruction>(V2)) {
566 for (const auto &pair : gutils->rematerializableAllocations) {
567 if (pair.second.stores.count(Inst)) {
568 f(llvm::cast<llvm::Instruction>(pair.first));
569 }
570 }
571 f(Inst);
572 forEachDirectInsertUser(f, gutils, Inst, V, useCheck);
573 }
574 }
575}
576
577//! Return whether or not this is a constant and should use reverse pass
579 llvm::CallBase &orig, QueryType qtype,
580 const llvm::Value *val);
581
582}; // namespace DifferentialUseAnalysis
583
584#endif
static std::string to_string(QueryType mode)
llvm::StringMap< std::function< bool(const llvm::CallInst *, const GradientUtils *, const llvm::Value *, bool, DerivativeMode, bool &)> > customDiffUseHandlers
QueryType
Classification of what type of use is requested.
@ ShadowByConstPrimal
std::pair< const llvm::Value *, QueryType > UsageKey
llvm::cl::opt< bool > EnzymePrintDiffUse
constexpr const char * to_string(ActivityAnalyzer::UseActivity UA)
static bool isIntelSubscriptIntrinsic(const llvm::IntrinsicInst &II)
Definition Utils.h:1445
static llvm::StringRef getFuncNameFromCall(const llvm::CallBase *op)
Definition Utils.h:1269
DerivativeMode
Definition Utils.h:390
DerivativeMode mode
TypeResults TR
llvm::ValueMap< llvm::Value *, Rematerializer > rematerializableAllocations
llvm::Function * oldFunc
bool isConstantValue(llvm::Value *val) const
bool needsCacheWholeAllocation(const llvm::Value *V) const
llvm::SmallPtrSet< llvm::BasicBlock *, 4 > notForAnalysis
A holder class representing the results of running TypeAnalysis on a given function.
bool anyPointer(llvm::Value *val) const
Whether any part of the top level register can contain a pointer e.g.
bool allFloat(llvm::Value *val) const
Whether all of the top level register is known to contain float data.
void minCut(const llvm::DataLayout &DL, llvm::LoopInfo &OrigLI, const llvm::SetVector< llvm::Value * > &Recomputes, const llvm::SetVector< llvm::Value * > &Intermediates, llvm::SetVector< llvm::Value * > &Required, llvm::SetVector< llvm::Value * > &MinReq, const GradientUtils *gutils, llvm::TargetLibraryInfo &TLI)
bool callShouldNotUseDerivative(const GradientUtils *gutils, llvm::CallBase &orig, QueryType qtype, const llvm::Value *val)
Return whether or not this is a constant and should use reverse pass.
bool is_value_needed_in_reverse(const GradientUtils *gutils, const llvm::Value *inst, DerivativeMode mode, std::map< UsageKey, bool > &seen, const llvm::SmallPtrSetImpl< llvm::BasicBlock * > &oldUnreachable)
void bfs(const std::map< Node, std::set< Node > > &G, const llvm::SetVector< llvm::Value * > &Recompute, std::map< Node, Node > &parent)
void dump(std::map< Node, std::set< Node > > &G)
bool is_use_directly_needed_in_reverse(const GradientUtils *gutils, const llvm::Value *val, DerivativeMode mode, const llvm::Instruction *user, const llvm::SmallPtrSetImpl< llvm::BasicBlock * > &oldUnreachable, QueryType shadow, bool *recursiveUse=nullptr)
Determine if a value is needed directly to compute the adjoint of the given instruction user.
std::map< Node, std::set< Node > > Graph
__attribute__((always_inline)) static inline void forEachDirectInsertUser(llvm
int cmpLoopNest(llvm::Loop *prev, llvm::Loop *next)
Node(llvm::Value *V, bool outgoing)