Enzyme main
Loading...
Searching...
No Matches
DifferentialUseAnalysis.cpp
Go to the documentation of this file.
1//===- DifferentialUseAnalysis.cpp - Determine values needed in reverse
2// pass-===//
3//
4// Enzyme Project
5//
6// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
7// Exceptions. See https://llvm.org/LICENSE.txt for license information.
8// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
9//
10// If using this code in an academic setting, please cite the following:
11// @incollection{enzymeNeurips,
12// title = {Instead of Rewriting Foreign Code for Machine Learning,
13// Automatically Synthesize Fast Gradients},
14// author = {Moses, William S. and Churavy, Valentin},
15// booktitle = {Advances in Neural Information Processing Systems 33},
16// year = {2020},
17// note = {To appear in},
18// }
19//
20//===----------------------------------------------------------------------===//
21//
22// This file contains the declaration of Differential USe Analysis -- an
23// AD-specific analysis that deduces if a given value is needed in the reverse
24// pass.
25//
26//===----------------------------------------------------------------------===//
27
28#include <deque>
29#include <map>
30#include <set>
31
33#include "Utils.h"
34
35#include "llvm/IR/BasicBlock.h"
36#include "llvm/IR/Instruction.h"
37#include "llvm/IR/IntrinsicsX86.h"
38
39#include "llvm/ADT/ArrayRef.h"
40#include "llvm/ADT/SmallPtrSet.h"
41#include "llvm/ADT/SmallVector.h"
42
43#include "llvm/Support/Casting.h"
44#include "llvm/Support/ErrorHandling.h"
45
46#include "DiffeGradientUtils.h"
47#include "GradientUtils.h"
48#include "LibraryFuncs.h"
49
50using namespace llvm;
51
52StringMap<std::function<bool(const CallInst *, const GradientUtils *,
53 const Value *, bool, DerivativeMode, bool &)>>
55
57 const GradientUtils *gutils, const Value *val, DerivativeMode mode,
58 const Instruction *user,
59 const SmallPtrSetImpl<BasicBlock *> &oldUnreachable, QueryType qtype,
60 bool *recursiveUse) {
61 TypeResults const &TR = gutils->TR;
62#ifndef NDEBUG
63 if (auto ainst = dyn_cast<Instruction>(val)) {
64 assert(ainst->getParent()->getParent() == gutils->oldFunc);
65 }
66#endif
67
68 bool shadow =
70
71 /// Recursive use is only usable in shadow mode.
72 if (!shadow)
73 assert(recursiveUse == nullptr);
74 else
75 assert(recursiveUse != nullptr);
76
77 if (!shadow && isPointerArithmeticInst(user, /*includephi*/ true,
78 /*includebin*/ false)) {
79 return false;
80 }
81
82 // Floating point numbers cannot be used as a shadow pointer/etc
84 if (TR.query(const_cast<Value *>(val))[{-1}].isFloat())
85 return false;
86
87 if (!user) {
89 llvm::errs() << " Need: of " << *val << " in reverse as nullptr user\n";
90 return true;
91 }
92
93 assert(user->getParent()->getParent() == gutils->oldFunc);
94
95 if (oldUnreachable.count(user->getParent()))
96 return false;
97
98 if (auto SI = dyn_cast<StoreInst>(user)) {
99 if (!shadow) {
100
101 // We don't need any of the input operands to compute the adjoint of a
102 // store instance The one exception to this is stores to the loop bounds.
103 if (SI->getValueOperand() == val) {
104 for (auto U : SI->getPointerOperand()->users()) {
105 if (auto CI = dyn_cast<CallInst>(U)) {
106 if (auto F = CI->getCalledFunction()) {
107 if (F->getName() == "__kmpc_for_static_init_4" ||
108 F->getName() == "__kmpc_for_static_init_4u" ||
109 F->getName() == "__kmpc_for_static_init_8" ||
110 F->getName() == "__kmpc_for_static_init_8u") {
111 if (CI->getArgOperand(4) == val ||
112 CI->getArgOperand(5) == val || CI->getArgOperand(6)) {
114 llvm::errs() << " Need direct primal of " << *val
115 << " in reverse from omp " << *user << "\n";
116 return true;
117 }
118 }
119 }
120 }
121 }
122 }
123
124 // And runtime activity updates
125 if (gutils->runtimeActivity && SI->getPointerOperand() == val) {
126 auto &DL = gutils->newFunc->getParent()->getDataLayout();
127 auto ET = SI->getValueOperand()->getType();
128 auto storeSize = (DL.getTypeSizeInBits(ET) + 7) / 8;
129 auto vd = TR.query(const_cast<Value *>(SI->getPointerOperand()))
130 .Lookup(storeSize, DL);
131 if (!vd.isKnown()) {
132 // It verbatim needs to replicate the same behavior as
133 // adjointgenerator. From reverse mode type analysis
134 // (https://github.com/EnzymeAD/Enzyme/blob/194875cbccd73d63cacfefbfa85c1f583c2fa1fe/enzyme/Enzyme/AdjointGenerator.h#L556)
135 if (looseTypeAnalysis || true) {
136 vd = defaultTypeTreeForLLVM(ET, const_cast<StoreInst *>(SI));
137 }
138 }
139 bool hasFloat = true;
140 for (ssize_t i = -1; i < (ssize_t)storeSize; ++i) {
141 if (vd[{(int)i}].isFloat()) {
142 hasFloat = true;
143 break;
144 }
145 }
146 if (hasFloat && !gutils->isConstantValue(const_cast<llvm::Value *>(
147 SI->getPointerOperand()))) {
149 llvm::errs() << " Need direct primal of " << *val
150 << " in reverse from runtime active store " << *user
151 << "\n";
152 return true;
153 }
154 }
155 } else {
156 bool backwardsShadow = false;
157 bool forwardsShadow = true;
158 for (auto pair : gutils->backwardsOnlyShadows) {
159 if (pair.second.stores.count(SI) &&
160 !gutils->isConstantValue(pair.first)) {
161 backwardsShadow = true;
162 forwardsShadow = pair.second.primalInitialize;
163 }
164 }
165
166 // Preserve any non-floating point values that are stored in an active
167 // backwards creation shadow.
168
169 if (SI->getValueOperand() == val) {
170 // storing an active pointer into a location
171 // doesn't require the shadow pointer for the
172 // reverse pass
173 // Unless the store is into a backwards store, which would
174 // would then be performed in the reverse if the stored value was
175 // a possible pointer.
176
177 if (!((mode == DerivativeMode::ReverseModePrimal && forwardsShadow) ||
179 backwardsShadow) ||
180 (mode == DerivativeMode::ForwardModeSplit && backwardsShadow) ||
182 (forwardsShadow || backwardsShadow)) ||
185 return false;
186 } else {
187 // Likewise, if not rematerializing in reverse pass, you
188 // don't need to keep the pointer operand for known pointers
189
190 auto ct = TR.query(const_cast<Value *>(SI->getValueOperand()))[{-1}];
191 if (ct == BaseType::Pointer || ct == BaseType::Integer) {
192
193 if (!((mode == DerivativeMode::ReverseModePrimal && forwardsShadow) ||
195 backwardsShadow) ||
196 (mode == DerivativeMode::ForwardModeSplit && backwardsShadow) ||
198 (forwardsShadow || backwardsShadow)) ||
201 return false;
202 }
203 }
204
205 if (!gutils->isConstantValue(
206 const_cast<Value *>(SI->getPointerOperand()))) {
208 llvm::errs() << " Need: shadow of " << *val
209 << " in reverse as shadow store " << *SI << "\n";
210 return true;
211 } else
212 return false;
213 }
214 return false;
215 }
216
217 if (!shadow)
218 if (auto LI = dyn_cast<LoadInst>(user)) {
219 if (gutils->runtimeActivity) {
220 auto vd = TR.query(const_cast<llvm::Instruction *>(user));
221 if (!vd.isKnown()) {
222 auto ET = LI->getType();
223 // It verbatim needs to replicate the same behavior as
224 // adjointgenerator. From reverse mode type analysis
225 // (https://github.com/EnzymeAD/Enzyme/blob/194875cbccd73d63cacfefbfa85c1f583c2fa1fe/enzyme/Enzyme/AdjointGenerator.h#L556)
226 if (looseTypeAnalysis || true) {
227 vd = defaultTypeTreeForLLVM(ET, const_cast<LoadInst *>(LI));
228 }
229 }
230 auto &DL = gutils->newFunc->getParent()->getDataLayout();
231 auto LoadSize = (DL.getTypeSizeInBits(LI->getType()) + 1) / 8;
232 bool hasFloat = true;
233 for (ssize_t i = -1; i < (ssize_t)LoadSize; ++i) {
234 if (vd[{(int)i}].isFloat()) {
235 hasFloat = true;
236 break;
237 }
238 }
239 if (hasFloat && !gutils->isConstantInstruction(
240 const_cast<llvm::Instruction *>(user))) {
242 llvm::errs() << " Need direct primal of " << *val
243 << " in reverse from runtime active load " << *user
244 << "\n";
245 return true;
246 }
247 }
248 return false;
249 }
250
251 if (auto MTI = dyn_cast<MemTransferInst>(user)) {
252 // If memtransfer, only the primal of the size is needed reverse pass
253 if (!shadow) {
254 // Unless we're storing into a backwards only shadow store
255 if (MTI->getArgOperand(1) == val || MTI->getArgOperand(2) == val) {
256 for (auto pair : gutils->backwardsOnlyShadows)
257 if (pair.second.stores.count(MTI)) {
259 llvm::errs() << " Need direct primal of " << *val
260 << " in reverse from remat memtransfer " << *user
261 << "\n";
262 return true;
263 }
264 }
265 if (MTI->getArgOperand(2) != val)
266 return false;
267 bool res = !gutils->isConstantValue(MTI->getArgOperand(0));
268 if (res) {
270 llvm::errs() << " Need direct primal of " << *val
271 << " in reverse from memtransfer " << *user << "\n";
272 }
273 return res;
274 } else {
275
276 if (MTI->getArgOperand(0) != val && MTI->getArgOperand(1) != val)
277 return false;
278
279 if (!gutils->isConstantValue(
280 const_cast<Value *>(MTI->getArgOperand(0)))) {
282 llvm::errs() << " Need: shadow of " << *val
283 << " in reverse as shadow MTI " << *MTI << "\n";
284 return true;
285 } else
286 return false;
287 }
288 }
289
290 if (auto MS = dyn_cast<MemSetInst>(user)) {
291 if (!shadow) {
292 // Preserve the primal of length of memsets of backward creation shadows,
293 // or if float-like and non constant value.
294 if (MS->getArgOperand(1) == val || MS->getArgOperand(2) == val) {
295 for (auto pair : gutils->backwardsOnlyShadows)
296 if (pair.second.stores.count(MS)) {
298 llvm::errs() << " Need direct primal of " << *val
299 << " in reverse from remat memset " << *user << "\n";
300 return true;
301 }
302 bool res = !gutils->isConstantValue(MS->getArgOperand(0));
303 if (res) {
305 llvm::errs() << " Need direct primal of " << *val
306 << " in reverse from memset " << *user << "\n";
307 }
308 return res;
309 }
310 } else {
311
312 if (MS->getArgOperand(0) != val)
313 return false;
314
315 if (!gutils->isConstantValue(const_cast<Value *>(MS->getArgOperand(0)))) {
317 llvm::errs() << " Need: shadow of " << *val
318 << " in reverse as shadow MS " << *MS << "\n";
319 return true;
320 } else
321 return false;
322 }
323 }
324
325 if (!shadow)
326 if (isa<CmpInst>(user) || isa<BranchInst>(user) || isa<ReturnInst>(user) ||
327 isa<FPExtInst>(user) || isa<FPTruncInst>(user)
328 // isa<ExtractElement>(use) ||
329 // isa<InsertElementInst>(use) || isa<ShuffleVectorInst>(use) ||
330 // isa<ExtractValueInst>(use) || isa<AllocaInst>(use)
331 // || isa<StoreInst>(use)
332 ) {
333 return false;
334 }
335
336 if (!shadow)
337 if (auto IEI = dyn_cast<InsertElementInst>(user)) {
338 // Only need the index in the reverse, so if the value is not
339 // the index, short circuit and say we don't need
340 if (IEI->getOperand(2) != val) {
341 return false;
342 }
343 // The index is only needed in the reverse if the value being inserted
344 // is a possible active floating point value
345 if (gutils->isConstantValue(const_cast<InsertElementInst *>(IEI)) ||
346 TR.query(const_cast<InsertElementInst *>(IEI))[{-1}] ==
348 return false;
349 // Otherwise, we need the value.
351 llvm::errs() << " Need direct primal of " << *val
352 << " in reverse from non-pointer insertelem " << *user
353 << " "
354 << TR.query(const_cast<InsertElementInst *>(IEI)).str()
355 << "\n";
356 return true;
357 }
358
359 if (!shadow)
360 if (auto EEI = dyn_cast<ExtractElementInst>(user)) {
361 // Only need the index in the reverse, so if the value is not
362 // the index, short circuit and say we don't need
363 if (EEI->getIndexOperand() != val) {
364 return false;
365 }
366 // The index is only needed in the reverse if the value being inserted
367 // is a possible active floating point value
368 if (gutils->isConstantValue(const_cast<ExtractElementInst *>(EEI)) ||
369 TR.query(const_cast<ExtractElementInst *>(EEI))[{-1}] ==
371 return false;
372 // Otherwise, we need the value.
374 llvm::errs() << " Need direct primal of " << *val
375 << " in reverse from non-pointer extractelem " << *user
376 << " "
377 << TR.query(const_cast<ExtractElementInst *>(EEI)).str()
378 << "\n";
379 return true;
380 }
381
382 if (!shadow)
383 if (auto IVI = dyn_cast<InsertValueInst>(user)) {
384 // Only need the index in the reverse, so if the value is not
385 // the index, short circuit and say we don't need
386 bool valueIsIndex = false;
387 for (unsigned i = 2; i < IVI->getNumOperands(); ++i) {
388 if (IVI->getOperand(i) == val) {
389 valueIsIndex = true;
390 }
391 }
392
393 if (!valueIsIndex)
394 return false;
395
396 // The index is only needed in the reverse if the value being inserted
397 // is a possible active floating point value
398 if (gutils->isConstantValue(const_cast<InsertValueInst *>(IVI)) ||
399 TR.query(const_cast<InsertValueInst *>(IVI))[{-1}] ==
401 return false;
402 // Otherwise, we need the value.
404 llvm::errs() << " Need direct primal of " << *val
405 << " in reverse from non-pointer insertval " << *user
406 << " "
407 << TR.query(const_cast<InsertValueInst *>(IVI)).str()
408 << "\n";
409 return true;
410 }
411
412 if (!shadow)
413 if (auto EVI = dyn_cast<ExtractValueInst>(user)) {
414 // Only need the index in the reverse, so if the value is not
415 // the index, short circuit and say we don't need
416 bool valueIsIndex = false;
417 for (unsigned i = 2; i < EVI->getNumOperands(); ++i) {
418 if (EVI->getOperand(i) == val) {
419 valueIsIndex = true;
420 }
421 }
422
423 if (!valueIsIndex)
424 return false;
425
426 // The index is only needed in the reverse if the value being inserted
427 // is a possible active floating point value
428 if (gutils->isConstantValue(const_cast<ExtractValueInst *>(EVI)) ||
429 TR.query(const_cast<ExtractValueInst *>(EVI))[{-1}] ==
431 return false;
432 // Otherwise, we need the value.
434 llvm::errs() << " Need direct primal of " << *val
435 << " in reverse from non-pointer extractval " << *user
436 << " "
437 << TR.query(const_cast<ExtractValueInst *>(EVI)).str()
438 << "\n";
439 return true;
440 }
441
442 Intrinsic::ID ID = Intrinsic::not_intrinsic;
443 if (auto II = dyn_cast<IntrinsicInst>(user)) {
444 ID = II->getIntrinsicID();
445 } else if (auto CI = dyn_cast<CallInst>(user)) {
446 StringRef funcName = getFuncNameFromCall(const_cast<CallInst *>(CI));
447 isMemFreeLibMFunction(funcName, &ID);
448 }
449
450 if (ID != Intrinsic::not_intrinsic) {
451 if (ID == Intrinsic::lifetime_start || ID == Intrinsic::lifetime_end ||
452 ID == Intrinsic::stacksave || ID == Intrinsic::stackrestore) {
453 return false;
454 }
455 }
456
457 if (!shadow)
458 if (auto si = dyn_cast<SelectInst>(user)) {
459 // Only would potentially need the condition
460 if (si->getCondition() != val) {
461 return false;
462 }
463
464 // only need the condition if select is active
465 bool needed = !gutils->isConstantValue(const_cast<SelectInst *>(si));
466 if (needed) {
468 llvm::errs() << " Need direct primal of " << *val
469 << " in reverse from select " << *user << "\n";
470 }
471 return needed;
472 }
473
474#include "BlasDiffUse.inc"
475
476 if (auto CI = dyn_cast<CallInst>(user)) {
477
478 {
479 SmallVector<OperandBundleDef, 2> OrigDefs;
480 CI->getOperandBundlesAsDefs(OrigDefs);
481 SmallVector<OperandBundleDef, 2> Defs;
482 for (auto bund : OrigDefs) {
483 for (auto inp : bund.inputs()) {
484 if (inp == val)
485 return true;
486 }
487 }
488 }
489
490 auto funcName = getFuncNameFromCall(CI);
491
492 {
493 auto found = customDiffUseHandlers.find(funcName);
494 if (found != customDiffUseHandlers.end()) {
495 bool useDefault = false;
496 bool result = found->second(CI, gutils, val, shadow, mode, useDefault);
497 if (!useDefault) {
498 if (result) {
500 llvm::errs() << " Need: " << to_string(qtype) << " of " << *val
501 << " from custom diff use handler of " << *CI
502 << "\n";
503 }
504 return result;
505 }
506 }
507 }
508
509 // Don't need shadow inputs for alloc function
510 if (shadow && isAllocationFunction(funcName, gutils->TLI))
511 return false;
512
513 // Even though inactive, keep the shadow pointer around in forward mode
514 // to perform the same memory free behavior on the shadow.
515 if (shadow &&
518 isDeallocationFunction(funcName, gutils->TLI)) {
520 llvm::errs() << " Need: shadow of " << *val
521 << " in reverse as shadow free " << *CI << "\n";
522 return true;
523 }
524
525 // Only need primal (and shadow) request for reverse, or shadow buffer
526 if (funcName == "MPI_Isend" || funcName == "MPI_Irecv" ||
527 funcName == "PMPI_Isend" || funcName == "PMPI_Irecv") {
528 if (gutils->isConstantInstruction(const_cast<Instruction *>(user)))
529 return false;
530
531 if (val == CI->getArgOperand(6)) {
533 llvm::errs() << " Need: " << to_string(qtype) << " request " << *val
534 << " in reverse for MPI " << *CI << "\n";
535 return true;
536 }
537 if (shadow && val == CI->getArgOperand(0)) {
538 if ((funcName == "MPI_Irecv" || funcName == "PMPI_Irecv") &&
540 // Need shadow buffer for forward pass of irecieve
542 llvm::errs() << " Need: shadow(" << to_string(qtype) << ") of "
543 << *val << " in reverse as shadow MPI " << *CI << "\n";
544 return true;
545 }
546 if (funcName == "MPI_Isend" || funcName == "PMPI_Isend") {
547 // Need shadow buffer for forward or reverse pass of isend
549 llvm::errs() << " Need: shadow(" << to_string(qtype) << ") of "
550 << *val << " in reverse as shadow MPI " << *CI << "\n";
551 return true;
552 }
553 }
554
555 return false;
556 }
557
558 if (!shadow) {
559
560 // Need the primal request in reverse.
561 if (funcName == "cuStreamSynchronize")
562 if (val == CI->getArgOperand(0)) {
564 llvm::errs() << " Need: primal(" << to_string(qtype) << ") of "
565 << *val << " in reverse for cuda sync " << *CI << "\n";
566 return true;
567 }
568
569 // Only need the primal request.
570 if (funcName == "MPI_Wait" || funcName == "PMPI_Wait")
571 if (val != CI->getArgOperand(0))
572 return false;
573
574 // Only need element count for reverse of waitall
575 if (funcName == "MPI_Waitall" || funcName == "PMPI_Waitall")
576 if (val != CI->getArgOperand(0) || val != CI->getOperand(1))
577 return false;
578
579 } else {
580 // Don't need shadow of anything (all via cache for reverse),
581 // but need shadow of request for primal.
582 if (funcName == "MPI_Wait" || funcName == "PMPI_Wait") {
583 if (gutils->isConstantInstruction(const_cast<Instruction *>(user)))
584 return false;
585 // Need shadow request in forward pass only
587 if (val == CI->getArgOperand(0)) {
589 llvm::errs() << " Need: shadow of " << *val
590 << " in reverse as shadow MPI " << *CI << "\n";
591 return true;
592 }
593 return false;
594 }
595 }
596
597 // Since adjoint of barrier is another barrier in reverse
598 // we still need even if instruction is inactive
599 if (!shadow)
600 if (funcName == "__kmpc_barrier" || funcName == "MPI_Barrier") {
602 llvm::errs() << " Need direct primal of " << *val
603 << " in reverse from barrier " << *user << "\n";
604 return true;
605 }
606
607 // Since adjoint of GC preserve is another preserve in reverse
608 // we still need even if instruction is inactive
609 if (!shadow)
610 if (funcName == "llvm.julia.gc_preserve_begin") {
612 llvm::errs() << " Need direct primal of " << *val
613 << " in reverse from gc " << *CI << "\n";
614 return true;
615 }
616
617 if (funcName == "julia.write_barrier" ||
618 funcName == "julia.write_barrier_binding") {
619 // Use in a write barrier requires the shadow in the forward, even
620 // though the instruction is active.
621 if (shadow && (mode != DerivativeMode::ReverseModeGradient &&
624 llvm::errs() << " Need: shadow of " << *val
625 << " in forward as shadow write_barrier " << *CI << "\n";
626 return true;
627 }
628 if (shadow) {
629 auto sz = CI->arg_size();
630 bool isStored = false;
631 // First pointer is the destination
632 for (size_t i = 1; i < sz; i++)
633 isStored |= val == CI->getArgOperand(i);
634 bool rematerialized = false;
635 if (isStored)
636 for (auto pair : gutils->backwardsOnlyShadows)
637 if (pair.second.stores.count(CI) &&
638 !gutils->isConstantValue(pair.first)) {
639 rematerialized = true;
640 break;
641 }
642
643 if (rematerialized) {
645 llvm::errs()
646 << " Need: shadow of " << *val
647 << " in rematerialized reverse as shadow write_barrier " << *CI
648 << "\n";
649 return true;
650 }
651 }
652 }
653
654 bool writeOnlyNoCapture = true;
655
656 if (shouldDisableNoWrite(CI)) {
657 writeOnlyNoCapture = false;
658 }
659 // Outside of forward mode, we don't need to keep a primal around in reverse
660 // just for the deallocation
661 if (!(mode == DerivativeMode::ForwardMode ||
663 isDeallocationFunction(funcName, gutils->TLI)) {
664 } else {
665 for (size_t i = 0; i < CI->arg_size(); i++) {
666 if (val == CI->getArgOperand(i)) {
667 if (!isNoCapture(CI, i)) {
668 writeOnlyNoCapture = false;
669 break;
670 }
671 if (!isWriteOnly(CI, i)) {
672 writeOnlyNoCapture = false;
673 break;
674 }
675 }
676 }
677 }
678
679 // Don't need the primal argument if it is write only and not captured
680 if (!shadow)
681 if (writeOnlyNoCapture) {
683 llvm::errs() << " No Need: primal of " << *val
684 << " per write-only no-capture use in " << *CI << "\n";
685 return false;
686 }
687
688 if (shadow) {
689 // Don't need the shadow argument if it is a pointer to pointers, which
690 // is only written since the shadow pointer store will have been
691 // completed in the forward pass.
692 if (writeOnlyNoCapture &&
693 TR.query(const_cast<Value *>(val))[{-1, -1}] == BaseType::Pointer &&
695 return false;
696
697 const Value *FV = CI->getCalledOperand();
698 if (FV == val) {
699 if (!gutils->isConstantInstruction(const_cast<Instruction *>(user)) ||
700 !gutils->isConstantValue(const_cast<Value *>((Value *)user))) {
702 llvm::errs() << " Need: shadow of " << *val
703 << " in reverse as shadow call " << *CI << "\n";
704 return true;
705 }
706 }
707 }
708 }
709
710 if (shadow) {
711 if (isa<ReturnInst>(user)) {
712 bool notrev = mode != DerivativeMode::ReverseModeGradient;
713 if (gutils->shadowReturnUsed && notrev) {
714
715 bool inst_cv = gutils->isConstantValue(const_cast<Value *>(val));
716
717 if ((qtype == QueryType::ShadowByConstPrimal && inst_cv) ||
718 (qtype == QueryType::Shadow && !inst_cv)) {
720 llvm::errs() << " Need: shadow(qtype=" << (int)qtype
721 << ",cv=" << inst_cv << ") of " << *val
722 << " in reverse as shadow return " << *user << "\n";
723 return true;
724 }
725 }
726 return false;
727 }
728
729 // With certain exceptions, assume active instructions require the
730 // shadow of the operand.
731 if (mode == DerivativeMode::ForwardMode ||
734 (!isa<ExtractValueInst>(user) && !isa<ExtractElementInst>(user) &&
735 !isa<InsertValueInst>(user) && !isa<InsertElementInst>(user) &&
736 !isPointerArithmeticInst(user, /*includephi*/ false,
737 /*includebin*/ false))) {
738
739 bool inst_cv = gutils->isConstantValue(const_cast<Value *>(val));
740
741 if (!inst_cv &&
742 !gutils->isConstantInstruction(const_cast<Instruction *>(user))) {
744 llvm::errs() << " Need: shadow of " << *val
745 << " in reverse as shadow inst " << *user << "\n";
746 return true;
747 }
748 }
749
750 // Now the remaining instructions are inactive, however note that
751 // a constant instruction may still require the use of the shadow
752 // in the forward pass, for example double* x = load double** y
753 // is a constant instruction, but needed in the forward. However,
754 // if the value [and from above also the instruction] is constant
755 // we don't need it.
756 if (gutils->isConstantValue(
757 const_cast<Value *>((const llvm::Value *)user))) {
758 return false;
759 }
760
761 // Now we don't need this value directly, but we may need it recursively
762 // in one the active value users
763 assert(recursiveUse);
764 *recursiveUse = true;
765 return false;
766 }
767
768 bool neededFB = false;
769 if (auto CB = dyn_cast<CallBase>(const_cast<Instruction *>(user))) {
770 neededFB = !callShouldNotUseDerivative(gutils, *CB, qtype, val);
771 } else {
772 neededFB = !gutils->isConstantInstruction(user) ||
773 !gutils->isConstantValue(const_cast<Instruction *>(user));
774 }
775 if (neededFB) {
777 llvm::errs() << " Need direct primal(" << mode << ") of " << *val
778 << " in reverse from fallback " << *user << "\n";
779 }
780 return neededFB;
781}
782
784 for (auto &pair : G) {
785 llvm::errs() << "[" << *pair.first.V << ", " << (int)pair.first.outgoing
786 << "]\n";
787 for (auto N : pair.second) {
788 llvm::errs() << "\t[" << *N.V << ", " << (int)N.outgoing << "]\n";
789 }
790 }
791}
792
793/* Returns true if there is a path from source 's' to sink 't' in
794 residual graph. Also fills parent[] to store the path */
796 const SetVector<Value *> &Recompute,
797 std::map<Node, Node> &parent) {
798 std::deque<Node> q;
799 for (auto V : Recompute) {
800 Node N(V, false);
801 parent.emplace(N, Node(nullptr, true));
802 q.push_back(N);
803 }
804
805 // Standard BFS Loop
806 while (!q.empty()) {
807 auto u = q.front();
808 q.pop_front();
809 auto found = G.find(u);
810 if (found == G.end())
811 continue;
812 for (auto v : found->second) {
813 if (parent.find(v) == parent.end()) {
814 q.push_back(v);
815 parent.emplace(v, u);
816 }
817 }
818 }
819}
820
821// Return 1 if next is better
822// 0 if equal
823// -1 if prev is better, or unknown
824int DifferentialUseAnalysis::cmpLoopNest(Loop *prev, Loop *next) {
825 if (next == prev)
826 return 0;
827 if (next == nullptr)
828 return 1;
829 else if (prev == nullptr)
830 return -1;
831 for (Loop *L = prev; L != nullptr; L = L->getParentLoop()) {
832 if (L == next)
833 return 1;
834 }
835 return -1;
836}
837
838void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI,
839 const SetVector<Value *> &Recomputes,
840 const SetVector<Value *> &Intermediates,
841 SetVector<Value *> &Required,
842 SetVector<Value *> &MinReq,
843 const GradientUtils *gutils,
844 llvm::TargetLibraryInfo &TLI) {
845 Graph G;
846 for (auto V : Intermediates) {
847 G[Node(V, false)].insert(Node(V, true));
848 forEachDifferentialUser(
849 [&](Value *U) {
850 if (Intermediates.count(U)) {
851 if (V != U)
852 G[Node(V, true)].insert(Node(U, false));
853 }
854 },
855 gutils, V);
856 }
857 for (auto pair : gutils->rematerializableAllocations) {
858 if (Intermediates.count(pair.first)) {
859 for (LoadInst *L : pair.second.loads) {
860 if (Intermediates.count(L)) {
861 if (L != pair.first)
862 G[Node(pair.first, true)].insert(Node(L, false));
863 }
864 }
865 for (auto L : pair.second.loadLikeCalls) {
866 if (Intermediates.count(L.loadCall)) {
867 if (L.loadCall != pair.first)
868 G[Node(pair.first, true)].insert(Node(L.loadCall, false));
869 }
870 }
871 }
872 }
873#ifndef NDEBUG
874 for (auto R : Required) {
875 assert(Intermediates.count(R));
876 }
877 for (auto R : Recomputes) {
878 assert(Intermediates.count(R));
879 }
880#endif
881
882 Graph Orig = G;
883
884 // Augment the flow while there is a path from source to sink
885 while (1) {
886 std::map<Node, Node> parent;
887 bfs(G, Recomputes, parent);
888 Node end(nullptr, false);
889 for (auto req : Required) {
890 if (parent.find(Node(req, true)) != parent.end()) {
891 end = Node(req, true);
892 break;
893 }
894 }
895 if (end.V == nullptr)
896 break;
897 // update residual capacities of the edges and reverse edges
898 // along the path
899 Node v = end;
900 while (1) {
901 assert(parent.find(v) != parent.end());
902 Node u = parent.find(v)->second;
903 assert(u.V != nullptr);
904 assert(G[u].count(v) == 1);
905 assert(G[v].count(u) == 0);
906 G[u].erase(v);
907 G[v].insert(u);
908 if (Recomputes.count(u.V) && u.outgoing == false)
909 break;
910 v = u;
911 }
912 }
913
914 // Flow is maximum now, find vertices reachable from s
915
916 std::map<Node, Node> parent;
917 bfs(G, Recomputes, parent);
918
919 SetVector<Value *> todo;
920
921 // Print all edges that are from a reachable vertex to
922 // non-reachable vertex in the original graph
923 for (auto &pair : Orig) {
924 if (parent.find(pair.first) != parent.end())
925 for (auto N : pair.second) {
926 if (parent.find(N) == parent.end()) {
927 assert(pair.first.outgoing == 0 && N.outgoing == 1);
928 assert(pair.first.V == N.V);
929 MinReq.insert(N.V);
930 todo.insert(N.V);
931 }
932 }
933 }
934
935 while (todo.size()) {
936 auto V = todo.front();
937 todo.remove(V);
938 assert(MinReq.count(V));
939
940 // Fix up non-cacheable calls to use their operand(s) instead
941 if (hasNoCache(V)) {
942 MinReq.remove(V);
943 for (auto &pair : Orig) {
944 if (pair.second.count(Node(V, false))) {
945 MinReq.insert(pair.first.V);
946 todo.insert(pair.first.V);
947 }
948 }
949 continue;
950 }
951
952 auto found = Orig.find(Node(V, true));
953 if (found != Orig.end()) {
954 const auto &mp = found->second;
955
956 // When ambiguous, push to cache the last value in a computation chain
957 // This should be considered in a cost for the max flow
958 if (mp.size() == 1 && !Required.count(V)) {
959 bool potentiallyRecursive =
960 isa<PHINode>((*mp.begin()).V) &&
961 OrigLI.isLoopHeader(cast<PHINode>((*mp.begin()).V)->getParent());
962 int moreOuterLoop =
963 cmpLoopNest(OrigLI.getLoopFor(cast<Instruction>(V)->getParent()),
964 OrigLI.getLoopFor(
965 cast<Instruction>(((*mp.begin()).V))->getParent()));
966 if (potentiallyRecursive)
967 continue;
968 if (moreOuterLoop == -1)
969 continue;
970 if (auto ASC = dyn_cast<AddrSpaceCastInst>((*mp.begin()).V)) {
971 if (ASC->getDestAddressSpace() == 11 ||
972 ASC->getDestAddressSpace() == 13)
973 continue;
974 if (ASC->getSrcAddressSpace() == 10 &&
975 ASC->getDestAddressSpace() == 0)
976 continue;
977 }
978 if (auto CI = dyn_cast<CastInst>((*mp.begin()).V)) {
979 if (CI->getType()->isPointerTy() &&
980 CI->getType()->getPointerAddressSpace() == 13)
981 continue;
982 }
983 if (auto G = dyn_cast<GetElementPtrInst>((*mp.begin()).V)) {
984 if (G->getType()->getPointerAddressSpace() == 13)
985 continue;
986 }
987 if (hasNoCache((*mp.begin()).V)) {
988 continue;
989 }
990 // If an allocation call, we cannot cache any "capturing" users
991 if (isAllocationCall(V, TLI) || isa<AllocaInst>(V)) {
992 auto next = (*mp.begin()).V;
993 bool noncapture = false;
994 if (isa<LoadInst>(next) || isNVLoad(next)) {
995 noncapture = true;
996 } else if (auto CI = dyn_cast<CallInst>(next)) {
997 bool captures = false;
998 for (size_t i = 0; i < CI->arg_size(); i++) {
999 if (CI->getArgOperand(i) == V && !isNoCapture(CI, i)) {
1000 captures = true;
1001 break;
1002 }
1003 }
1004 noncapture = !captures;
1005 }
1006
1007 if (!noncapture)
1008 continue;
1009 }
1010
1011 if (moreOuterLoop == 1 ||
1012 (moreOuterLoop == 0 &&
1013 DL.getTypeSizeInBits(V->getType()) >=
1014 DL.getTypeSizeInBits((*mp.begin()).V->getType()))) {
1015 MinReq.remove(V);
1016 auto nnode = (*mp.begin()).V;
1017 MinReq.insert(nnode);
1018 if (Orig.find(Node(nnode, true)) != Orig.end())
1019 todo.insert(nnode);
1020 }
1021 }
1022 }
1023 }
1024
1025 // Fix up non-repeatable writing calls that chain within rematerialized
1026 // allocations. We could iterate from the keys of the valuemap, but that would
1027 // be a non-determinstic ordering.
1028 for (auto V : Intermediates) {
1029 auto found = gutils->rematerializableAllocations.find(V);
1030 if (found == gutils->rematerializableAllocations.end())
1031 continue;
1032 if (!found->second.nonRepeatableWritingCall)
1033 continue;
1034
1035 // We are already caching this allocation directly, we're fine
1036 if (MinReq.count(V))
1037 continue;
1038
1039 // If we are recomputing a load, we need to fix this.
1040 bool needsLoad = false;
1041 for (auto load : found->second.loads)
1042 if (Intermediates.count(load) && !MinReq.count(load)) {
1043 needsLoad = true;
1044 break;
1045 }
1046 for (auto load : found->second.loadLikeCalls)
1047 if (Intermediates.count(load.loadCall) && !MinReq.count(load.loadCall)) {
1048 needsLoad = true;
1049 break;
1050 }
1051
1052 if (!needsLoad)
1053 continue;
1054
1055 // Rewire the uses to cache the allocation directly.
1056 // TODO: as further optimization, we can remove potentially unnecessary
1057 // values that we are keeping for stores.
1058 MinReq.insert(V);
1059 }
1060
1061 return;
1062}
1063
1065 const GradientUtils *gutils, CallBase &call, QueryType qtype,
1066 const Value *val) {
1067 bool shadowReturnUsed = false;
1068 auto smode = gutils->mode;
1071 (void)gutils->getReturnDiffeType(&call, nullptr, &shadowReturnUsed, smode);
1072
1073 bool useConstantFallback =
1074 gutils->isConstantInstruction(&call) &&
1075 (gutils->isConstantValue(&call) || !shadowReturnUsed);
1076 if (useConstantFallback && gutils->mode != DerivativeMode::ForwardMode &&
1078
1079 // if there is an escaping allocation, which is deduced needed in
1080 // reverse pass, we need to do the recursive procedure to perform the
1081 // free.
1082 bool escapingNeededAllocation = false;
1083
1084 // First, some calls may be marked, non escaping. If that's the case, we
1085 // can avoid unnecessary work.
1086 if (!isNoEscapingAllocation(&call)) {
1087
1088 // If the function being called has a definition, check if any of the
1089 // subcalls can allocate.
1090 if (auto F = getFunctionFromCall(&call)) {
1091 SmallVector<Function *, 1> todo = {F};
1092 SmallPtrSet<Function *, 1> done;
1093 bool seenAllocation = false;
1094 while (todo.size() && !seenAllocation) {
1095 auto cur = todo.pop_back_val();
1096 if (done.count(cur))
1097 continue;
1098 done.insert(cur);
1099 // assume empty functions allocate.
1100 if (cur->empty()) {
1101 // unless they are marked
1102 if (isNoEscapingAllocation(cur))
1103 continue;
1104 seenAllocation = true;
1105 break;
1106 }
1107 auto UR = getGuaranteedUnreachable(cur);
1108 for (auto &BB : *cur) {
1109 if (UR.count(&BB))
1110 continue;
1111 for (auto &I : BB)
1112 if (auto CB = dyn_cast<CallBase>(&I)) {
1113 if (isNoEscapingAllocation(CB))
1114 continue;
1115 if (isAllocationCall(CB, gutils->TLI)) {
1116 seenAllocation = true;
1117 goto finish;
1118 }
1119 if (auto F = getFunctionFromCall(CB)) {
1120 todo.push_back(F);
1121 continue;
1122 }
1123 // Conservatively assume indirect functions allocate.
1124 seenAllocation = true;
1125 goto finish;
1126 }
1127 }
1128 finish:;
1129 if (!seenAllocation)
1130 goto doneEscapeCheck;
1131 }
1132
1133 // Next, test if any allocation could be stored into one of the
1134 // arguments.
1135 for (unsigned i = 0; i < call.arg_size(); ++i) {
1136 Value *a = call.getOperand(i);
1137
1138 if (EnzymeJuliaAddrLoad && isSpecialPtr(a->getType()))
1139 continue;
1140
1141 if (!gutils->TR.anyPointer(a))
1142 continue;
1143
1144 auto vd = gutils->TR.query(a);
1145
1146 if (!vd[{-1, -1}].isPossiblePointer())
1147 continue;
1148
1149 if (isReadOnly(&call, i))
1150 continue;
1151
1152 // An allocation could only be needed in the reverse pass if it
1153 // escapes into an argument. However, is the parameter by which it
1154 // escapes could capture the pointer, the rest of Enzyme's caching
1155 // mechanisms cannot assume that the allocation itself is
1156 // reloadable, since it may have been captured and overwritten
1157 // elsewhere.
1158 // TODO: this justification will need revisiting in the future as
1159 // the caching algorithm becomes increasingly sophisticated.
1160 if (!isNoCapture(&call, i))
1161 continue;
1162
1163 escapingNeededAllocation = true;
1164 goto doneEscapeCheck;
1165 break;
1166 }
1167
1168 // Finally, test if the return is a potential pointer, and needed for
1169 // the reverse pass.
1171 escapingNeededAllocation = true;
1172 goto doneEscapeCheck;
1173 }
1174
1175 // Not a pointer
1176 if (!gutils->TR.anyPointer(&call)) {
1177 goto doneEscapeCheck;
1178 }
1179 // GC'd pointer, not needed to be explicitly free'd
1180 if (EnzymeJuliaAddrLoad && isSpecialPtr(call.getType())) {
1181 goto doneEscapeCheck;
1182 }
1183
1184 std::map<UsageKey, bool> CacheResults;
1185 for (auto pair : gutils->knownRecomputeHeuristic) {
1186 if (!pair.second || gutils->unnecessaryIntermediates.count(
1187 cast<Instruction>(pair.first))) {
1188 CacheResults[UsageKey(pair.first, QueryType::Primal)] = false;
1189 }
1190 }
1191
1192 // to avoid an infinite loop, we assume this is needed by
1193 // marking the query that led us.
1194 if (val)
1195 CacheResults[UsageKey(val, qtype)] = true;
1196
1197 auto found = gutils->knownRecomputeHeuristic.find(&call);
1198 if (found != gutils->knownRecomputeHeuristic.end()) {
1199 if (!found->second) {
1200 CacheResults.erase(UsageKey(&call, QueryType::Primal));
1201 escapingNeededAllocation =
1203 QueryType::Primal>(gutils, &call,
1205 CacheResults, gutils->notForAnalysis);
1206 }
1207 } else {
1208 escapingNeededAllocation =
1210 QueryType::Primal>(gutils, &call,
1212 CacheResults, gutils->notForAnalysis);
1213 }
1214 }
1215 }
1216
1217 doneEscapeCheck:;
1218 if (escapingNeededAllocation)
1219 useConstantFallback = false;
1220 }
1221 return useConstantFallback;
1222}
StringMap< std::function< bool(const CallInst *, const GradientUtils *, const Value *, bool, DerivativeMode, bool &)> > customDiffUseHandlers
QueryType
Classification of what type of use is requested.
@ ShadowByConstPrimal
std::pair< const llvm::Value *, QueryType > UsageKey
llvm::cl::opt< bool > EnzymePrintDiffUse
static llvm::SmallPtrSet< llvm::BasicBlock *, 4 > getGuaranteedUnreachable(llvm::Function *F)
static bool isDeallocationFunction(const llvm::StringRef name, const llvm::TargetLibraryInfo &TLI)
Return whether a given function is a known C/C++ memory deallocation function For updating below one ...
static bool isAllocationFunction(const llvm::StringRef name, const llvm::TargetLibraryInfo &TLI)
Return whether a given function is a known C/C++ memory allocation function For updating below one sh...
static bool isAllocationCall(const llvm::Value *TmpOrig, llvm::TargetLibraryInfo &TLI)
static bool isReadOnly(Operation *op)
static Operation * getFunctionFromCall(CallOpInterface iface)
llvm::cl::opt< bool > EnzymeGlobalActivity
constexpr const char * to_string(ActivityAnalyzer::UseActivity UA)
llvm::cl::opt< bool > looseTypeAnalysis
static void bfs(const Graph &G, const llvm::SetVector< Value > &Sources, DenseMap< Node, Node > &parent)
llvm::PointerUnion< Operation *, Value > Node
TypeTree defaultTypeTreeForLLVM(llvm::Type *ET, llvm::Instruction *I, bool intIsPointer)
static bool isMemFreeLibMFunction(llvm::StringRef str, llvm::Intrinsic::ID *ID=nullptr)
bool isNVLoad(const llvm::Value *V)
Definition Utils.cpp:4483
static bool isNoEscapingAllocation(const llvm::Function *F)
Definition Utils.h:1878
static bool isNoCapture(const llvm::CallBase *call, size_t idx)
Definition Utils.h:1840
static bool isPointerArithmeticInst(const llvm::Value *V, bool includephi=true, bool includebin=true)
Definition Utils.h:1456
static bool shouldDisableNoWrite(const llvm::CallInst *CI)
Definition Utils.h:1423
static bool isSpecialPtr(llvm::Type *Ty)
Definition Utils.h:2354
static bool hasNoCache(llvm::Value *op)
Definition Utils.h:1283
llvm::cl::opt< bool > EnzymeJuliaAddrLoad
static llvm::StringRef getFuncNameFromCall(const llvm::CallBase *op)
Definition Utils.h:1269
static bool isWriteOnly(const llvm::Function *F, ssize_t arg=-1)
Definition Utils.h:1788
DerivativeMode
Definition Utils.h:390
llvm::Function *const newFunc
The function whose instructions we are caching.
llvm::TargetLibraryInfo & TLI
Various analysis results of newFunc.
DerivativeMode mode
llvm::SmallPtrSet< llvm::Instruction *, 4 > unnecessaryIntermediates
TypeResults TR
std::map< const llvm::Value *, bool > knownRecomputeHeuristic
llvm::ValueMap< llvm::Value *, Rematerializer > rematerializableAllocations
llvm::Function * oldFunc
DIFFE_TYPE getReturnDiffeType(llvm::Value *orig, bool *primalReturnUsedP, bool *shadowReturnUsedP, DerivativeMode cmode) const
bool isConstantInstruction(const llvm::Instruction *inst) const
llvm::ValueMap< llvm::Value *, ShadowRematerializer > backwardsOnlyShadows
Only loaded from and stored to (not captured), mapped to the stores (and memset).
bool isConstantValue(llvm::Value *val) const
A holder class representing the results of running TypeAnalysis on a given function.
bool anyPointer(llvm::Value *val) const
Whether any part of the top level register can contain a pointer e.g.
TypeTree query(llvm::Value *val) const
The TypeTree of a particular Value.
TypeTree Lookup(size_t len, const llvm::DataLayout &dl) const
Select all submappings whose first index is in range [0, len) and remove the first index.
Definition TypeTree.h:593
std::string str() const
Returns a string representation of this TypeTree.
Definition TypeTree.h:1383
void minCut(const llvm::DataLayout &DL, llvm::LoopInfo &OrigLI, const llvm::SetVector< llvm::Value * > &Recomputes, const llvm::SetVector< llvm::Value * > &Intermediates, llvm::SetVector< llvm::Value * > &Required, llvm::SetVector< llvm::Value * > &MinReq, const GradientUtils *gutils, llvm::TargetLibraryInfo &TLI)
bool callShouldNotUseDerivative(const GradientUtils *gutils, llvm::CallBase &orig, QueryType qtype, const llvm::Value *val)
Return whether or not this is a constant and should use reverse pass.
bool is_value_needed_in_reverse(const GradientUtils *gutils, const llvm::Value *inst, DerivativeMode mode, std::map< UsageKey, bool > &seen, const llvm::SmallPtrSetImpl< llvm::BasicBlock * > &oldUnreachable)
void bfs(const std::map< Node, std::set< Node > > &G, const llvm::SetVector< llvm::Value * > &Recompute, std::map< Node, Node > &parent)
void dump(std::map< Node, std::set< Node > > &G)
bool is_use_directly_needed_in_reverse(const GradientUtils *gutils, const llvm::Value *val, DerivativeMode mode, const llvm::Instruction *user, const llvm::SmallPtrSetImpl< llvm::BasicBlock * > &oldUnreachable, QueryType shadow, bool *recursiveUse=nullptr)
Determine if a value is needed directly to compute the adjoint of the given instruction user.
int cmpLoopNest(llvm::Loop *prev, llvm::Loop *next)