97 std::map<UsageKey, bool> &seen,
98 const llvm::SmallPtrSetImpl<llvm::BasicBlock *> &oldUnreachable) {
105 if (seen.find(idx) != seen.end())
107 if (
auto ainst = dyn_cast<Instruction>(inst)) {
108 assert(ainst->getParent()->getParent() == gutils->
oldFunc);
115 if (
auto op = dyn_cast<BinaryOperator>(inst)) {
116 if (op->getOpcode() == Instruction::FDiv) {
120 llvm::errs() <<
" Need: " <<
to_string(VT) <<
" of " << *inst
121 <<
" in reverse as is active div\n";
122 return seen[idx] =
true;
130 <<
" Need: " <<
to_string(VT) <<
" of " << *inst
131 <<
" in reverse as forward mode error always needs result\n";
132 return seen[idx] =
true;
136 if (
auto CI = dyn_cast<CallInst>(inst)) {
138 if (funcName ==
"julia.get_pgcstack" || funcName ==
"julia.ptls_states")
139 return seen[idx] =
true;
145 for (
auto use : inst->users()) {
149 const Instruction *user = dyn_cast<Instruction>(use);
157 bool recursiveUse =
false;
159 gutils, inst, mode, user, oldUnreachable,
163 return seen[idx] =
true;
166 if (recursiveUse && !OneLevel) {
170 gutils, user, mode, seen, oldUnreachable);
173 gutils, user, mode, seen, oldUnreachable);
176 llvm::errs() <<
" Need: " <<
to_string(VT) <<
" of " << *inst
177 <<
" in reverse as shadow sub-need " << *user <<
"\n";
178 return seen[idx] =
true;
182 if (!TR.
allFloat(
const_cast<Value *
>(inst)))
183 if (
auto IVI = dyn_cast<Instruction>(user)) {
184 bool inserted =
false;
185 if (
auto II = dyn_cast<InsertValueInst>(IVI))
186 inserted = II->getInsertedValueOperand() == inst ||
187 II->getAggregateOperand() == inst;
188 if (
auto II = dyn_cast<ExtractValueInst>(IVI))
189 inserted = II->getAggregateOperand() == inst;
190 if (
auto II = dyn_cast<InsertElementInst>(IVI))
191 inserted = II->getOperand(1) == inst || II->getOperand(0) == inst;
192 if (
auto II = dyn_cast<ExtractElementInst>(IVI))
193 inserted = II->getOperand(0) == inst;
195 SmallVector<const Instruction *, 1> todo;
197 while (todo.size()) {
198 auto cur = todo.pop_back_val();
199 for (
auto u : cur->users()) {
200 if (
auto IVI2 = dyn_cast<InsertValueInst>(u)) {
201 todo.push_back(IVI2);
204 if (
auto IVI2 = dyn_cast<ExtractValueInst>(u)) {
205 todo.push_back(IVI2);
208 if (
auto IVI2 = dyn_cast<InsertElementInst>(u)) {
209 todo.push_back(IVI2);
212 if (
auto IVI2 = dyn_cast<ExtractElementInst>(u)) {
213 todo.push_back(IVI2);
217 bool partial =
false;
218 if (
auto UI = dyn_cast<Instruction>(u)) {
220 const_cast<Instruction *
>(cur))) {
221 bool recursiveUse =
false;
223 gutils, cur, mode, UI, oldUnreachable,
226 }
else if (recursiveUse && !OneLevel) {
228 gutils, UI, mode, seen, oldUnreachable);
231 bool recursiveUse =
false;
233 gutils, cur, mode, UI, oldUnreachable,
236 }
else if (recursiveUse && !OneLevel) {
239 seen, oldUnreachable);
248 <<
" Need (partial) direct " <<
to_string(VT) <<
" of "
249 << *inst <<
" in reverse from insertelem " << *user
250 <<
" via " << *cur <<
" in " << *u <<
"\n";
251 return seen[idx] =
true;
268 llvm::errs() <<
" Need: " <<
to_string(VT) <<
"(" << mode <<
") of "
269 << *inst <<
" in reverse as sub-need " << *user <<
"\n";
270 return seen[idx] =
true;
276 bool isStored =
false;
277 if (
auto SI = dyn_cast<StoreInst>(user))
278 isStored = inst == SI->getValueOperand();
279 else if (
auto MTI = dyn_cast<MemTransferInst>(user)) {
280 isStored = inst == MTI->getSource() || inst == MTI->getLength();
281 }
else if (
auto MS = dyn_cast<MemSetInst>(user)) {
282 isStored = inst == MS->getLength() || inst == MS->getValue();
283 }
else if (
auto CB = dyn_cast<CallBase>(user)) {
285 if (name ==
"julia.write_barrier" ||
286 name ==
"julia.write_barrier_binding") {
287 auto sz = CB->arg_size();
289 for (
size_t i = 1; i < sz; i++)
290 isStored |= inst == CB->getArgOperand(i);
305 if (found != seen.end() && !found->second) {
314 if (pair.second.stores.count(user)) {
315 for (LoadInst *L : pair.second.loads)
319 llvm::errs() <<
" Need: " <<
to_string(VT) <<
" of " << *inst
320 <<
" in reverse as rematload " << *L <<
"\n";
321 return seen[idx] =
true;
323 for (
auto &pair : pair.second.loadLikeCalls)
325 gutils, pair.operand, mode, pair.loadCall, oldUnreachable,
328 seen, oldUnreachable)) {
330 llvm::errs() <<
" Need: " <<
to_string(VT) <<
" of " << *inst
331 <<
" in reverse as rematloadcall "
332 << *pair.loadCall <<
"\n";
333 return seen[idx] =
true;
340 <<
" Need: " <<
to_string(VT) <<
" of " << *inst
341 <<
" in reverse as rematalloc " << *pair.first <<
"\n";
342 return seen[idx] =
true;
364 if (isa<BranchInst>(use) || isa<SwitchInst>(use)) {
366 for (
auto suc : successors(cast<Instruction>(use)->getParent())) {
367 if (!oldUnreachable.count(suc)) {
374 llvm::errs() <<
" Need: " <<
to_string(VT) <<
" of " << *inst
375 <<
" in reverse as control-flow " << *user <<
"\n";
376 return seen[idx] =
true;
379 if (
auto CI = dyn_cast<CallInst>(use)) {
380 if (
auto F = CI->getCalledFunction()) {
381 if (F->getName() ==
"__kmpc_for_static_init_4" ||
382 F->getName() ==
"__kmpc_for_static_init_4u" ||
383 F->getName() ==
"__kmpc_for_static_init_8" ||
384 F->getName() ==
"__kmpc_for_static_init_8u") {
386 llvm::errs() <<
" Need: " <<
to_string(VT) <<
" of " << *inst
387 <<
" in reverse as omp init " << *user <<
"\n";
388 return seen[idx] =
true;
398 bool primalUsedInShadowPointer =
true;
399 if (isa<CastInst>(user) || isa<LoadInst>(user))
400 primalUsedInShadowPointer =
false;
401 if (
auto CI = dyn_cast<CallInst>(user)) {
403 if (funcName ==
"julia.pointer_from_objref") {
404 primalUsedInShadowPointer =
false;
406 if (funcName ==
"julia.gc_loaded") {
407 primalUsedInShadowPointer =
false;
409 if (funcName.contains(
"__enzyme_todense")) {
410 primalUsedInShadowPointer =
false;
412 if (funcName.contains(
"__enzyme_ignore_derivatives")) {
413 primalUsedInShadowPointer =
false;
416 if (
auto GEP = dyn_cast<GetElementPtrInst>(user)) {
417 bool idxUsed =
false;
418 for (
auto &idx : GEP->indices()) {
419 if (idx.get() == inst)
423 primalUsedInShadowPointer =
false;
425 if (
auto II = dyn_cast<IntrinsicInst>(user)) {
427 const std::array<size_t, 4> idxArgsIndices{{0, 1, 2, 4}};
428 bool idxUsed =
false;
429 for (
auto i : idxArgsIndices) {
430 if (II->getOperand(i) == inst)
434 primalUsedInShadowPointer =
false;
439 if (isa<InsertValueInst>(user) || isa<ExtractValueInst>(user))
440 primalUsedInShadowPointer =
false;
442 if (primalUsedInShadowPointer)
443 if (!user->getType()->isVoidTy() &&
444 TR.
anyPointer(
const_cast<Instruction *
>(user))) {
446 gutils, user, mode, seen, oldUnreachable)) {
448 llvm::errs() <<
" Need: " <<
to_string(VT) <<
" of " << *inst
449 <<
" in reverse as used to compute shadow ptr "
451 return seen[idx] =
true;
460 if (inst->getType()->isTokenTy()) {
461 llvm::errs() <<
" need " << *inst <<
" via " << *user <<
"\n";
463 assert(!inst->getType()->isTokenTy());
465 return seen[idx] =
true;
520forEachDirectInsertUser(llvm::function_ref<
void(llvm::Instruction *)> f,
522 llvm::Value *val,
bool useCheck) {
523 using namespace llvm;
526 bool inserted =
false;
527 if (
auto II = dyn_cast<InsertValueInst>(IVI))
528 inserted = II->getInsertedValueOperand() == val ||
529 II->getAggregateOperand() == val;
530 if (
auto II = dyn_cast<ExtractValueInst>(IVI))
531 inserted = II->getAggregateOperand() == val;
532 if (
auto II = dyn_cast<InsertElementInst>(IVI))
533 inserted = II->getOperand(1) == val || II->getOperand(0) == val;
534 if (
auto II = dyn_cast<ExtractElementInst>(IVI))
535 inserted = II->getOperand(0) == val;
537 SmallVector<Instruction *, 1> todo;
539 while (todo.size()) {
540 auto cur = todo.pop_back_val();
541 for (
auto u : cur->users()) {
542 if (isa<InsertValueInst>(u) || isa<InsertElementInst>(u) ||
543 isa<ExtractValueInst>(u) || isa<ExtractElementInst>(u)) {
544 auto I2 = cast<Instruction>(u);
545 bool subCheck = useCheck;