35#include "llvm/IR/BasicBlock.h"
36#include "llvm/IR/Instruction.h"
37#include "llvm/IR/IntrinsicsX86.h"
39#include "llvm/ADT/ArrayRef.h"
40#include "llvm/ADT/SmallPtrSet.h"
41#include "llvm/ADT/SmallVector.h"
43#include "llvm/Support/Casting.h"
44#include "llvm/Support/ErrorHandling.h"
47#include "GradientUtils.h"
52StringMap<std::function<bool(
const CallInst *,
const GradientUtils *,
58 const Instruction *user,
59 const SmallPtrSetImpl<BasicBlock *> &oldUnreachable,
QueryType qtype,
63 if (
auto ainst = dyn_cast<Instruction>(val)) {
64 assert(ainst->getParent()->getParent() == gutils->
oldFunc);
73 assert(recursiveUse ==
nullptr);
75 assert(recursiveUse !=
nullptr);
84 if (TR.
query(
const_cast<Value *
>(val))[{-1}].isFloat())
89 llvm::errs() <<
" Need: of " << *val <<
" in reverse as nullptr user\n";
93 assert(user->getParent()->getParent() == gutils->
oldFunc);
95 if (oldUnreachable.count(user->getParent()))
98 if (
auto SI = dyn_cast<StoreInst>(user)) {
103 if (SI->getValueOperand() == val) {
104 for (
auto U : SI->getPointerOperand()->users()) {
105 if (
auto CI = dyn_cast<CallInst>(U)) {
106 if (
auto F = CI->getCalledFunction()) {
107 if (F->getName() ==
"__kmpc_for_static_init_4" ||
108 F->getName() ==
"__kmpc_for_static_init_4u" ||
109 F->getName() ==
"__kmpc_for_static_init_8" ||
110 F->getName() ==
"__kmpc_for_static_init_8u") {
111 if (CI->getArgOperand(4) == val ||
112 CI->getArgOperand(5) == val || CI->getArgOperand(6)) {
114 llvm::errs() <<
" Need direct primal of " << *val
115 <<
" in reverse from omp " << *user <<
"\n";
126 auto &DL = gutils->
newFunc->getParent()->getDataLayout();
127 auto ET = SI->getValueOperand()->getType();
128 auto storeSize = (DL.getTypeSizeInBits(ET) + 7) / 8;
129 auto vd = TR.
query(
const_cast<Value *
>(SI->getPointerOperand()))
139 bool hasFloat =
true;
140 for (ssize_t i = -1; i < (ssize_t)storeSize; ++i) {
141 if (vd[{(int)i}].isFloat()) {
147 SI->getPointerOperand()))) {
149 llvm::errs() <<
" Need direct primal of " << *val
150 <<
" in reverse from runtime active store " << *user
156 bool backwardsShadow =
false;
157 bool forwardsShadow =
true;
159 if (pair.second.stores.count(SI) &&
161 backwardsShadow =
true;
162 forwardsShadow = pair.second.primalInitialize;
169 if (SI->getValueOperand() == val) {
182 (forwardsShadow || backwardsShadow)) ||
190 auto ct = TR.
query(
const_cast<Value *
>(SI->getValueOperand()))[{-1}];
198 (forwardsShadow || backwardsShadow)) ||
206 const_cast<Value *
>(SI->getPointerOperand()))) {
208 llvm::errs() <<
" Need: shadow of " << *val
209 <<
" in reverse as shadow store " << *SI <<
"\n";
218 if (
auto LI = dyn_cast<LoadInst>(user)) {
220 auto vd = TR.
query(
const_cast<llvm::Instruction *
>(user));
222 auto ET = LI->getType();
230 auto &DL = gutils->
newFunc->getParent()->getDataLayout();
231 auto LoadSize = (DL.getTypeSizeInBits(LI->getType()) + 1) / 8;
232 bool hasFloat =
true;
233 for (ssize_t i = -1; i < (ssize_t)LoadSize; ++i) {
234 if (vd[{(int)i}].isFloat()) {
240 const_cast<llvm::Instruction *
>(user))) {
242 llvm::errs() <<
" Need direct primal of " << *val
243 <<
" in reverse from runtime active load " << *user
251 if (
auto MTI = dyn_cast<MemTransferInst>(user)) {
255 if (MTI->getArgOperand(1) == val || MTI->getArgOperand(2) == val) {
257 if (pair.second.stores.count(MTI)) {
259 llvm::errs() <<
" Need direct primal of " << *val
260 <<
" in reverse from remat memtransfer " << *user
265 if (MTI->getArgOperand(2) != val)
270 llvm::errs() <<
" Need direct primal of " << *val
271 <<
" in reverse from memtransfer " << *user <<
"\n";
276 if (MTI->getArgOperand(0) != val && MTI->getArgOperand(1) != val)
280 const_cast<Value *
>(MTI->getArgOperand(0)))) {
282 llvm::errs() <<
" Need: shadow of " << *val
283 <<
" in reverse as shadow MTI " << *MTI <<
"\n";
290 if (
auto MS = dyn_cast<MemSetInst>(user)) {
294 if (MS->getArgOperand(1) == val || MS->getArgOperand(2) == val) {
296 if (pair.second.stores.count(MS)) {
298 llvm::errs() <<
" Need direct primal of " << *val
299 <<
" in reverse from remat memset " << *user <<
"\n";
305 llvm::errs() <<
" Need direct primal of " << *val
306 <<
" in reverse from memset " << *user <<
"\n";
312 if (MS->getArgOperand(0) != val)
315 if (!gutils->
isConstantValue(
const_cast<Value *
>(MS->getArgOperand(0)))) {
317 llvm::errs() <<
" Need: shadow of " << *val
318 <<
" in reverse as shadow MS " << *MS <<
"\n";
326 if (isa<CmpInst>(user) || isa<BranchInst>(user) || isa<ReturnInst>(user) ||
327 isa<FPExtInst>(user) || isa<FPTruncInst>(user)
337 if (
auto IEI = dyn_cast<InsertElementInst>(user)) {
340 if (IEI->getOperand(2) != val) {
346 TR.
query(
const_cast<InsertElementInst *
>(IEI))[{-1}] ==
351 llvm::errs() <<
" Need direct primal of " << *val
352 <<
" in reverse from non-pointer insertelem " << *user
354 << TR.
query(
const_cast<InsertElementInst *
>(IEI)).
str()
360 if (
auto EEI = dyn_cast<ExtractElementInst>(user)) {
363 if (EEI->getIndexOperand() != val) {
369 TR.
query(
const_cast<ExtractElementInst *
>(EEI))[{-1}] ==
374 llvm::errs() <<
" Need direct primal of " << *val
375 <<
" in reverse from non-pointer extractelem " << *user
377 << TR.
query(
const_cast<ExtractElementInst *
>(EEI)).
str()
383 if (
auto IVI = dyn_cast<InsertValueInst>(user)) {
386 bool valueIsIndex =
false;
387 for (
unsigned i = 2; i < IVI->getNumOperands(); ++i) {
388 if (IVI->getOperand(i) == val) {
399 TR.
query(
const_cast<InsertValueInst *
>(IVI))[{-1}] ==
404 llvm::errs() <<
" Need direct primal of " << *val
405 <<
" in reverse from non-pointer insertval " << *user
407 << TR.
query(
const_cast<InsertValueInst *
>(IVI)).
str()
413 if (
auto EVI = dyn_cast<ExtractValueInst>(user)) {
416 bool valueIsIndex =
false;
417 for (
unsigned i = 2; i < EVI->getNumOperands(); ++i) {
418 if (EVI->getOperand(i) == val) {
429 TR.
query(
const_cast<ExtractValueInst *
>(EVI))[{-1}] ==
434 llvm::errs() <<
" Need direct primal of " << *val
435 <<
" in reverse from non-pointer extractval " << *user
437 << TR.
query(
const_cast<ExtractValueInst *
>(EVI)).
str()
442 Intrinsic::ID ID = Intrinsic::not_intrinsic;
443 if (
auto II = dyn_cast<IntrinsicInst>(user)) {
444 ID = II->getIntrinsicID();
445 }
else if (
auto CI = dyn_cast<CallInst>(user)) {
450 if (ID != Intrinsic::not_intrinsic) {
451 if (ID == Intrinsic::lifetime_start || ID == Intrinsic::lifetime_end ||
452 ID == Intrinsic::stacksave || ID == Intrinsic::stackrestore) {
458 if (
auto si = dyn_cast<SelectInst>(user)) {
460 if (si->getCondition() != val) {
468 llvm::errs() <<
" Need direct primal of " << *val
469 <<
" in reverse from select " << *user <<
"\n";
474#include "BlasDiffUse.inc"
476 if (
auto CI = dyn_cast<CallInst>(user)) {
479 SmallVector<OperandBundleDef, 2> OrigDefs;
480 CI->getOperandBundlesAsDefs(OrigDefs);
481 SmallVector<OperandBundleDef, 2> Defs;
482 for (
auto bund : OrigDefs) {
483 for (
auto inp : bund.inputs()) {
495 bool useDefault =
false;
496 bool result = found->second(CI, gutils, val, shadow, mode, useDefault);
500 llvm::errs() <<
" Need: " <<
to_string(qtype) <<
" of " << *val
501 <<
" from custom diff use handler of " << *CI
520 llvm::errs() <<
" Need: shadow of " << *val
521 <<
" in reverse as shadow free " << *CI <<
"\n";
526 if (funcName ==
"MPI_Isend" || funcName ==
"MPI_Irecv" ||
527 funcName ==
"PMPI_Isend" || funcName ==
"PMPI_Irecv") {
531 if (val == CI->getArgOperand(6)) {
533 llvm::errs() <<
" Need: " <<
to_string(qtype) <<
" request " << *val
534 <<
" in reverse for MPI " << *CI <<
"\n";
537 if (shadow && val == CI->getArgOperand(0)) {
538 if ((funcName ==
"MPI_Irecv" || funcName ==
"PMPI_Irecv") &&
542 llvm::errs() <<
" Need: shadow(" <<
to_string(qtype) <<
") of "
543 << *val <<
" in reverse as shadow MPI " << *CI <<
"\n";
546 if (funcName ==
"MPI_Isend" || funcName ==
"PMPI_Isend") {
549 llvm::errs() <<
" Need: shadow(" <<
to_string(qtype) <<
") of "
550 << *val <<
" in reverse as shadow MPI " << *CI <<
"\n";
561 if (funcName ==
"cuStreamSynchronize")
562 if (val == CI->getArgOperand(0)) {
564 llvm::errs() <<
" Need: primal(" <<
to_string(qtype) <<
") of "
565 << *val <<
" in reverse for cuda sync " << *CI <<
"\n";
570 if (funcName ==
"MPI_Wait" || funcName ==
"PMPI_Wait")
571 if (val != CI->getArgOperand(0))
575 if (funcName ==
"MPI_Waitall" || funcName ==
"PMPI_Waitall")
576 if (val != CI->getArgOperand(0) || val != CI->getOperand(1))
582 if (funcName ==
"MPI_Wait" || funcName ==
"PMPI_Wait") {
587 if (val == CI->getArgOperand(0)) {
589 llvm::errs() <<
" Need: shadow of " << *val
590 <<
" in reverse as shadow MPI " << *CI <<
"\n";
600 if (funcName ==
"__kmpc_barrier" || funcName ==
"MPI_Barrier") {
602 llvm::errs() <<
" Need direct primal of " << *val
603 <<
" in reverse from barrier " << *user <<
"\n";
610 if (funcName ==
"llvm.julia.gc_preserve_begin") {
612 llvm::errs() <<
" Need direct primal of " << *val
613 <<
" in reverse from gc " << *CI <<
"\n";
617 if (funcName ==
"julia.write_barrier" ||
618 funcName ==
"julia.write_barrier_binding") {
624 llvm::errs() <<
" Need: shadow of " << *val
625 <<
" in forward as shadow write_barrier " << *CI <<
"\n";
629 auto sz = CI->arg_size();
630 bool isStored =
false;
632 for (
size_t i = 1; i < sz; i++)
633 isStored |= val == CI->getArgOperand(i);
634 bool rematerialized =
false;
637 if (pair.second.stores.count(CI) &&
639 rematerialized =
true;
643 if (rematerialized) {
646 <<
" Need: shadow of " << *val
647 <<
" in rematerialized reverse as shadow write_barrier " << *CI
654 bool writeOnlyNoCapture =
true;
657 writeOnlyNoCapture =
false;
665 for (
size_t i = 0; i < CI->arg_size(); i++) {
666 if (val == CI->getArgOperand(i)) {
668 writeOnlyNoCapture =
false;
672 writeOnlyNoCapture =
false;
681 if (writeOnlyNoCapture) {
683 llvm::errs() <<
" No Need: primal of " << *val
684 <<
" per write-only no-capture use in " << *CI <<
"\n";
692 if (writeOnlyNoCapture &&
697 const Value *FV = CI->getCalledOperand();
702 llvm::errs() <<
" Need: shadow of " << *val
703 <<
" in reverse as shadow call " << *CI <<
"\n";
711 if (isa<ReturnInst>(user)) {
720 llvm::errs() <<
" Need: shadow(qtype=" << (int)qtype
721 <<
",cv=" << inst_cv <<
") of " << *val
722 <<
" in reverse as shadow return " << *user <<
"\n";
734 (!isa<ExtractValueInst>(user) && !isa<ExtractElementInst>(user) &&
735 !isa<InsertValueInst>(user) && !isa<InsertElementInst>(user) &&
744 llvm::errs() <<
" Need: shadow of " << *val
745 <<
" in reverse as shadow inst " << *user <<
"\n";
757 const_cast<Value *
>((
const llvm::Value *)user))) {
763 assert(recursiveUse);
764 *recursiveUse =
true;
768 bool neededFB =
false;
769 if (
auto CB = dyn_cast<CallBase>(
const_cast<Instruction *
>(user))) {
770 neededFB = !callShouldNotUseDerivative(gutils, *CB, qtype, val);
777 llvm::errs() <<
" Need direct primal(" << mode <<
") of " << *val
778 <<
" in reverse from fallback " << *user <<
"\n";
784 for (
auto &pair : G) {
785 llvm::errs() <<
"[" << *pair.first.V <<
", " << (int)pair.first.outgoing
787 for (
auto N : pair.second) {
788 llvm::errs() <<
"\t[" << *N.V <<
", " << (int)N.outgoing <<
"]\n";
796 const SetVector<Value *> &Recompute,
797 std::map<Node, Node> &parent) {
799 for (
auto V : Recompute) {
801 parent.emplace(N,
Node(
nullptr,
true));
809 auto found = G.find(u);
810 if (found == G.end())
812 for (
auto v : found->second) {
813 if (parent.find(v) == parent.end()) {
815 parent.emplace(v, u);
829 else if (prev ==
nullptr)
831 for (Loop *L = prev; L !=
nullptr; L = L->getParentLoop()) {
839 const SetVector<Value *> &Recomputes,
840 const SetVector<Value *> &Intermediates,
841 SetVector<Value *> &Required,
842 SetVector<Value *> &MinReq,
844 llvm::TargetLibraryInfo &TLI) {
846 for (
auto V : Intermediates) {
847 G[
Node(V,
false)].insert(
Node(V,
true));
848 forEachDifferentialUser(
850 if (Intermediates.count(U)) {
852 G[
Node(V,
true)].insert(
Node(U,
false));
858 if (Intermediates.count(pair.first)) {
859 for (LoadInst *L : pair.second.loads) {
860 if (Intermediates.count(L)) {
862 G[
Node(pair.first,
true)].insert(
Node(L,
false));
865 for (
auto L : pair.second.loadLikeCalls) {
866 if (Intermediates.count(L.loadCall)) {
867 if (L.loadCall != pair.first)
868 G[
Node(pair.first,
true)].insert(
Node(L.loadCall,
false));
874 for (
auto R : Required) {
875 assert(Intermediates.count(R));
877 for (
auto R : Recomputes) {
878 assert(Intermediates.count(R));
886 std::map<Node, Node> parent;
887 bfs(G, Recomputes, parent);
888 Node end(
nullptr,
false);
889 for (
auto req : Required) {
890 if (parent.find(
Node(req,
true)) != parent.end()) {
891 end =
Node(req,
true);
895 if (end.V ==
nullptr)
901 assert(parent.find(v) != parent.end());
902 Node u = parent.find(v)->second;
903 assert(u.V !=
nullptr);
904 assert(G[u].count(v) == 1);
905 assert(G[v].count(u) == 0);
908 if (Recomputes.count(u.V) && u.outgoing ==
false)
916 std::map<Node, Node> parent;
917 bfs(G, Recomputes, parent);
919 SetVector<Value *> todo;
923 for (
auto &pair : Orig) {
924 if (parent.find(pair.first) != parent.end())
925 for (
auto N : pair.second) {
926 if (parent.find(N) == parent.end()) {
927 assert(pair.first.outgoing == 0 && N.outgoing == 1);
928 assert(pair.first.V == N.V);
935 while (todo.size()) {
936 auto V = todo.front();
938 assert(MinReq.count(V));
943 for (
auto &pair : Orig) {
944 if (pair.second.count(
Node(V,
false))) {
945 MinReq.insert(pair.first.V);
946 todo.insert(pair.first.V);
952 auto found = Orig.find(
Node(V,
true));
953 if (found != Orig.end()) {
954 const auto &mp = found->second;
958 if (mp.size() == 1 && !Required.count(V)) {
959 bool potentiallyRecursive =
960 isa<PHINode>((*mp.begin()).V) &&
961 OrigLI.isLoopHeader(cast<PHINode>((*mp.begin()).V)->getParent());
963 cmpLoopNest(OrigLI.getLoopFor(cast<Instruction>(V)->getParent()),
965 cast<Instruction>(((*mp.begin()).V))->getParent()));
966 if (potentiallyRecursive)
968 if (moreOuterLoop == -1)
970 if (
auto ASC = dyn_cast<AddrSpaceCastInst>((*mp.begin()).V)) {
971 if (ASC->getDestAddressSpace() == 11 ||
972 ASC->getDestAddressSpace() == 13)
974 if (ASC->getSrcAddressSpace() == 10 &&
975 ASC->getDestAddressSpace() == 0)
978 if (
auto CI = dyn_cast<CastInst>((*mp.begin()).V)) {
979 if (CI->getType()->isPointerTy() &&
980 CI->getType()->getPointerAddressSpace() == 13)
983 if (
auto G = dyn_cast<GetElementPtrInst>((*mp.begin()).V)) {
984 if (G->getType()->getPointerAddressSpace() == 13)
992 auto next = (*mp.begin()).V;
993 bool noncapture =
false;
994 if (isa<LoadInst>(next) ||
isNVLoad(next)) {
996 }
else if (
auto CI = dyn_cast<CallInst>(next)) {
997 bool captures =
false;
998 for (
size_t i = 0; i < CI->arg_size(); i++) {
999 if (CI->getArgOperand(i) == V && !
isNoCapture(CI, i)) {
1004 noncapture = !captures;
1011 if (moreOuterLoop == 1 ||
1012 (moreOuterLoop == 0 &&
1013 DL.getTypeSizeInBits(V->getType()) >=
1014 DL.getTypeSizeInBits((*mp.begin()).V->getType()))) {
1016 auto nnode = (*mp.begin()).V;
1017 MinReq.insert(nnode);
1018 if (Orig.find(
Node(nnode,
true)) != Orig.end())
1028 for (
auto V : Intermediates) {
1032 if (!found->second.nonRepeatableWritingCall)
1036 if (MinReq.count(V))
1040 bool needsLoad =
false;
1041 for (
auto load : found->second.loads)
1042 if (Intermediates.count(load) && !MinReq.count(load)) {
1046 for (
auto load : found->second.loadLikeCalls)
1047 if (Intermediates.count(load.loadCall) && !MinReq.count(load.loadCall)) {
1067 bool shadowReturnUsed =
false;
1068 auto smode = gutils->
mode;
1073 bool useConstantFallback =
1082 bool escapingNeededAllocation =
false;
1091 SmallVector<Function *, 1> todo = {F};
1092 SmallPtrSet<Function *, 1> done;
1093 bool seenAllocation =
false;
1094 while (todo.size() && !seenAllocation) {
1095 auto cur = todo.pop_back_val();
1096 if (done.count(cur))
1104 seenAllocation =
true;
1108 for (
auto &BB : *cur) {
1112 if (
auto CB = dyn_cast<CallBase>(&I)) {
1116 seenAllocation =
true;
1124 seenAllocation =
true;
1129 if (!seenAllocation)
1130 goto doneEscapeCheck;
1135 for (
unsigned i = 0; i < call.arg_size(); ++i) {
1136 Value *a = call.getOperand(i);
1144 auto vd = gutils->
TR.
query(a);
1146 if (!vd[{-1, -1}].isPossiblePointer())
1163 escapingNeededAllocation =
true;
1164 goto doneEscapeCheck;
1171 escapingNeededAllocation =
true;
1172 goto doneEscapeCheck;
1177 goto doneEscapeCheck;
1181 goto doneEscapeCheck;
1184 std::map<UsageKey, bool> CacheResults;
1187 cast<Instruction>(pair.first))) {
1195 CacheResults[
UsageKey(val, qtype)] =
true;
1199 if (!found->second) {
1201 escapingNeededAllocation =
1205 CacheResults, gutils->notForAnalysis);
1208 escapingNeededAllocation =
1212 CacheResults, gutils->notForAnalysis);
1218 if (escapingNeededAllocation)
1219 useConstantFallback =
false;
1221 return useConstantFallback;
StringMap< std::function< bool(const CallInst *, const GradientUtils *, const Value *, bool, DerivativeMode, bool &)> > customDiffUseHandlers
QueryType
Classification of what type of use is requested.
std::pair< const llvm::Value *, QueryType > UsageKey
llvm::cl::opt< bool > EnzymePrintDiffUse
static llvm::SmallPtrSet< llvm::BasicBlock *, 4 > getGuaranteedUnreachable(llvm::Function *F)
static bool isDeallocationFunction(const llvm::StringRef name, const llvm::TargetLibraryInfo &TLI)
Return whether a given function is a known C/C++ memory deallocation function For updating below one ...
static bool isAllocationFunction(const llvm::StringRef name, const llvm::TargetLibraryInfo &TLI)
Return whether a given function is a known C/C++ memory allocation function For updating below one sh...
static bool isAllocationCall(const llvm::Value *TmpOrig, llvm::TargetLibraryInfo &TLI)
static bool isReadOnly(Operation *op)
static Operation * getFunctionFromCall(CallOpInterface iface)
llvm::cl::opt< bool > EnzymeGlobalActivity
constexpr const char * to_string(ActivityAnalyzer::UseActivity UA)
llvm::cl::opt< bool > looseTypeAnalysis
static void bfs(const Graph &G, const llvm::SetVector< Value > &Sources, DenseMap< Node, Node > &parent)
llvm::PointerUnion< Operation *, Value > Node
TypeTree defaultTypeTreeForLLVM(llvm::Type *ET, llvm::Instruction *I, bool intIsPointer)
static bool isMemFreeLibMFunction(llvm::StringRef str, llvm::Intrinsic::ID *ID=nullptr)
bool isNVLoad(const llvm::Value *V)
static bool isNoEscapingAllocation(const llvm::Function *F)
static bool isNoCapture(const llvm::CallBase *call, size_t idx)
static bool isPointerArithmeticInst(const llvm::Value *V, bool includephi=true, bool includebin=true)
static bool shouldDisableNoWrite(const llvm::CallInst *CI)
static bool isSpecialPtr(llvm::Type *Ty)
static bool hasNoCache(llvm::Value *op)
llvm::cl::opt< bool > EnzymeJuliaAddrLoad
static llvm::StringRef getFuncNameFromCall(const llvm::CallBase *op)
static bool isWriteOnly(const llvm::Function *F, ssize_t arg=-1)
llvm::Function *const newFunc
The function whose instructions we are caching.
llvm::TargetLibraryInfo & TLI
Various analysis results of newFunc.
llvm::SmallPtrSet< llvm::Instruction *, 4 > unnecessaryIntermediates
std::map< const llvm::Value *, bool > knownRecomputeHeuristic
llvm::ValueMap< llvm::Value *, Rematerializer > rematerializableAllocations
DIFFE_TYPE getReturnDiffeType(llvm::Value *orig, bool *primalReturnUsedP, bool *shadowReturnUsedP, DerivativeMode cmode) const
bool isConstantInstruction(const llvm::Instruction *inst) const
llvm::ValueMap< llvm::Value *, ShadowRematerializer > backwardsOnlyShadows
Only loaded from and stored to (not captured), mapped to the stores (and memset).
bool isConstantValue(llvm::Value *val) const
A holder class representing the results of running TypeAnalysis on a given function.
bool anyPointer(llvm::Value *val) const
Whether any part of the top level register can contain a pointer e.g.
TypeTree query(llvm::Value *val) const
The TypeTree of a particular Value.
TypeTree Lookup(size_t len, const llvm::DataLayout &dl) const
Select all submappings whose first index is in range [0, len) and remove the first index.
std::string str() const
Returns a string representation of this TypeTree.
void minCut(const llvm::DataLayout &DL, llvm::LoopInfo &OrigLI, const llvm::SetVector< llvm::Value * > &Recomputes, const llvm::SetVector< llvm::Value * > &Intermediates, llvm::SetVector< llvm::Value * > &Required, llvm::SetVector< llvm::Value * > &MinReq, const GradientUtils *gutils, llvm::TargetLibraryInfo &TLI)
bool callShouldNotUseDerivative(const GradientUtils *gutils, llvm::CallBase &orig, QueryType qtype, const llvm::Value *val)
Return whether or not this is a constant and should use reverse pass.
bool is_value_needed_in_reverse(const GradientUtils *gutils, const llvm::Value *inst, DerivativeMode mode, std::map< UsageKey, bool > &seen, const llvm::SmallPtrSetImpl< llvm::BasicBlock * > &oldUnreachable)
void bfs(const std::map< Node, std::set< Node > > &G, const llvm::SetVector< llvm::Value * > &Recompute, std::map< Node, Node > &parent)
void dump(std::map< Node, std::set< Node > > &G)
bool is_use_directly_needed_in_reverse(const GradientUtils *gutils, const llvm::Value *val, DerivativeMode mode, const llvm::Instruction *user, const llvm::SmallPtrSetImpl< llvm::BasicBlock * > &oldUnreachable, QueryType shadow, bool *recursiveUse=nullptr)
Determine if a value is needed directly to compute the adjoint of the given instruction user.
int cmpLoopNest(llvm::Loop *prev, llvm::Loop *next)