Enzyme main
Loading...
Searching...
No Matches
AdjointGenerator.h
Go to the documentation of this file.
1//===- AdjointGenerator.h - Implementation of Adjoint's of instructions --===//
2//
3// Enzyme Project
4//
5// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9// If using this code in an academic setting, please cite the following:
10// @incollection{enzymeNeurips,
11// title = {Instead of Rewriting Foreign Code for Machine Learning,
12// Automatically Synthesize Fast Gradients},
13// author = {Moses, William S. and Churavy, Valentin},
14// booktitle = {Advances in Neural Information Processing Systems 33},
15// year = {2020},
16// note = {To appear in},
17// }
18//
19//===----------------------------------------------------------------------===//
20//
21// This file contains an instruction visitor AdjointGenerator that generates
22// the corresponding augmented forward pass code, and adjoints for all
23// LLVM instructions.
24//
25//===----------------------------------------------------------------------===//
26
27#ifndef ENZYME_ADJOINT_GENERATOR_H
28#define ENZYME_ADJOINT_GENERATOR_H
29
30#include "llvm/ADT/ArrayRef.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/SmallSet.h"
33#include "llvm/ADT/SmallVector.h"
34#include "llvm/Analysis/ValueTracking.h"
35#include "llvm/IR/Constants.h"
36#include "llvm/IR/DerivedTypes.h"
37#include "llvm/IR/IntrinsicsX86.h"
38#include "llvm/IR/Value.h"
39#include "llvm/Transforms/Utils/BasicBlockUtils.h"
40#include "llvm/Transforms/Utils/Cloning.h"
41
42#include "DiffeGradientUtils.h"
44#include "EnzymeLogic.h"
45#include "FunctionUtils.h"
46#include "GradientUtils.h"
47#include "LibraryFuncs.h"
48#include "TraceUtils.h"
49#include "TypeAnalysis/TBAA.h"
50
51#define DEBUG_TYPE "enzyme"
52
53// Helper instruction visitor that generates adjoints
54class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
55private:
56 // Type of code being generated (forward, reverse, or both)
57 const DerivativeMode Mode;
58
59 GradientUtils *const gutils;
60 llvm::ArrayRef<DIFFE_TYPE> constant_args;
61 DIFFE_TYPE retType;
62 TypeResults &TR = gutils->TR;
63 std::function<unsigned(llvm::Instruction *, CacheType, llvm::IRBuilder<> &)>
64 getIndex;
65 const std::map<llvm::CallInst *, std::pair<bool, const std::vector<bool>>>
66 overwritten_args_map;
67 const AugmentedReturn *augmentedReturn;
68 const std::map<llvm::ReturnInst *, llvm::StoreInst *> *replacedReturns;
69
70 const llvm::SmallPtrSetImpl<const llvm::Value *> &unnecessaryValues;
71 const llvm::SmallPtrSetImpl<const llvm::Instruction *>
72 &unnecessaryInstructions;
73 const llvm::SmallPtrSetImpl<const llvm::Instruction *> &unnecessaryStores;
74 const llvm::SmallPtrSetImpl<llvm::BasicBlock *> &oldUnreachable;
75
76public:
78 DerivativeMode Mode, GradientUtils *gutils,
79 llvm::ArrayRef<DIFFE_TYPE> constant_args, DIFFE_TYPE retType,
80 std::function<unsigned(llvm::Instruction *, CacheType,
81 llvm::IRBuilder<> &)>
82 getIndex,
83 const std::map<llvm::CallInst *, std::pair<bool, const std::vector<bool>>>
84 overwritten_args_map,
85 const AugmentedReturn *augmentedReturn,
86 const std::map<llvm::ReturnInst *, llvm::StoreInst *> *replacedReturns,
87 const llvm::SmallPtrSetImpl<const llvm::Value *> &unnecessaryValues,
88 const llvm::SmallPtrSetImpl<const llvm::Instruction *>
89 &unnecessaryInstructions,
90 const llvm::SmallPtrSetImpl<const llvm::Instruction *> &unnecessaryStores,
91 const llvm::SmallPtrSetImpl<llvm::BasicBlock *> &oldUnreachable)
92 : Mode(Mode), gutils(gutils), constant_args(constant_args),
93 retType(retType), getIndex(getIndex),
94 overwritten_args_map(overwritten_args_map),
95 augmentedReturn(augmentedReturn), replacedReturns(replacedReturns),
96 unnecessaryValues(unnecessaryValues),
97 unnecessaryInstructions(unnecessaryInstructions),
98 unnecessaryStores(unnecessaryStores), oldUnreachable(oldUnreachable) {
99 using namespace llvm;
100
101 assert(TR.getFunction() == gutils->oldFunc);
102 for (auto &pair : TR.analyzer->analysis) {
103 if (auto in = dyn_cast<Instruction>(pair.first)) {
104 if (in->getParent()->getParent() != gutils->oldFunc) {
105 llvm::errs() << "inf: " << *in->getParent()->getParent() << "\n";
106 llvm::errs() << "gutils->oldFunc: " << *gutils->oldFunc << "\n";
107 llvm::errs() << "in: " << *in << "\n";
108 }
109 assert(in->getParent()->getParent() == gutils->oldFunc);
110 }
111 }
112 }
113
114 void eraseIfUnused(llvm::Instruction &I, bool erase = true,
115 bool check = true) {
116 using namespace llvm;
117
118 bool used =
119 unnecessaryInstructions.find(&I) == unnecessaryInstructions.end();
120 if (!used) {
121 // if decided to cache a value, preserve it here for later
122 // replacement in EnzymeLogic
123 auto found = gutils->knownRecomputeHeuristic.find(&I);
124 if (found != gutils->knownRecomputeHeuristic.end() && !found->second)
125 used = true;
126 }
127 auto iload = gutils->getNewFromOriginal((llvm::Value *)&I);
128 if (used && check)
129 return;
130
131 if (auto newi = dyn_cast<Instruction>(iload))
132 gutils->eraseWithPlaceholder(newi, &I, "_replacementA", erase);
133 }
134
135 llvm::Value *MPI_TYPE_SIZE(llvm::Value *DT, llvm::IRBuilder<> &B,
136 llvm::Type *intType, llvm::Function *caller) {
137 using namespace llvm;
138
139 if (DT->getType()->isIntegerTy())
140 DT = B.CreateIntToPtr(DT, getInt8PtrTy(DT->getContext()));
141
142 if (Constant *C = dyn_cast<Constant>(DT)) {
143 while (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) {
144 C = CE->getOperand(0);
145 }
146 if (auto GV = dyn_cast<GlobalVariable>(C)) {
147 if (GV->getName() == "ompi_mpi_double") {
148 return ConstantInt::get(intType, 8, false);
149 } else if (GV->getName() == "ompi_mpi_float") {
150 return ConstantInt::get(intType, 4, false);
151 }
152 }
153 }
154 Type *pargs[] = {getInt8PtrTy(DT->getContext()), getUnqual(intType)};
155 auto FT = FunctionType::get(intType, pargs, false);
156 auto alloc = IRBuilder<>(gutils->inversionAllocs).CreateAlloca(intType);
157 llvm::Value *args[] = {DT, alloc};
158 if (DT->getType() != pargs[0])
159 args[0] = B.CreateBitCast(args[0], pargs[0]);
160 AttributeList AL;
161 AL = AL.addParamAttribute(DT->getContext(), 0,
162 Attribute::AttrKind::ReadOnly);
163 AL = addFunctionNoCapture(DT->getContext(), AL, 0);
164 AL =
165 AL.addParamAttribute(DT->getContext(), 0, Attribute::AttrKind::NoAlias);
166 AL =
167 AL.addParamAttribute(DT->getContext(), 0, Attribute::AttrKind::NonNull);
168 AL = AL.addParamAttribute(DT->getContext(), 1,
169 Attribute::AttrKind::WriteOnly);
170 AL = addFunctionNoCapture(DT->getContext(), AL, 1);
171 AL =
172 AL.addParamAttribute(DT->getContext(), 1, Attribute::AttrKind::NoAlias);
173 AL =
174 AL.addParamAttribute(DT->getContext(), 1, Attribute::AttrKind::NonNull);
175 AL = AL.addAttributeAtIndex(DT->getContext(), AttributeList::FunctionIndex,
176 Attribute::AttrKind::NoUnwind);
177 AL = AL.addAttributeAtIndex(DT->getContext(), AttributeList::FunctionIndex,
178 Attribute::AttrKind::NoFree);
179 AL = AL.addAttributeAtIndex(DT->getContext(), AttributeList::FunctionIndex,
180 Attribute::AttrKind::NoSync);
181 AL = AL.addAttributeAtIndex(DT->getContext(), AttributeList::FunctionIndex,
182 Attribute::AttrKind::WillReturn);
183 auto CI = B.CreateCall(
184 B.GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
185 getRenamedPerCallingConv(caller->getName(), "MPI_Type_size"), FT,
186 AL),
187 args);
188#if LLVM_VERSION_MAJOR >= 16
189 CI->setOnlyAccessesArgMemory();
190#else
191 CI->addAttributeAtIndex(AttributeList::FunctionIndex,
192 Attribute::ArgMemOnly);
193#endif
194 return B.CreateLoad(intType, alloc);
195 }
196
197 // To be double-checked against the functionality needed and the respective
198 // implementation in Adjoint-MPI
199 llvm::Value *MPI_COMM_RANK(llvm::Value *comm, llvm::IRBuilder<> &B,
200 llvm::Type *rankTy, llvm::Function *caller) {
201 using namespace llvm;
202
203 Type *pargs[] = {comm->getType(), getUnqual(rankTy)};
204 auto FT = FunctionType::get(rankTy, pargs, false);
205 auto &context = comm->getContext();
206 auto alloc = IRBuilder<>(gutils->inversionAllocs).CreateAlloca(rankTy);
207 AttributeList AL;
208 AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::ReadOnly);
209 AL = addFunctionNoCapture(context, AL, 0);
210 AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::NoAlias);
211 AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::NonNull);
212 AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::WriteOnly);
213 AL = addFunctionNoCapture(context, AL, 1);
214 AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::NoAlias);
215 AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::NonNull);
216 AL = AL.addAttributeAtIndex(context, AttributeList::FunctionIndex,
217 Attribute::AttrKind::NoUnwind);
218 AL = AL.addAttributeAtIndex(context, AttributeList::FunctionIndex,
219 Attribute::AttrKind::NoFree);
220 AL = AL.addAttributeAtIndex(context, AttributeList::FunctionIndex,
221 Attribute::AttrKind::NoSync);
222 AL = AL.addAttributeAtIndex(context, AttributeList::FunctionIndex,
223 Attribute::AttrKind::WillReturn);
224 llvm::Value *args[] = {comm, alloc};
225 B.CreateCall(
226 B.GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
227 getRenamedPerCallingConv(caller->getName(), "MPI_Comm_rank"), FT,
228 AL),
229 args);
230 return B.CreateLoad(rankTy, alloc);
231 }
232
233 llvm::Value *MPI_COMM_SIZE(llvm::Value *comm, llvm::IRBuilder<> &B,
234 llvm::Type *rankTy, llvm::Function *caller) {
235 using namespace llvm;
236
237 Type *pargs[] = {comm->getType(), getUnqual(rankTy)};
238 auto FT = FunctionType::get(rankTy, pargs, false);
239 auto &context = comm->getContext();
240 auto alloc = IRBuilder<>(gutils->inversionAllocs).CreateAlloca(rankTy);
241 AttributeList AL;
242 AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::ReadOnly);
243 AL = addFunctionNoCapture(context, AL, 0);
244 AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::NoAlias);
245 AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::NonNull);
246 AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::WriteOnly);
247 AL = addFunctionNoCapture(context, AL, 1);
248 AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::NoAlias);
249 AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::NonNull);
250 AL = AL.addAttributeAtIndex(context, AttributeList::FunctionIndex,
251 Attribute::AttrKind::NoUnwind);
252 AL = AL.addAttributeAtIndex(context, AttributeList::FunctionIndex,
253 Attribute::AttrKind::NoFree);
254 AL = AL.addAttributeAtIndex(context, AttributeList::FunctionIndex,
255 Attribute::AttrKind::NoSync);
256 AL = AL.addAttributeAtIndex(context, AttributeList::FunctionIndex,
257 Attribute::AttrKind::WillReturn);
258 llvm::Value *args[] = {comm, alloc};
259 B.CreateCall(
260 B.GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
261 getRenamedPerCallingConv(caller->getName(), "MPI_Comm_size"), FT,
262 AL),
263 args);
264 return B.CreateLoad(rankTy, alloc);
265 }
266
267 void visitInstruction(llvm::Instruction &inst) {
268 using namespace llvm;
269
270 // TODO explicitly handle all instructions rather than using the catch all
271 // below
272
273 switch (inst.getOpcode()) {
274#include "InstructionDerivatives.inc"
275 default:
276 break;
277 }
278
279 if (gutils->isConstantInstruction(&inst)) {
281 return;
282 eraseIfUnused(inst);
283 return;
284 }
285
286 std::string s;
287 llvm::raw_string_ostream ss(s);
288 ss << "in Mode: " << to_string(Mode) << "\n";
289 ss << "cannot handle unknown instruction\n" << inst;
290 IRBuilder<> Builder2(&inst);
291 getForwardBuilder(Builder2);
292 EmitNoDerivativeError(ss.str(), inst, gutils, Builder2);
293 if (!gutils->isConstantValue(&inst)) {
294 if (Mode == DerivativeMode::ForwardMode ||
297 setDiffe(&inst,
298 Constant::getNullValue(gutils->getShadowType(inst.getType())),
299 Builder2);
300 }
301 if (!inst.getType()->isVoidTy()) {
302 for (auto &U :
303 make_early_inc_range(gutils->getNewFromOriginal(&inst)->uses())) {
304 U.set(UndefValue::get(inst.getType()));
305 }
306 }
307 eraseIfUnused(inst, /*erase*/ true, /*check*/ false);
308 return;
309 }
310
311 // Common function for falling back to the implementation
312 // of dual propagation, as available in invertPointerM.
313 void forwardModeInvertedPointerFallback(llvm::Instruction &I) {
314 using namespace llvm;
315
316 auto found = gutils->invertedPointers.find(&I);
317 if (gutils->isConstantValue(&I)) {
318 assert(found == gutils->invertedPointers.end());
319 return;
320 }
321
322 assert(found != gutils->invertedPointers.end());
323 auto placeholder = cast<PHINode>(&*found->second);
324 gutils->invertedPointers.erase(found);
325
327 gutils, &I, Mode, oldUnreachable)) {
328 gutils->erase(placeholder);
329 return;
330 }
331
332 IRBuilder<> Builder2(&I);
333 getForwardBuilder(Builder2);
334
335 auto toset = gutils->invertPointerM(&I, Builder2, /*nullShadow*/ true);
336
337 assert(toset != placeholder);
338
339 gutils->replaceAWithB(placeholder, toset);
340 placeholder->replaceAllUsesWith(toset);
341 gutils->erase(placeholder);
342 gutils->invertedPointers.insert(
343 std::make_pair((const Value *)&I, InvertedPointerVH(gutils, toset)));
344 return;
345 }
346
347 void visitAllocaInst(llvm::AllocaInst &I) {
348 eraseIfUnused(I);
349 switch (Mode) {
354 return;
355 }
356 default:
357 return;
358 }
359 }
360
361 void visitICmpInst(llvm::ICmpInst &I) { eraseIfUnused(I); }
362
363 void visitFCmpInst(llvm::FCmpInst &I) { eraseIfUnused(I); }
364
365 void visitLoadLike(llvm::Instruction &I, llvm::MaybeAlign alignment,
366 bool constantval, llvm::Value *mask = nullptr,
367 llvm::Value *orig_maskInit = nullptr) {
368 using namespace llvm;
369
370 auto &DL = gutils->newFunc->getParent()->getDataLayout();
371 auto LoadSize = (DL.getTypeSizeInBits(I.getType()) + 1) / 8;
372
373 assert(Mode == DerivativeMode::ForwardMode ||
375 assert(Mode == DerivativeMode::ForwardMode ||
377 gutils->can_modref_map->find(&I) != gutils->can_modref_map->end());
378 bool can_modref = (Mode == DerivativeMode::ForwardMode ||
380 ? false
381 : gutils->can_modref_map->find(&I)->second;
382
383 constantval |= gutils->isConstantValue(&I);
384
385 Type *type = gutils->getShadowType(I.getType());
386 (void)type;
387
388 auto *newi = dyn_cast<Instruction>(gutils->getNewFromOriginal(&I));
389
390 SmallVector<Metadata *, 1> scopeMD = {
391 gutils->getDerivativeAliasScope(I.getOperand(0), -1)};
392 if (auto prev = I.getMetadata(LLVMContext::MD_alias_scope)) {
393 for (auto &M : cast<MDNode>(prev)->operands()) {
394 scopeMD.push_back(M);
395 }
396 }
397 auto scope = MDNode::get(I.getContext(), scopeMD);
398 newi->setMetadata(LLVMContext::MD_alias_scope, scope);
399
400 SmallVector<Metadata *, 1> MDs;
401 for (size_t j = 0; j < gutils->getWidth(); j++) {
402 MDs.push_back(gutils->getDerivativeAliasScope(I.getOperand(0), j));
403 }
404 if (auto prev = I.getMetadata(LLVMContext::MD_noalias)) {
405 for (auto &M : cast<MDNode>(prev)->operands()) {
406 MDs.push_back(M);
407 }
408 }
409 auto noscope = MDNode::get(I.getContext(), MDs);
410 newi->setMetadata(LLVMContext::MD_noalias, noscope);
411
412 auto vd = TR.query(&I);
413
414 IRBuilder<> BuilderZ(newi);
415 if (!vd.isKnown()) {
416 std::string str;
417 raw_string_ostream ss(str);
418 ss << "Cannot deduce type of load " << I;
419 auto ET = I.getType();
420 if (looseTypeAnalysis || true) {
421 vd = defaultTypeTreeForLLVM(ET, &I);
422 ss << ", assumed " << vd.str() << "\n";
423 EmitWarning("CannotDeduceType", I, ss.str());
424 goto known;
425 }
426 EmitNoTypeError(str, I, gutils, BuilderZ);
427 known:;
428 }
429
430 if (Mode == DerivativeMode::ForwardMode ||
433 if (!constantval) {
434 auto found = gutils->invertedPointers.find(&I);
435 assert(found != gutils->invertedPointers.end());
436 Instruction *placeholder = cast<Instruction>(&*found->second);
437 assert(placeholder->getType() == type);
438 gutils->invertedPointers.erase(found);
439
440 // only make shadow where caching needed
442 QueryType::Shadow>(gutils, &I, Mode, oldUnreachable)) {
443 gutils->erase(placeholder);
444 return;
445 }
446
447 if (can_modref) {
448 if (vd[{-1}].isPossiblePointer()) {
449 Value *newip = gutils->cacheForReverse(
450 BuilderZ, placeholder,
451 getIndex(&I, CacheType::Shadow, BuilderZ));
452 assert(newip->getType() == type);
453 gutils->invertedPointers.insert(std::make_pair(
454 (const Value *)&I, InvertedPointerVH(gutils, newip)));
455 } else {
456 gutils->erase(placeholder);
457 }
458 } else {
459 Value *newip = gutils->invertPointerM(&I, BuilderZ);
460 if (gutils->runtimeActivity && vd[{-1}].isFloat()) {
461 // TODO handle mask
462 assert(!mask);
463
464 auto rule = [&](Value *inop, Value *newip) -> Value * {
465 Value *shadow = BuilderZ.CreateICmpNE(
466 gutils->getNewFromOriginal(I.getOperand(0)), inop);
467 newip = CreateSelect(BuilderZ, shadow, newip,
468 Constant::getNullValue(newip->getType()));
469 return newip;
470 };
471 newip = applyChainRule(
472 I.getType(), BuilderZ, rule,
473 gutils->invertPointerM(I.getOperand(0), BuilderZ), newip);
474 }
475 assert(newip->getType() == type);
476 placeholder->replaceAllUsesWith(newip);
477 gutils->erase(placeholder);
478 gutils->invertedPointers.erase(&I);
479 gutils->invertedPointers.insert(std::make_pair(
480 (const Value *)&I, InvertedPointerVH(gutils, newip)));
481 }
482 }
483 return;
484 }
485
486 //! Store inverted pointer loads that need to be cached for use in reverse
487 //! pass
488 if (vd[{-1}].isPossiblePointer()) {
489 auto found = gutils->invertedPointers.find(&I);
490 if (found != gutils->invertedPointers.end()) {
491 Instruction *placeholder = cast<Instruction>(&*found->second);
492 assert(placeholder->getType() == type);
493 gutils->invertedPointers.erase(found);
494
495 if (!constantval) {
496 Value *newip = nullptr;
497
498 // TODO: In the case of fwd mode this should be true if the loaded
499 // value itself is used as a pointer.
501 QueryType::Shadow>(gutils, &I, Mode, oldUnreachable);
502
503 switch (Mode) {
504
507 if (!needShadow) {
508 gutils->erase(placeholder);
509 } else {
510 newip = gutils->invertPointerM(&I, BuilderZ);
511 assert(newip->getType() == type);
512 if (Mode == DerivativeMode::ReverseModePrimal && can_modref &&
514 QueryType::Shadow>(gutils, &I,
516 oldUnreachable)) {
517 gutils->cacheForReverse(
518 BuilderZ, newip, getIndex(&I, CacheType::Shadow, BuilderZ));
519 }
520 placeholder->replaceAllUsesWith(newip);
521 gutils->erase(placeholder);
522 gutils->invertedPointers.insert(std::make_pair(
523 (const Value *)&I, InvertedPointerVH(gutils, newip)));
524 }
525 break;
526 }
530 assert(0 && "impossible branch");
531 return;
532 }
534 if (!needShadow) {
535 gutils->erase(placeholder);
536 } else {
537 // only make shadow where caching needed
538 if (can_modref) {
539 newip = gutils->cacheForReverse(
540 BuilderZ, placeholder,
541 getIndex(&I, CacheType::Shadow, BuilderZ));
542 assert(newip->getType() == type);
543 gutils->invertedPointers.insert(std::make_pair(
544 (const Value *)&I, InvertedPointerVH(gutils, newip)));
545 } else {
546 newip = gutils->invertPointerM(&I, BuilderZ);
547 assert(newip->getType() == type);
548 placeholder->replaceAllUsesWith(newip);
549 gutils->erase(placeholder);
550 gutils->invertedPointers.insert(std::make_pair(
551 (const Value *)&I, InvertedPointerVH(gutils, newip)));
552 }
553 }
554 break;
555 }
556 }
557
558 } else {
559 gutils->erase(placeholder);
560 }
561 }
562 }
563
564 Value *inst = newi;
565
566 //! Store loads that need to be cached for use in reverse pass
567
568 // Only cache value here if caching decision isn't precomputed.
569 // Otherwise caching will be done inside EnzymeLogic.cpp at
570 // the end of the function jointly.
571 if (Mode != DerivativeMode::ForwardMode &&
573 !gutils->knownRecomputeHeuristic.count(&I) && can_modref &&
574 !gutils->unnecessaryIntermediates.count(&I)) {
575 // we can pre initialize all the knownRecomputeHeuristic values to false
576 // (not needing) as we may assume that minCutCache already preserves
577 // everything it requires.
578 std::map<UsageKey, bool> Seen;
579 bool primalNeededInReverse = false;
580 for (auto pair : gutils->knownRecomputeHeuristic)
581 if (!pair.second) {
582 Seen[UsageKey(pair.first, QueryType::Primal)] = false;
583 if (pair.first == &I)
584 primalNeededInReverse = true;
585 }
586 auto cacheMode = (Mode == DerivativeMode::ReverseModePrimal)
588 : Mode;
589 primalNeededInReverse |=
591 QueryType::Primal>(gutils, &I, cacheMode, Seen, oldUnreachable);
592 if (primalNeededInReverse) {
593 inst = gutils->cacheForReverse(BuilderZ, newi,
594 getIndex(&I, CacheType::Self, BuilderZ));
595 (void)inst;
596 assert(inst->getType() == I.getType());
597
601 assert(inst != newi);
602 } else {
603 assert(inst == newi);
604 }
605 }
606 }
607
609 return;
610
611 if (constantval)
612 return;
613
615 // Assume that non enzyme_shadow globals are inactive
616 // If we ever store to a global variable, we will error if it doesn't
617 // have a shadow This allows functions who only read global memory to
618 // have their derivative computed Note that this is too aggressive for
619 // general programs as if the global aliases with an argument something
620 // that is written to, then we will have a logical error
621 if (auto arg = dyn_cast<GlobalVariable>(I.getOperand(0))) {
622 if (!hasMetadata(arg, "enzyme_shadow")) {
623 return;
624 }
625 }
626 }
627
628 // Only propagate if instruction is active. The value can be active and not
629 // the instruction if the value is a potential pointer. This may not be
630 // caught by type analysis is the result does not have a known type.
631 if (!gutils->isConstantInstruction(&I)) {
632 switch (Mode) {
636 assert(0 && "impossible branch");
637 return;
638 }
641
642 IRBuilder<> Builder2(&I);
643 getReverseBuilder(Builder2);
644
645 Value *prediff = nullptr;
646
647 for (ssize_t i = -1; i < (ssize_t)LoadSize; ++i) {
648 if (vd[{(int)i}].isFloat()) {
649 prediff = diffe(&I, Builder2);
650 break;
651 }
652 }
653
654 Value *premask = nullptr;
655
656 if (prediff && mask) {
657 premask = lookup(mask, Builder2);
658 }
659
660 if (prediff)
661 ((DiffeGradientUtils *)gutils)
662 ->addToInvertedPtrDiffe(&I, &I, vd, LoadSize, I.getOperand(0),
663 prediff, Builder2, alignment, premask);
664
665 unsigned start = 0;
666 unsigned size = LoadSize;
667
668 while (1) {
669 unsigned nextStart = size;
670
671 auto dt = vd[{-1}];
672 for (size_t i = start; i < size; ++i) {
673 bool Legal = true;
674 dt.checkedOrIn(vd[{(int)i}], /*PointerIntSame*/ true, Legal);
675 if (!Legal) {
676 nextStart = i;
677 break;
678 }
679 }
680 if (!dt.isKnown()) {
681 std::string str;
682 raw_string_ostream ss(str);
683 ss << "Cannot deduce type of load " << I;
684 ss << " vd:" << vd.str() << " start:" << start << " size: " << size
685 << " dt:" << dt.str() << "\n";
686 EmitNoTypeError(str, I, gutils, BuilderZ);
687 continue;
688 }
689 assert(dt.isKnown());
690
691 if (Type *isfloat = dt.isFloat()) {
692 if (premask && !gutils->isConstantValue(orig_maskInit)) {
693 // Masked partial type is unhanled.
694 if (premask)
695 assert(start == 0 && nextStart == LoadSize);
696 addToDiffe(orig_maskInit, prediff, Builder2, isfloat,
697 Builder2.CreateNot(premask));
698 }
699 }
700
701 if (nextStart == size)
702 break;
703 start = nextStart;
704 }
705 break;
706 }
708 break;
709 }
710 }
711 }
712
713 void visitLoadInst(llvm::LoadInst &LI) {
714 using namespace llvm;
715
716 // If a load of an omp init argument, don't cache for reverse
717 // and don't do any adjoint propagation (assumed integral)
718 for (auto U : LI.getPointerOperand()->users()) {
719 if (auto CI = dyn_cast<CallInst>(U)) {
720 if (auto F = CI->getCalledFunction()) {
721 if (F->getName() == "__kmpc_for_static_init_4" ||
722 F->getName() == "__kmpc_for_static_init_4u" ||
723 F->getName() == "__kmpc_for_static_init_8" ||
724 F->getName() == "__kmpc_for_static_init_8u") {
725 eraseIfUnused(LI);
726 return;
727 }
728 }
729 }
730 }
731
732 auto alignment = LI.getAlign();
733 auto &DL = gutils->newFunc->getParent()->getDataLayout();
734
735 bool constantval = parseTBAA(LI, DL, nullptr)[{-1}].isIntegral();
736 visitLoadLike(LI, alignment, constantval);
737 eraseIfUnused(LI);
738 }
739
740 void visitAtomicRMWInst(llvm::AtomicRMWInst &I) {
741 using namespace llvm;
742
743 if (gutils->isConstantInstruction(&I) && gutils->isConstantValue(&I)) {
746 eraseIfUnused(I, /*erase*/ true, /*check*/ false);
747 } else {
748 eraseIfUnused(I);
749 }
750 return;
751 }
752
753 IRBuilder<> BuilderZ(&I);
754 getForwardBuilder(BuilderZ);
755
756 switch (I.getOperation()) {
757 case AtomicRMWInst::FAdd:
758 case AtomicRMWInst::FSub: {
759
760 if (Mode == DerivativeMode::ForwardMode ||
763 auto rule = [&](Value *ptr, Value *dif) -> Value * {
764 if (dif == nullptr)
765 dif = Constant::getNullValue(I.getType());
766 if (!gutils->isConstantInstruction(&I)) {
767 assert(ptr);
768 AtomicRMWInst *rmw = nullptr;
769 rmw = BuilderZ.CreateAtomicRMW(I.getOperation(), ptr, dif,
770 I.getAlign(), I.getOrdering(),
771 I.getSyncScopeID());
772 rmw->setVolatile(I.isVolatile());
773 if (gutils->isConstantValue(&I))
774 return Constant::getNullValue(dif->getType());
775 else
776 return rmw;
777 } else {
778 assert(gutils->isConstantValue(&I));
779 return Constant::getNullValue(dif->getType());
780 }
781 };
782
783 Value *diff = applyChainRule(
784 I.getType(), BuilderZ, rule,
785 gutils->isConstantValue(I.getPointerOperand())
786 ? nullptr
787 : gutils->invertPointerM(I.getPointerOperand(), BuilderZ),
788 gutils->isConstantValue(I.getValOperand())
789 ? nullptr
790 : gutils->invertPointerM(I.getValOperand(), BuilderZ));
791 if (!gutils->isConstantValue(&I))
792 setDiffe(&I, diff, BuilderZ);
793 return;
794 }
796 eraseIfUnused(I);
797 return;
798 }
801 gutils->isConstantValue(&I)) {
802 if (!gutils->isConstantValue(I.getValOperand())) {
803 assert(!gutils->isConstantValue(I.getPointerOperand()));
804 IRBuilder<> Builder2(&I);
805 getReverseBuilder(Builder2);
806 Value *ip = gutils->invertPointerM(I.getPointerOperand(), Builder2);
807 ip = lookup(ip, Builder2);
808 auto order = I.getOrdering();
809 if (order == AtomicOrdering::Release)
810 order = AtomicOrdering::Monotonic;
811 else if (order == AtomicOrdering::AcquireRelease)
812 order = AtomicOrdering::Acquire;
813
814 auto rule = [&](Value *ip) -> Value * {
815 LoadInst *dif1 =
816 Builder2.CreateLoad(I.getType(), ip, I.isVolatile());
817
818 dif1->setAlignment(I.getAlign());
819 dif1->setOrdering(order);
820 dif1->setSyncScopeID(I.getSyncScopeID());
821 return dif1;
822 };
823 Value *diff = applyChainRule(I.getType(), Builder2, rule, ip);
824
825 addToDiffe(I.getValOperand(), diff, Builder2,
826 I.getValOperand()->getType()->getScalarType());
827 }
829 eraseIfUnused(I, /*erase*/ true, /*check*/ false);
830 } else
831 eraseIfUnused(I);
832 return;
833 }
834 break;
835 }
836 default:
837 break;
838 }
839
840 if (looseTypeAnalysis) {
841 auto &DL = gutils->newFunc->getParent()->getDataLayout();
842 auto valType = I.getValOperand()->getType();
843 auto storeSize = DL.getTypeSizeInBits(valType) / 8;
844 auto fp = TR.firstPointer(storeSize, I.getPointerOperand(), &I,
845 /*errifnotfound*/ false,
846 /*pointerIntSame*/ true);
847 if (!fp.isKnown() && valType->isIntOrIntVectorTy()) {
850 eraseIfUnused(I, /*erase*/ true, /*check*/ false);
851 } else
852 eraseIfUnused(I);
853 return;
854 }
855 }
856 std::string s;
857 llvm::raw_string_ostream ss(s);
858 ss << *I.getParent()->getParent() << "\n" << I << "\n";
859 ss << " Active atomic inst not yet handled";
860 EmitNoDerivativeError(ss.str(), I, gutils, BuilderZ);
861 if (!gutils->isConstantValue(&I)) {
862 if (Mode == DerivativeMode::ForwardMode ||
865 setDiffe(&I, Constant::getNullValue(gutils->getShadowType(I.getType())),
866 BuilderZ);
867 }
868 if (!I.getType()->isVoidTy()) {
869 for (auto &U :
870 make_early_inc_range(gutils->getNewFromOriginal(&I)->uses())) {
871 U.set(UndefValue::get(I.getType()));
872 }
873 }
874 eraseIfUnused(I, /*erase*/ true, /*check*/ false);
875 return;
876 }
877
878 void visitStoreInst(llvm::StoreInst &SI) {
879 using namespace llvm;
880
881 // If a store of an omp init argument, don't delete in reverse
882 // and don't do any adjoint propagation (assumed integral)
883 for (auto U : SI.getPointerOperand()->users()) {
884 if (auto CI = dyn_cast<CallInst>(U)) {
885 if (auto F = CI->getCalledFunction()) {
886 if (F->getName() == "__kmpc_for_static_init_4" ||
887 F->getName() == "__kmpc_for_static_init_4u" ||
888 F->getName() == "__kmpc_for_static_init_8" ||
889 F->getName() == "__kmpc_for_static_init_8u") {
890 return;
891 }
892 }
893 }
894 }
895 auto align = SI.getAlign();
896
897 visitCommonStore(SI, SI.getPointerOperand(), SI.getValueOperand(), align,
898 SI.isVolatile(), SI.getOrdering(), SI.getSyncScopeID(),
899 /*mask=*/nullptr);
900
901 bool forceErase = false;
903 // Since we won't redo the store in the reverse pass, do not
904 // force the write barrier.
905 forceErase = true;
906 for (const auto &pair : gutils->rematerializableAllocations) {
907 // However, if we are rematerailizing the allocationa and not
908 // inside the loop level rematerialization, we do still need the
909 // reverse passes ``fake primal'' store and therefore write barrier
910 if (pair.second.stores.count(&SI) &&
911 (!pair.second.LI || !pair.second.LI->contains(&SI))) {
912 forceErase = false;
913 }
914 }
915 }
916 if (forceErase)
917 eraseIfUnused(SI, /*erase*/ true, /*check*/ false);
918 else
919 eraseIfUnused(SI);
920 }
921
922 void visitCommonStore(llvm::Instruction &I, llvm::Value *orig_ptr,
923 llvm::Value *orig_val, llvm::MaybeAlign prevalign,
924 bool isVolatile, llvm::AtomicOrdering ordering,
925 llvm::SyncScope::ID syncScope, llvm::Value *mask) {
926 using namespace llvm;
927
928 Value *val = gutils->getNewFromOriginal(orig_val);
929 Type *valType = orig_val->getType();
930
931 auto &DL = gutils->newFunc->getParent()->getDataLayout();
932
933 if (unnecessaryStores.count(&I)) {
934 return;
935 }
936
937 if (gutils->isConstantValue(orig_ptr)) {
938 return;
939 }
940
941 SmallVector<Metadata *, 1> scopeMD = {
942 gutils->getDerivativeAliasScope(orig_ptr, -1)};
943 SmallVector<Metadata *, 1> prevScopes;
944 if (auto prev = I.getMetadata(LLVMContext::MD_alias_scope)) {
945 for (auto &M : cast<MDNode>(prev)->operands()) {
946 scopeMD.push_back(M);
947 prevScopes.push_back(M);
948 }
949 }
950 auto scope = MDNode::get(I.getContext(), scopeMD);
951 auto NewI = gutils->getNewFromOriginal(&I);
952 NewI->setMetadata(LLVMContext::MD_alias_scope, scope);
953
954 SmallVector<Metadata *, 1> MDs;
955 SmallVector<Metadata *, 1> prevNoAlias;
956 for (size_t j = 0; j < gutils->getWidth(); j++) {
957 MDs.push_back(gutils->getDerivativeAliasScope(orig_ptr, j));
958 }
959 if (auto prev = I.getMetadata(LLVMContext::MD_noalias)) {
960 for (auto &M : cast<MDNode>(prev)->operands()) {
961 MDs.push_back(M);
962 prevNoAlias.push_back(M);
963 }
964 }
965 auto noscope = MDNode::get(I.getContext(), MDs);
966 NewI->setMetadata(LLVMContext::MD_noalias, noscope);
967
968 bool constantval = gutils->isConstantValue(orig_val) ||
969 parseTBAA(I, DL, nullptr)[{-1}].isIntegral();
970
971 IRBuilder<> BuilderZ(NewI);
972 BuilderZ.setFastMathFlags(getFast());
973
974 // TODO allow recognition of other types that could contain pointers [e.g.
975 // {void*, void*} or <2 x i64> ]
976 auto storeSize = (DL.getTypeSizeInBits(valType) + 7) / 8;
977
978 auto vd = TR.query(orig_ptr).Lookup(storeSize, DL);
979
980 if (!vd.isKnown()) {
981 std::string str;
982 raw_string_ostream ss(str);
983 ss << "Cannot deduce type of store " << I;
984 if (looseTypeAnalysis || true) {
985 vd = defaultTypeTreeForLLVM(valType, &I);
986 ss << ", assumed " << vd.str() << "\n";
987 EmitWarning("CannotDeduceType", I, ss.str());
988 goto known;
989 }
990 EmitNoTypeError(str, I, gutils, BuilderZ);
991 return;
992 known:;
993 }
994
995 if (Mode == DerivativeMode::ForwardMode ||
997
998 auto dt = vd[{-1}];
999 // Only need the full type in forward mode, if storing a constant
1000 // and therefore may need to zero some floats.
1001 if (constantval)
1002 for (size_t i = 0; i < storeSize; ++i) {
1003 bool Legal = true;
1004 dt.checkedOrIn(vd[{(int)i}], /*PointerIntSame*/ true, Legal);
1005 if (!Legal) {
1006 std::string str;
1007 raw_string_ostream ss(str);
1008 ss << "Cannot deduce single type of store " << I << vd.str()
1009 << " size: " << storeSize;
1010 EmitNoTypeError(str, I, gutils, BuilderZ);
1011 return;
1012 }
1013 }
1014
1015 Value *diff = nullptr;
1016 bool needs_writebarrier = false;
1017 if (!gutils->runtimeActivity && constantval) {
1018 if (dt.isPossiblePointer() && vd[{-1, -1}] != BaseType::Integer) {
1019 if (!isa<UndefValue>(orig_val) &&
1020 !isa<ConstantPointerNull>(orig_val)) {
1021 std::string str;
1022 raw_string_ostream ss(str);
1023 ss << "Mismatched activity for: " << I
1024 << " const val: " << *orig_val;
1025 if (CustomErrorHandler) {
1026 diff = unwrap(CustomErrorHandler(
1027 str.c_str(), wrap(&I), ErrorType::MixedActivityError, gutils,
1028 wrap(orig_val), wrap(&BuilderZ)));
1029 if (diff)
1030 needs_writebarrier = true;
1031 } else
1032 EmitWarning("MixedActivityError", I, ss.str());
1033 }
1034 }
1035 }
1036
1037 // TODO type analyze
1038 if (!diff) {
1039 if (!constantval)
1040 diff =
1041 gutils->invertPointerM(orig_val, BuilderZ, /*nullShadow*/ true);
1042 else if (orig_val->getType()->isPointerTy() ||
1043 dt == BaseType::Pointer || dt == BaseType::Integer)
1044 diff =
1045 gutils->invertPointerM(orig_val, BuilderZ, /*nullShadow*/ false);
1046 else
1047 diff =
1048 gutils->invertPointerM(orig_val, BuilderZ, /*nullShadow*/ true);
1049 }
1050
1051 gutils->setPtrDiffe(&I, orig_ptr, diff, BuilderZ, prevalign, 0, storeSize,
1052 isVolatile, ordering, syncScope, mask, prevNoAlias,
1053 prevScopes, needs_writebarrier);
1054
1055 return;
1056 }
1057
1058 unsigned start = 0;
1059
1060 IRBuilder<> Builder2(&I);
1061 BasicBlock *merge = nullptr;
1064 getReverseBuilder(Builder2);
1065
1066 while (1) {
1067 unsigned nextStart = storeSize;
1068
1069 auto dt = vd[{-1}];
1070 for (size_t i = start; i < storeSize; ++i) {
1071 auto nex = vd[{(int)i}];
1072 if ((nex == BaseType::Anything && dt.isFloat()) ||
1073 (dt == BaseType::Anything && nex.isFloat())) {
1074 nextStart = i;
1075 break;
1076 }
1077 bool Legal = true;
1078 dt.checkedOrIn(nex, /*PointerIntSame*/ true, Legal);
1079 if (!Legal) {
1080 nextStart = i;
1081 break;
1082 }
1083 }
1084 unsigned size = nextStart - start;
1085 if (!dt.isKnown()) {
1086
1087 std::string str;
1088 raw_string_ostream ss(str);
1089 ss << "Cannot deduce type of store " << I << vd.str()
1090 << " start: " << start << " size: " << size
1091 << " storeSize: " << storeSize;
1092 EmitNoTypeError(str, I, gutils, BuilderZ);
1093 break;
1094 }
1095
1096 MaybeAlign align;
1097 if (prevalign) {
1098 if (start % prevalign->value() == 0)
1099 align = prevalign;
1100 else
1101 align = Align(1);
1102 }
1103 //! Storing a floating point value
1104 if (Type *FT = dt.isFloat()) {
1105 //! Only need to update the reverse function
1106 switch (Mode) {
1108 break;
1111
1112 if (!merge && gutils->runtimeActivity) {
1113 auto basePtr = getBaseObject(orig_ptr);
1114
1115 // If runtime activity, first see if we can prove that the
1116 // shadow/primal are distinct statically as they are
1117 // allocas/mallocs, if not compare the pointers and conditionally
1118 // execute.
1119 if (!isa<AllocaInst>(basePtr) &&
1120 !isAllocationCall(basePtr, gutils->TLI)) {
1121 auto shadow_ptr =
1122 lookup(gutils->invertPointerM(orig_ptr, Builder2), Builder2);
1123 auto primal_ptr =
1124 lookup(gutils->getNewFromOriginal(orig_ptr), Builder2);
1125 if (gutils->getWidth() != 1) {
1126 shadow_ptr = gutils->extractMeta(Builder2, shadow_ptr, 0);
1127 }
1128 Value *shadow = Builder2.CreateICmpNE(primal_ptr, shadow_ptr);
1129
1130 BasicBlock *current = Builder2.GetInsertBlock();
1131 BasicBlock *conditional = gutils->addReverseBlock(
1132 current, current->getName() + "_active");
1133 merge = gutils->addReverseBlock(conditional,
1134 current->getName() + "_amerge");
1135 Builder2.CreateCondBr(shadow, conditional, merge);
1136 Builder2.SetInsertPoint(conditional);
1137 }
1138 }
1139
1140 if (constantval) {
1141 gutils->setPtrDiffe(
1142 &I, orig_ptr,
1143 Constant::getNullValue(gutils->getShadowType(valType)),
1144 Builder2, align, start, size, isVolatile, ordering, syncScope,
1145 mask, prevNoAlias, prevScopes);
1146 } else {
1147 Value *diff;
1148 Value *maskL = mask;
1149 if (!mask) {
1150 Value *dif1Ptr =
1151 lookup(gutils->invertPointerM(orig_ptr, Builder2), Builder2);
1152
1153 size_t idx = 0;
1154 auto rule = [&](Value *dif1Ptr) {
1155 LoadInst *dif1 =
1156 Builder2.CreateLoad(valType, dif1Ptr, isVolatile);
1157 if (align)
1158 dif1->setAlignment(*align);
1159 dif1->setOrdering(ordering);
1160 dif1->setSyncScopeID(syncScope);
1161
1162 SmallVector<Metadata *, 1> scopeMD = {
1163 gutils->getDerivativeAliasScope(orig_ptr, idx)};
1164 for (auto M : prevScopes)
1165 scopeMD.push_back(M);
1166
1167 SmallVector<Metadata *, 1> MDs;
1168 for (ssize_t j = -1; j < gutils->getWidth(); j++) {
1169 if (j != (ssize_t)idx)
1170 MDs.push_back(gutils->getDerivativeAliasScope(orig_ptr, j));
1171 }
1172 for (auto M : prevNoAlias)
1173 MDs.push_back(M);
1174
1175 dif1->setMetadata(LLVMContext::MD_alias_scope,
1176 MDNode::get(I.getContext(), scopeMD));
1177 dif1->setMetadata(LLVMContext::MD_noalias,
1178 MDNode::get(I.getContext(), MDs));
1179 dif1->setMetadata(LLVMContext::MD_tbaa,
1180 I.getMetadata(LLVMContext::MD_tbaa));
1181 dif1->setMetadata(LLVMContext::MD_tbaa_struct,
1182 I.getMetadata(LLVMContext::MD_tbaa_struct));
1183 idx++;
1184 return dif1;
1185 };
1186
1187 diff = applyChainRule(valType, Builder2, rule, dif1Ptr);
1188 } else {
1189 maskL = lookup(mask, Builder2);
1190 Type *tys[] = {valType, orig_ptr->getType()};
1191 auto F = getIntrinsicDeclaration(gutils->oldFunc->getParent(),
1192 Intrinsic::masked_load, tys);
1193 Value *alignv =
1194 ConstantInt::get(Type::getInt32Ty(mask->getContext()),
1195 align ? align->value() : 0);
1196 Value *ip =
1197 lookup(gutils->invertPointerM(orig_ptr, Builder2), Builder2);
1198
1199 auto rule = [&](Value *ip) {
1200 Value *args[] = {ip, alignv, maskL,
1201 Constant::getNullValue(valType)};
1202 diff = Builder2.CreateCall(F, args);
1203 return diff;
1204 };
1205
1206 diff = applyChainRule(valType, Builder2, rule, ip);
1207 }
1208
1209 gutils->setPtrDiffe(
1210 &I, orig_ptr,
1211 Constant::getNullValue(gutils->getShadowType(valType)),
1212 Builder2, align, start, size, isVolatile, ordering, syncScope,
1213 mask, prevNoAlias, prevScopes);
1214 ((DiffeGradientUtils *)gutils)
1215 ->addToDiffe(orig_val, diff, Builder2, FT, start, size, {},
1216 maskL);
1217 }
1218 break;
1219 }
1223
1224 Type *diffeTy = gutils->getShadowType(valType);
1225
1226 Value *diff = constantval
1227 ? Constant::getNullValue(diffeTy)
1228 : gutils->invertPointerM(orig_val, BuilderZ,
1229 /*nullShadow*/ true);
1230 gutils->setPtrDiffe(&I, orig_ptr, diff, BuilderZ, align, start, size,
1231 isVolatile, ordering, syncScope, mask,
1232 prevNoAlias, prevScopes);
1233
1234 break;
1235 }
1236 }
1237
1238 //! Storing an integer or pointer
1239 } else {
1240 //! Only need to update the forward function
1241
1242 // Don't reproduce mpi null requests
1243 if (constantval)
1244 if (Constant *C = dyn_cast<Constant>(orig_val)) {
1245 while (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) {
1246 C = CE->getOperand(0);
1247 }
1248 if (auto GV = dyn_cast<GlobalVariable>(C)) {
1249 if (GV->getName() == "ompi_request_null") {
1250 continue;
1251 }
1252 }
1253 }
1254
1255 bool backwardsShadow = false;
1256 bool forwardsShadow = true;
1257 for (auto pair : gutils->backwardsOnlyShadows) {
1258 if (pair.second.stores.count(&I)) {
1259 backwardsShadow = true;
1260 forwardsShadow = pair.second.primalInitialize;
1261 if (auto inst = dyn_cast<Instruction>(pair.first))
1262 if (!forwardsShadow && pair.second.LI &&
1263 pair.second.LI->contains(inst->getParent()))
1264 backwardsShadow = false;
1265 }
1266 }
1267
1268 if ((Mode == DerivativeMode::ReverseModePrimal && forwardsShadow) ||
1269 (Mode == DerivativeMode::ReverseModeGradient && backwardsShadow) ||
1270 (Mode == DerivativeMode::ForwardModeSplit && backwardsShadow) ||
1272 (forwardsShadow || backwardsShadow)) ||
1275
1276 Value *valueop = nullptr;
1277
1278 bool needs_writebarrier = false;
1279 if (constantval) {
1280 if (!gutils->runtimeActivity) {
1281 if (dt.isPossiblePointer() && vd[{-1, -1}] != BaseType::Integer) {
1282 if (!isa<UndefValue>(orig_val) &&
1283 !isa<ConstantPointerNull>(orig_val)) {
1284 std::string str;
1285 raw_string_ostream ss(str);
1286 ss << "Mismatched activity for: " << I
1287 << " const val: " << *orig_val;
1288 if (CustomErrorHandler) {
1289 valueop = unwrap(CustomErrorHandler(
1290 str.c_str(), wrap(&I), ErrorType::MixedActivityError,
1291 gutils, wrap(orig_val), wrap(&BuilderZ)));
1292 if (valueop)
1293 needs_writebarrier = true;
1294 } else
1295 EmitWarning("MixedActivityError", I, ss.str());
1296 }
1297 }
1298 }
1299 if (!valueop) {
1300 valueop = val;
1301 if (gutils->getWidth() > 1) {
1302 Value *array =
1303 UndefValue::get(gutils->getShadowType(val->getType()));
1304 for (unsigned i = 0; i < gutils->getWidth(); ++i) {
1305 array = BuilderZ.CreateInsertValue(array, val, {i});
1306 }
1307 valueop = array;
1308 }
1309 }
1310 } else {
1311 valueop = gutils->invertPointerM(orig_val, BuilderZ);
1312 }
1313 gutils->setPtrDiffe(&I, orig_ptr, valueop, BuilderZ, align, start,
1314 size, isVolatile, ordering, syncScope, mask,
1315 prevNoAlias, prevScopes, needs_writebarrier);
1316 }
1317 }
1318
1319 if (nextStart == storeSize)
1320 break;
1321 start = nextStart;
1322 }
1323
1324 if (merge) {
1325 Builder2.CreateBr(merge);
1326 Builder2.SetInsertPoint(merge);
1327 }
1328 }
1329
1330 void visitGetElementPtrInst(llvm::GetElementPtrInst &gep) {
1331 eraseIfUnused(gep);
1332 switch (Mode) {
1337 return;
1338 }
1339 default:
1340 return;
1341 }
1342 }
1343
1344 void visitPHINode(llvm::PHINode &phi) {
1345 eraseIfUnused(phi);
1346
1347 switch (Mode) {
1351 return;
1352 }
1357 return;
1358 }
1359 }
1360 }
1361
1362 void visitCastInst(llvm::CastInst &I) {
1363 using namespace llvm;
1364
1365 eraseIfUnused(I);
1366
1367 switch (Mode) {
1369 return;
1370 }
1373 if (gutils->isConstantInstruction(&I))
1374 return;
1375
1376 if (I.getType()->isPointerTy() ||
1377 I.getOpcode() == CastInst::CastOps::PtrToInt)
1378 return;
1379
1380 Value *orig_op0 = I.getOperand(0);
1381 Value *op0 = gutils->getNewFromOriginal(orig_op0);
1382
1383 IRBuilder<> Builder2(&I);
1384 getReverseBuilder(Builder2);
1385
1386 if (!gutils->isConstantValue(orig_op0)) {
1387 size_t size = 1;
1388 if (orig_op0->getType()->isSized())
1389 size =
1390 (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
1391 orig_op0->getType()) +
1392 7) /
1393 8;
1394 Type *FT = TR.addingType(size, orig_op0);
1395 if (!FT && looseTypeAnalysis) {
1396 if (auto ET = I.getSrcTy()->getScalarType())
1397 if (ET->isFPOrFPVectorTy()) {
1398 FT = ET;
1399 EmitWarning("CannotDeduceType", I,
1400 "failed to deduce adding type of cast ", I,
1401 " assumed ", FT, " from src");
1402 }
1403 }
1404 if (!FT && looseTypeAnalysis) {
1405 if (auto ET = I.getDestTy()->getScalarType())
1406 if (ET->isFPOrFPVectorTy()) {
1407 FT = ET;
1408 EmitWarning("CannotDeduceType", I,
1409 "failed to deduce adding type of cast ", I,
1410 " assumed ", FT, " from dst");
1411 }
1412 }
1413 if (!FT) {
1414 if (TR.query(orig_op0)[{-1}] == BaseType::Integer &&
1415 TR.query(&I)[{-1}] == BaseType::Integer)
1416 return;
1417 if (looseTypeAnalysis) {
1418 if (auto ET = I.getSrcTy()->getScalarType())
1419 if (ET->isIntOrIntVectorTy()) {
1420 EmitWarning("CannotDeduceType", I,
1421 "failed to deduce adding type of cast ", I,
1422 " assumed integral from src");
1423 return;
1424 }
1425 }
1426 std::string str;
1427 raw_string_ostream ss(str);
1428 ss << "Cannot deduce adding type (cast) of " << I;
1429 EmitNoTypeError(str, I, gutils, Builder2);
1430 }
1431
1432 if (FT) {
1433
1434 auto rule = [&](Value *dif) {
1435 if (I.getOpcode() == CastInst::CastOps::FPTrunc ||
1436 I.getOpcode() == CastInst::CastOps::FPExt) {
1437 return Builder2.CreateFPCast(dif, op0->getType());
1438 } else if (I.getOpcode() == CastInst::CastOps::BitCast) {
1439 return Builder2.CreateBitCast(dif, op0->getType());
1440 } else if (I.getOpcode() == CastInst::CastOps::Trunc) {
1441 // TODO CHECK THIS
1442 return Builder2.CreateZExt(dif, op0->getType());
1443 } else {
1444 std::string s;
1445 llvm::raw_string_ostream ss(s);
1446 ss << *I.getParent()->getParent() << "\n";
1447 ss << "cannot handle above cast " << I << "\n";
1448 EmitNoDerivativeError(ss.str(), I, gutils, Builder2);
1449 return (llvm::Value *)UndefValue::get(op0->getType());
1450 }
1451 };
1452
1453 Value *dif = diffe(&I, Builder2);
1454 Value *diff = applyChainRule(op0->getType(), Builder2, rule, dif);
1455
1456 addToDiffe(orig_op0, diff, Builder2, FT);
1457 }
1458 }
1459
1460 Type *diffTy = gutils->getShadowType(I.getType());
1461 setDiffe(&I, Constant::getNullValue(diffTy), Builder2);
1462
1463 break;
1464 }
1469 return;
1470 }
1471 }
1472 }
1473
1474 void visitSelectInst(llvm::SelectInst &SI) {
1475 eraseIfUnused(SI);
1476
1477 switch (Mode) {
1479 return;
1482 if (gutils->isConstantInstruction(&SI))
1483 return;
1484 if (SI.getType()->isPointerTy())
1485 return;
1487 return;
1488 }
1493 return;
1494 }
1495 }
1496 }
1497
1498 void createSelectInstAdjoint(llvm::SelectInst &SI) {
1499 using namespace llvm;
1500
1501 Value *op0 = gutils->getNewFromOriginal(SI.getOperand(0));
1502 Value *orig_op1 = SI.getOperand(1);
1503 Value *op1 = gutils->getNewFromOriginal(orig_op1);
1504 Value *orig_op2 = SI.getOperand(2);
1505 Value *op2 = gutils->getNewFromOriginal(orig_op2);
1506
1507 // TODO fix all the reverse builders
1508 IRBuilder<> Builder2(&SI);
1509 getReverseBuilder(Builder2);
1510
1511 Value *dif1 = nullptr;
1512 Value *dif2 = nullptr;
1513
1514 size_t size = 1;
1515 if (orig_op1->getType()->isSized())
1516 size = (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
1517 orig_op1->getType()) +
1518 7) /
1519 8;
1520 // Required loopy phi = [in, BO, BO, ..., BO]
1521 // 1) phi is only used in this B0
1522 // 2) BO dominates all latches
1523 // 3) phi == B0 whenever not coming from preheader [implies 2]
1524 // 4) [optional but done for ease] one exit to make it easier to
1525 // calculation the product at that point
1526 for (int i = 0; i < 2; i++)
1527 if (auto P0 = dyn_cast<PHINode>(SI.getOperand(i + 1))) {
1528 LoopContext lc;
1529 SmallVector<Instruction *, 4> activeUses;
1530 for (auto u : P0->users()) {
1531 if (!gutils->isConstantInstruction(cast<Instruction>(u))) {
1532 activeUses.push_back(cast<Instruction>(u));
1533 } else if (retType == DIFFE_TYPE::OUT_DIFF && isa<ReturnInst>(u))
1534 activeUses.push_back(cast<Instruction>(u));
1535 }
1536 if (activeUses.size() == 1 && activeUses[0] == &SI &&
1537 gutils->getContext(gutils->getNewFromOriginal(P0->getParent()),
1538 lc) &&
1539 gutils->getNewFromOriginal(P0->getParent()) == lc.header) {
1540 SmallVector<BasicBlock *, 1> Latches;
1541 gutils->OrigLI->getLoopFor(P0->getParent())->getLoopLatches(Latches);
1542 bool allIncoming = true;
1543 for (auto Latch : Latches) {
1544 if (&SI != P0->getIncomingValueForBlock(Latch)) {
1545 allIncoming = false;
1546 break;
1547 }
1548 }
1549 if (allIncoming && lc.exitBlocks.size() == 1) {
1550 if (!gutils->isConstantValue(SI.getOperand(2 - i))) {
1551 auto addingType = TR.addingType(size, SI.getOperand(2 - i));
1552 if (addingType || !looseTypeAnalysis) {
1553 auto index = gutils->getOrInsertConditionalIndex(
1554 gutils->getNewFromOriginal(SI.getOperand(0)), lc, i == 1);
1555 IRBuilder<> EB(*lc.exitBlocks.begin());
1556 getReverseBuilder(EB, /*original=*/false);
1557 Value *inc = lookup(lc.incvar, Builder2);
1558 if (VectorType *VTy =
1559 dyn_cast<VectorType>(SI.getOperand(0)->getType())) {
1560 inc = Builder2.CreateVectorSplat(VTy->getElementCount(), inc);
1561 }
1562 Value *dif = CreateSelect(
1563 Builder2,
1564 Builder2.CreateICmpEQ(gutils->lookupM(index, EB), inc),
1565 diffe(&SI, Builder2),
1566 Constant::getNullValue(
1567 gutils->getShadowType(op1->getType())));
1568 addToDiffe(SI.getOperand(2 - i), dif, Builder2, addingType);
1569 }
1570 }
1571 return;
1572 }
1573 }
1574 }
1575
1576 if (!gutils->isConstantValue(orig_op1))
1577 dif1 = CreateSelect(
1578 Builder2, lookup(op0, Builder2), diffe(&SI, Builder2),
1579 Constant::getNullValue(gutils->getShadowType(op1->getType())),
1580 "diffe" + op1->getName());
1581 if (!gutils->isConstantValue(orig_op2))
1582 dif2 = CreateSelect(
1583 Builder2, lookup(op0, Builder2),
1584 Constant::getNullValue(gutils->getShadowType(op2->getType())),
1585 diffe(&SI, Builder2), "diffe" + op2->getName());
1586
1587 setDiffe(&SI, Constant::getNullValue(gutils->getShadowType(SI.getType())),
1588 Builder2);
1589 if (dif1) {
1590 Type *addingType = TR.addingType(size, orig_op1);
1591 if (addingType || !looseTypeAnalysis)
1592 addToDiffe(orig_op1, dif1, Builder2, addingType);
1593 else
1594 llvm::errs() << " warning: assuming integral for " << SI << "\n";
1595 }
1596 if (dif2) {
1597 Type *addingType = TR.addingType(size, orig_op2);
1598 if (addingType || !looseTypeAnalysis)
1599 addToDiffe(orig_op2, dif2, Builder2, addingType);
1600 else
1601 llvm::errs() << " warning: assuming integral for " << SI << "\n";
1602 }
1603 }
1604
1605 void visitExtractElementInst(llvm::ExtractElementInst &EEI) {
1606 using namespace llvm;
1607
1608 eraseIfUnused(EEI);
1609 switch (Mode) {
1614 return;
1615 }
1618 if (gutils->isConstantInstruction(&EEI))
1619 return;
1620 IRBuilder<> Builder2(&EEI);
1621 getReverseBuilder(Builder2);
1622
1623 Value *orig_vec = EEI.getVectorOperand();
1624
1625 if (!gutils->isConstantValue(orig_vec)) {
1626
1627 size_t size = 1;
1628 if (EEI.getType()->isSized())
1629 size =
1630 (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
1631 EEI.getType()) +
1632 7) /
1633 8;
1634 auto diff = diffe(&EEI, Builder2);
1635 if (gutils->getWidth() == 1) {
1636 Value *sv[] = {gutils->getNewFromOriginal(EEI.getIndexOperand())};
1637 ((DiffeGradientUtils *)gutils)
1638 ->addToDiffe(orig_vec, diff, Builder2, TR.addingType(size, &EEI),
1639 sv);
1640 } else {
1641 for (size_t i = 0; i < gutils->getWidth(); i++) {
1642 Value *sv[] = {nullptr,
1643 gutils->getNewFromOriginal(EEI.getIndexOperand())};
1644 sv[0] = ConstantInt::get(sv[1]->getType(), i);
1645 ((DiffeGradientUtils *)gutils)
1646 ->addToDiffe(orig_vec, gutils->extractMeta(Builder2, diff, i),
1647 Builder2, TR.addingType(size, &EEI), sv);
1648 }
1649 }
1650 }
1651 setDiffe(&EEI,
1652 Constant::getNullValue(gutils->getShadowType(EEI.getType())),
1653 Builder2);
1654 return;
1655 }
1657 return;
1658 }
1659 }
1660 }
1661
1662 void visitInsertElementInst(llvm::InsertElementInst &IEI) {
1663 using namespace llvm;
1664
1665 eraseIfUnused(IEI);
1666
1667 switch (Mode) {
1672 return;
1673 }
1676 if (gutils->isConstantInstruction(&IEI))
1677 return;
1678 IRBuilder<> Builder2(&IEI);
1679 getReverseBuilder(Builder2);
1680
1681 Value *dif1 = diffe(&IEI, Builder2);
1682
1683 Value *orig_op0 = IEI.getOperand(0);
1684 Value *orig_op1 = IEI.getOperand(1);
1685 Value *op1 = gutils->getNewFromOriginal(orig_op1);
1686 Value *op2 = gutils->getNewFromOriginal(IEI.getOperand(2));
1687
1688 size_t size0 = 1;
1689 if (orig_op0->getType()->isSized())
1690 size0 =
1691 (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
1692 orig_op0->getType()) +
1693 7) /
1694 8;
1695 size_t size1 = 1;
1696 if (orig_op1->getType()->isSized())
1697 size1 =
1698 (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
1699 orig_op1->getType()) +
1700 7) /
1701 8;
1702
1703 if (!gutils->isConstantValue(orig_op0)) {
1704 if (gutils->getWidth() == 1) {
1705 addToDiffe(
1706 orig_op0,
1707 Builder2.CreateInsertElement(
1708 dif1,
1709 Constant::getNullValue(gutils->getShadowType(op1->getType())),
1710 lookup(op2, Builder2)),
1711 Builder2, TR.addingType(size0, orig_op0));
1712 } else {
1713 for (size_t i = 0; i < gutils->getWidth(); i++) {
1714 Value *sv[] = {ConstantInt::get(op2->getType(), i)};
1715 ((DiffeGradientUtils *)gutils)
1716 ->addToDiffe(orig_op0,
1717 Builder2.CreateInsertElement(
1718 gutils->extractMeta(Builder2, dif1, i),
1719 Constant::getNullValue(op1->getType()),
1720 lookup(op2, Builder2)),
1721 Builder2, TR.addingType(size0, orig_op0), sv);
1722 }
1723 }
1724 }
1725
1726 if (!gutils->isConstantValue(orig_op1)) {
1727 if (gutils->getWidth() == 1) {
1728 addToDiffe(orig_op1,
1729 Builder2.CreateExtractElement(dif1, lookup(op2, Builder2)),
1730 Builder2, TR.addingType(size1, orig_op1));
1731 } else {
1732 for (size_t i = 0; i < gutils->getWidth(); i++) {
1733 Value *sv[] = {ConstantInt::get(op2->getType(), i)};
1734 ((DiffeGradientUtils *)gutils)
1735 ->addToDiffe(orig_op1,
1736 Builder2.CreateExtractElement(
1737 gutils->extractMeta(Builder2, dif1, i),
1738 lookup(op2, Builder2)),
1739 Builder2, TR.addingType(size1, orig_op1), sv);
1740 }
1741 }
1742 }
1743
1744 setDiffe(&IEI,
1745 Constant::getNullValue(gutils->getShadowType(IEI.getType())),
1746 Builder2);
1747 return;
1748 }
1750 return;
1751 }
1752 }
1753 }
1754
1755 void visitShuffleVectorInst(llvm::ShuffleVectorInst &SVI) {
1756 using namespace llvm;
1757
1758 eraseIfUnused(SVI);
1759
1760 switch (Mode) {
1765 return;
1766 }
1769 if (gutils->isConstantInstruction(&SVI))
1770 return;
1771 IRBuilder<> Builder2(&SVI);
1772 getReverseBuilder(Builder2);
1773
1774 auto loaded = diffe(&SVI, Builder2);
1775 auto count =
1776 cast<VectorType>(SVI.getOperand(0)->getType())->getElementCount();
1777 assert(!count.isScalable());
1778 size_t l1 = count.getKnownMinValue();
1779 uint64_t instidx = 0;
1780
1781 for (size_t idx : SVI.getShuffleMask()) {
1782 auto opnum = (idx < l1) ? 0 : 1;
1783 auto opidx = (idx < l1) ? idx : (idx - l1);
1784
1785 if (!gutils->isConstantValue(SVI.getOperand(opnum))) {
1786 size_t size = 1;
1787 if (SVI.getOperand(opnum)->getType()->isSized())
1788 size = (gutils->newFunc->getParent()
1789 ->getDataLayout()
1790 .getTypeSizeInBits(SVI.getOperand(opnum)->getType()) +
1791 7) /
1792 8;
1793 if (gutils->getWidth() == 1) {
1794 Value *sv[] = {
1795 ConstantInt::get(Type::getInt32Ty(SVI.getContext()), opidx)};
1796 Value *toadd = Builder2.CreateExtractElement(loaded, instidx);
1797 ((DiffeGradientUtils *)gutils)
1798 ->addToDiffe(SVI.getOperand(opnum), toadd, Builder2,
1799 TR.addingType(size, SVI.getOperand(opnum)), sv);
1800 } else {
1801 for (size_t i = 0; i < gutils->getWidth(); i++) {
1802 Value *sv[] = {
1803 ConstantInt::get(Type::getInt32Ty(SVI.getContext()), i),
1804 ConstantInt::get(Type::getInt32Ty(SVI.getContext()), opidx)};
1805 Value *toadd = Builder2.CreateExtractElement(
1806 GradientUtils::extractMeta(Builder2, loaded, i), instidx);
1807 ((DiffeGradientUtils *)gutils)
1808 ->addToDiffe(SVI.getOperand(opnum), toadd, Builder2,
1809 TR.addingType(size, SVI.getOperand(opnum)), sv);
1810 }
1811 }
1812 }
1813 ++instidx;
1814 }
1815 setDiffe(&SVI,
1816 Constant::getNullValue(gutils->getShadowType(SVI.getType())),
1817 Builder2);
1818 return;
1819 }
1821 return;
1822 }
1823 }
1824 }
1825
1826 void visitExtractValueInst(llvm::ExtractValueInst &EVI) {
1827 using namespace llvm;
1828
1829 eraseIfUnused(EVI);
1830
1831 if (!gutils->isConstantValue(&EVI) && gutils->isConstantValue(&EVI)) {
1832 llvm::errs() << *gutils->oldFunc->getParent() << "\n";
1833 llvm::errs() << EVI << "\n";
1834 llvm_unreachable("Illegal activity for extractvalue");
1835 }
1836
1837 switch (Mode) {
1842 return;
1843 }
1846 if (gutils->isConstantInstruction(&EVI))
1847 return;
1848 if (EVI.getType()->isPointerTy())
1849 return;
1850 IRBuilder<> Builder2(&EVI);
1851 getReverseBuilder(Builder2);
1852
1853 Value *orig_op0 = EVI.getOperand(0);
1854
1855 auto prediff = diffe(&EVI, Builder2);
1856
1857 // todo const
1858 if (!gutils->isConstantValue(orig_op0)) {
1859 SmallVector<Value *, 4> sv;
1860 for (auto i : EVI.getIndices())
1861 sv.push_back(ConstantInt::get(Type::getInt32Ty(EVI.getContext()), i));
1862 size_t storeSize = 1;
1863 if (EVI.getType()->isSized())
1864 storeSize =
1865 (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
1866 EVI.getType()) +
1867 7) /
1868 8;
1869
1870 unsigned start = 0;
1871 auto vd = TR.query(&EVI);
1872
1873 while (1) {
1874 unsigned nextStart = storeSize;
1875
1876 auto dt = vd[{-1}];
1877 for (size_t i = start; i < storeSize; ++i) {
1878 auto nex = vd[{(int)i}];
1879 if ((nex == BaseType::Anything && dt.isFloat()) ||
1880 (dt == BaseType::Anything && nex.isFloat())) {
1881 nextStart = i;
1882 break;
1883 }
1884 bool Legal = true;
1885 dt.checkedOrIn(nex, /*PointerIntSame*/ true, Legal);
1886 if (!Legal) {
1887 nextStart = i;
1888 break;
1889 }
1890 }
1891 unsigned size = nextStart - start;
1892 if (!dt.isKnown()) {
1893 bool found = false;
1894 if (looseTypeAnalysis) {
1895 if (EVI.getType()->isFPOrFPVectorTy()) {
1896 dt = ConcreteType(EVI.getType()->getScalarType());
1897 found = true;
1898 } else if (EVI.getType()->isIntOrIntVectorTy() ||
1899 EVI.getType()->isPointerTy()) {
1900 dt = BaseType::Integer;
1901 found = true;
1902 }
1903 }
1904 if (!found) {
1905 std::string str;
1906 raw_string_ostream ss(str);
1907 ss << "Cannot deduce type of extract " << EVI << vd.str()
1908 << " start: " << start << " size: " << size
1909 << " extractSize: " << storeSize;
1910 EmitNoTypeError(str, EVI, gutils, Builder2);
1911 }
1912 }
1913 if (auto FT = dt.isFloat())
1914 ((DiffeGradientUtils *)gutils)
1915 ->addToDiffe(orig_op0, prediff, Builder2, FT, start, size, sv,
1916 nullptr, /*ignoreFirstSlicesToDiff*/ sv.size());
1917
1918 if (nextStart == storeSize)
1919 break;
1920 start = nextStart;
1921 }
1922 }
1923
1924 setDiffe(&EVI,
1925 Constant::getNullValue(gutils->getShadowType(EVI.getType())),
1926 Builder2);
1927 return;
1928 }
1930 return;
1931 }
1932 }
1933 }
1934
1935 void visitInsertValueInst(llvm::InsertValueInst &IVI) {
1936 using namespace llvm;
1937
1938 eraseIfUnused(IVI);
1939 if (gutils->isConstantValue(&IVI))
1940 return;
1941
1943 return;
1944
1945 if (Mode == DerivativeMode::ForwardMode ||
1949 return;
1950 }
1951
1952 bool hasNonPointer = false;
1953 if (auto st = dyn_cast<StructType>(IVI.getType())) {
1954 for (unsigned i = 0; i < st->getNumElements(); ++i) {
1955 if (!st->getElementType(i)->isPointerTy()) {
1956 hasNonPointer = true;
1957 }
1958 }
1959 } else if (auto at = dyn_cast<ArrayType>(IVI.getType())) {
1960 if (!at->getElementType()->isPointerTy()) {
1961 hasNonPointer = true;
1962 }
1963 }
1964 if (!hasNonPointer)
1965 return;
1966
1967 bool floatingInsertion = false;
1968 for (InsertValueInst *iv = &IVI;;) {
1969 size_t size0 = 1;
1970 if (iv->getInsertedValueOperand()->getType()->isSized() &&
1971 (iv->getInsertedValueOperand()->getType()->isIntOrIntVectorTy() ||
1972 iv->getInsertedValueOperand()->getType()->isFPOrFPVectorTy()))
1973 size0 =
1974 (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
1975 iv->getInsertedValueOperand()->getType()) +
1976 7) /
1977 8;
1978 auto it = TR.intType(size0, iv->getInsertedValueOperand(), false);
1979 if (it.isFloat() || !it.isKnown()) {
1980 floatingInsertion = true;
1981 break;
1982 }
1983 Value *val = iv->getAggregateOperand();
1984 if (gutils->isConstantValue(val))
1985 break;
1986 if (auto dc = dyn_cast<InsertValueInst>(val)) {
1987 iv = dc;
1988 } else {
1989 // unsure where this came from, conservatively assume contains float
1990 floatingInsertion = true;
1991 break;
1992 }
1993 }
1994
1995 if (!floatingInsertion)
1996 return;
1997
1998 // TODO handle pointers
1999 // TODO type analysis handle structs
2000
2001 switch (Mode) {
2005 assert(0 && "should be handled above");
2006 return;
2009 IRBuilder<> Builder2(&IVI);
2010 getReverseBuilder(Builder2);
2011
2012 Value *orig_inserted = IVI.getInsertedValueOperand();
2013 Value *orig_agg = IVI.getAggregateOperand();
2014
2015 size_t size0 = 1;
2016 if (orig_inserted->getType()->isSized())
2017 size0 =
2018 (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
2019 orig_inserted->getType()) +
2020 7) /
2021 8;
2022
2023 if (!gutils->isConstantValue(orig_inserted)) {
2024 auto TT = TR.query(orig_inserted);
2025
2026 unsigned start = 0;
2027 Value *dindex = nullptr;
2028
2029 while (1) {
2030 unsigned nextStart = size0;
2031
2032 auto dt = TT[{-1}];
2033 for (size_t i = start; i < size0; ++i) {
2034 auto nex = TT[{(int)i}];
2035 if ((nex == BaseType::Anything && dt.isFloat()) ||
2036 (dt == BaseType::Anything && nex.isFloat())) {
2037 nextStart = i;
2038 break;
2039 }
2040 bool Legal = true;
2041 dt.checkedOrIn(nex, /*PointerIntSame*/ true, Legal);
2042 if (!Legal) {
2043 nextStart = i;
2044 break;
2045 }
2046 }
2047 Type *flt = dt.isFloat();
2048 if (!dt.isKnown()) {
2049 bool found = false;
2050 if (looseTypeAnalysis) {
2051 if (orig_inserted->getType()->isFPOrFPVectorTy()) {
2052 flt = orig_inserted->getType()->getScalarType();
2053 found = true;
2054 } else if (orig_inserted->getType()->isIntOrIntVectorTy() ||
2055 orig_inserted->getType()->isPointerTy()) {
2056 flt = nullptr;
2057 found = true;
2058 }
2059 }
2060 if (!found) {
2061 std::string str;
2062 raw_string_ostream ss(str);
2063 ss << "Cannot deduce type of insertvalue ins " << IVI
2064 << " size: " << size0 << " TT: " << TT.str();
2065 EmitNoTypeError(str, IVI, gutils, Builder2);
2066 }
2067 }
2068
2069 if (flt) {
2070 if (!dindex) {
2071 auto rule = [&](Value *prediff) {
2072 return Builder2.CreateExtractValue(prediff, IVI.getIndices());
2073 };
2074 auto prediff = diffe(&IVI, Builder2);
2075 dindex = applyChainRule(orig_inserted->getType(), Builder2, rule,
2076 prediff);
2077 }
2078
2079 auto TT = TR.query(orig_inserted);
2080
2081 ((DiffeGradientUtils *)gutils)
2082 ->addToDiffe(orig_inserted, dindex, Builder2, flt, start,
2083 nextStart - start);
2084 }
2085 if (nextStart == size0)
2086 break;
2087 start = nextStart;
2088 }
2089 }
2090
2091 size_t size1 = 1;
2092 if (orig_agg->getType()->isSized())
2093 size1 =
2094 (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
2095 orig_agg->getType()) +
2096 7) /
2097 8;
2098
2099 if (!gutils->isConstantValue(orig_agg)) {
2100
2101 auto TT = TR.query(orig_agg);
2102
2103 unsigned start = 0;
2104
2105 Value *dindex = nullptr;
2106
2107 while (1) {
2108 unsigned nextStart = size1;
2109
2110 auto dt = TT[{-1}];
2111 for (size_t i = start; i < size1; ++i) {
2112 auto nex = TT[{(int)i}];
2113 if ((nex == BaseType::Anything && dt.isFloat()) ||
2114 (dt == BaseType::Anything && nex.isFloat())) {
2115 nextStart = i;
2116 break;
2117 }
2118 bool Legal = true;
2119 dt.checkedOrIn(nex, /*PointerIntSame*/ true, Legal);
2120 if (!Legal) {
2121 nextStart = i;
2122 break;
2123 }
2124 }
2125 Type *flt = dt.isFloat();
2126 if (!dt.isKnown()) {
2127 bool found = false;
2128 if (looseTypeAnalysis) {
2129 if (orig_agg->getType()->isFPOrFPVectorTy()) {
2130 flt = orig_agg->getType()->getScalarType();
2131 found = true;
2132 } else if (orig_agg->getType()->isIntOrIntVectorTy() ||
2133 orig_agg->getType()->isPointerTy()) {
2134 flt = nullptr;
2135 found = true;
2136 }
2137 }
2138 if (!found) {
2139 std::string str;
2140 raw_string_ostream ss(str);
2141 ss << "Cannot deduce type of insertvalue agg " << IVI
2142 << " start: " << start << " size: " << size1
2143 << " TT: " << TT.str();
2144 EmitNoTypeError(str, IVI, gutils, Builder2);
2145 }
2146 }
2147
2148 if (flt) {
2149 if (!dindex) {
2150 auto rule = [&](Value *prediff) {
2151 return Builder2.CreateInsertValue(
2152 prediff, Constant::getNullValue(orig_inserted->getType()),
2153 IVI.getIndices());
2154 };
2155 auto prediff = diffe(&IVI, Builder2);
2156 dindex =
2157 applyChainRule(orig_agg->getType(), Builder2, rule, prediff);
2158 }
2159 ((DiffeGradientUtils *)gutils)
2160 ->addToDiffe(orig_agg, dindex, Builder2, flt, start,
2161 nextStart - start);
2162 }
2163 if (nextStart == size1)
2164 break;
2165 start = nextStart;
2166 }
2167 }
2168
2169 setDiffe(&IVI,
2170 Constant::getNullValue(gutils->getShadowType(IVI.getType())),
2171 Builder2);
2172 return;
2173 }
2175 return;
2176 }
2177 }
2178 }
2179
2180 void getReverseBuilder(llvm::IRBuilder<> &Builder2, bool original = true) {
2181 ((GradientUtils *)gutils)->getReverseBuilder(Builder2, original);
2182 }
2183
2184 void getForwardBuilder(llvm::IRBuilder<> &Builder2) {
2185 ((GradientUtils *)gutils)->getForwardBuilder(Builder2);
2186 }
2187
2188 llvm::Value *diffe(llvm::Value *val, llvm::IRBuilder<> &Builder) {
2189 assert(Mode != DerivativeMode::ReverseModePrimal);
2190 return ((DiffeGradientUtils *)gutils)->diffe(val, Builder);
2191 }
2192
2193 void setDiffe(llvm::Value *val, llvm::Value *dif,
2194 llvm::IRBuilder<> &Builder) {
2195 assert(Mode != DerivativeMode::ReverseModePrimal);
2196 ((DiffeGradientUtils *)gutils)->setDiffe(val, dif, Builder);
2197 }
2198
2199 /// Unwraps a vector derivative from its internal representation and applies a
2200 /// function f to each element. Return values of f are collected and wrapped.
2201 template <typename Func, typename... Args>
2202 llvm::Value *applyChainRule(llvm::Type *diffType, llvm::IRBuilder<> &Builder,
2203 Func rule, Args... args) {
2204 return ((GradientUtils *)gutils)
2205 ->applyChainRule(diffType, Builder, rule, args...);
2206 }
2207
2208 /// Unwraps a vector derivative from its internal representation and applies a
2209 /// function f to each element.
2210 template <typename Func, typename... Args>
2211 void applyChainRule(llvm::IRBuilder<> &Builder, Func rule, Args... args) {
2212 ((GradientUtils *)gutils)->applyChainRule(Builder, rule, args...);
2213 }
2214
2215 /// Unwraps an collection of constant vector derivatives from their internal
2216 /// representations and applies a function f to each element.
2217 template <typename Func>
2218 void applyChainRule(llvm::ArrayRef<llvm::Value *> diffs,
2219 llvm::IRBuilder<> &Builder, Func rule) {
2220 ((GradientUtils *)gutils)->applyChainRule(diffs, Builder, rule);
2221 }
2222
2223 bool shouldFree() {
2224 assert(Mode == DerivativeMode::ReverseModeCombined ||
2227 return ((DiffeGradientUtils *)gutils)->FreeMemory;
2228 }
2229
2230 llvm::SmallVector<llvm::SelectInst *, 4>
2231 addToDiffe(llvm::Value *val, llvm::Value *dif, llvm::IRBuilder<> &Builder,
2232 llvm::Type *T, llvm::Value *mask = nullptr) {
2233 return ((DiffeGradientUtils *)gutils)
2234 ->addToDiffe(val, dif, Builder, T, /*idxs*/ {}, mask);
2235 }
2236
2237 llvm::Value *lookup(llvm::Value *val, llvm::IRBuilder<> &Builder) {
2238 return gutils->lookupM(val, Builder);
2239 }
2240
2241 void visitBinaryOperator(llvm::BinaryOperator &BO) {
2242 eraseIfUnused(BO);
2243
2244 if (BO.getOpcode() == llvm::Instruction::FDiv &&
2247 !gutils->isConstantValue(&BO)) {
2248 using namespace llvm;
2249 // Required loopy phi = [in, BO, BO, ..., BO]
2250 // 1) phi is only used in this B0
2251 // 2) BO dominates all latches
2252 // 3) phi == B0 whenever not coming from preheader [implies 2]
2253 // 4) [optional but done for ease] one exit to make it easier to
2254 // calculation the product at that point
2255 Value *orig_op0 = BO.getOperand(0);
2256 if (auto P0 = dyn_cast<PHINode>(orig_op0)) {
2257 LoopContext lc;
2258 SmallVector<Instruction *, 4> activeUses;
2259 for (auto u : P0->users()) {
2260 if (!gutils->isConstantInstruction(cast<Instruction>(u))) {
2261 activeUses.push_back(cast<Instruction>(u));
2262 } else if (retType == DIFFE_TYPE::OUT_DIFF && isa<ReturnInst>(u)) {
2263 activeUses.push_back(cast<Instruction>(u));
2264 }
2265 }
2266 if (activeUses.size() == 1 && activeUses[0] == &BO &&
2267 gutils->getContext(gutils->getNewFromOriginal(P0->getParent()),
2268 lc) &&
2269 gutils->getNewFromOriginal(P0->getParent()) == lc.header) {
2270 SmallVector<BasicBlock *, 1> Latches;
2271 gutils->OrigLI->getLoopFor(P0->getParent())->getLoopLatches(Latches);
2272 bool allIncoming = true;
2273 for (auto Latch : Latches) {
2274 if (&BO != P0->getIncomingValueForBlock(Latch)) {
2275 allIncoming = false;
2276 break;
2277 }
2278 }
2279 if (allIncoming && lc.exitBlocks.size() == 1) {
2280
2281 IRBuilder<> Builder2(&BO);
2282 getReverseBuilder(Builder2);
2283
2284 Value *orig_op1 = BO.getOperand(1);
2285
2286 Value *dif1 = nullptr;
2287 Value *idiff = diffe(&BO, Builder2);
2288
2289 Type *addingType = BO.getType();
2290
2291 if (!gutils->isConstantValue(orig_op1)) {
2292 IRBuilder<> EB(*lc.exitBlocks.begin());
2293 getReverseBuilder(EB, /*original=*/false);
2294 Value *Pstart = P0->getIncomingValueForBlock(
2295 gutils->getOriginalFromNew(lc.preheader));
2296 if (gutils->isConstantValue(Pstart)) {
2297 Value *lop0 = lookup(gutils->getNewFromOriginal(&BO), EB);
2298 Value *lop1 =
2299 lookup(gutils->getNewFromOriginal(orig_op1), Builder2);
2300 auto rule = [&](Value *idiff) {
2301 auto res = Builder2.CreateFDiv(
2302 Builder2.CreateFNeg(Builder2.CreateFMul(idiff, lop0)),
2303 lop1);
2304 if (gutils->strongZero) {
2305 res = CreateSelect(
2306 Builder2,
2307 Builder2.CreateFCmpOEQ(
2308 idiff, Constant::getNullValue(idiff->getType())),
2309 idiff, res);
2310 }
2311 return res;
2312 };
2313 dif1 =
2314 applyChainRule(orig_op1->getType(), Builder2, rule, idiff);
2315 } else {
2316 auto product = gutils->getOrInsertTotalMultiplicativeProduct(
2317 gutils->getNewFromOriginal(orig_op1), lc);
2318 IRBuilder<> EB(*lc.exitBlocks.begin());
2319 getReverseBuilder(EB, /*original=*/false);
2320 Value *s = lookup(gutils->getNewFromOriginal(Pstart), Builder2);
2321 Value *lop0 = lookup(product, EB);
2322 Value *lop1 =
2323 lookup(gutils->getNewFromOriginal(orig_op1), Builder2);
2324 auto rule = [&](Value *idiff) {
2325 auto res = Builder2.CreateFDiv(
2326 Builder2.CreateFNeg(Builder2.CreateFMul(
2327 s, Builder2.CreateFDiv(idiff, lop0))),
2328 lop1);
2329 if (gutils->strongZero) {
2330 res = CreateSelect(
2331 Builder2,
2332 Builder2.CreateFCmpOEQ(
2333 idiff, Constant::getNullValue(idiff->getType())),
2334 idiff, res);
2335 }
2336 return res;
2337 };
2338 dif1 =
2339 applyChainRule(orig_op1->getType(), Builder2, rule, idiff);
2340 }
2341 addToDiffe(orig_op1, dif1, Builder2, addingType);
2342 }
2343 return;
2344 }
2345 }
2346 }
2347 }
2348
2349 {
2350 using namespace llvm;
2351 switch (BO.getOpcode()) {
2352#include "BinopDerivatives.inc"
2353 default:
2354 break;
2355 }
2356 }
2357
2358 switch (Mode) {
2361 if (gutils->isConstantInstruction(&BO))
2362 return;
2364 break;
2369 break;
2371 return;
2372 }
2373 }
2374
2375 void createBinaryOperatorAdjoint(llvm::BinaryOperator &BO) {
2376 if (gutils->isConstantInstruction(&BO)) {
2377 return;
2378 }
2379 using namespace llvm;
2380
2381 IRBuilder<> Builder2(&BO);
2382 getReverseBuilder(Builder2);
2383
2384 Value *orig_op0 = BO.getOperand(0);
2385 Value *orig_op1 = BO.getOperand(1);
2386
2387 Value *dif0 = nullptr;
2388 Value *dif1 = nullptr;
2389 Value *idiff = diffe(&BO, Builder2);
2390
2391 Type *addingType = BO.getType();
2392
2393 switch (BO.getOpcode()) {
2394 case Instruction::LShr: {
2395 if (!gutils->isConstantValue(orig_op0)) {
2396 if (auto ci = dyn_cast<ConstantInt>(orig_op1)) {
2397 size_t size = 1;
2398 if (orig_op0->getType()->isSized())
2399 size = (gutils->newFunc->getParent()
2400 ->getDataLayout()
2401 .getTypeSizeInBits(orig_op0->getType()) +
2402 7) /
2403 8;
2404
2405 if (Type *flt = TR.addingType(size, orig_op0)) {
2406 auto bits = gutils->newFunc->getParent()
2407 ->getDataLayout()
2408 .getTypeAllocSizeInBits(flt);
2409 if (ci->getSExtValue() >= (int64_t)bits &&
2410 ci->getSExtValue() % bits == 0) {
2411 auto rule = [&](Value *idiff) {
2412 return Builder2.CreateShl(idiff, ci);
2413 };
2414 dif0 = applyChainRule(orig_op0->getType(), Builder2, rule, idiff);
2415 addingType = flt;
2416 goto done;
2417 }
2418 }
2419 }
2420 }
2421 if (looseTypeAnalysis) {
2422 llvm::errs() << "warning: binary operator is integer and constant: "
2423 << BO << "\n";
2424 // if loose type analysis, assume this integer and is constant
2425 return;
2426 }
2427 goto def;
2428 }
2429 case Instruction::AShr: {
2430 if (looseTypeAnalysis) {
2431 llvm::errs() << "warning: binary operator is integer and constant: "
2432 << BO << "\n";
2433 // if loose type analysis, assume this integer and is constant
2434 return;
2435 }
2436 goto def;
2437 }
2438 case Instruction::And: {
2439 // If & against 0b10000000000 and a float the result is 0
2440 auto &dl = gutils->oldFunc->getParent()->getDataLayout();
2441 auto size = dl.getTypeSizeInBits(BO.getType()) / 8;
2442
2443 auto FT = TR.query(&BO).IsAllFloat(size, dl);
2444 auto eFT = FT;
2445 if (FT)
2446 for (int i = 0; i < 2; ++i) {
2447 auto CI = dyn_cast<ConstantInt>(BO.getOperand(i));
2448 if (CI && dl.getTypeSizeInBits(eFT) ==
2449 dl.getTypeSizeInBits(CI->getType())) {
2450 if (eFT->isDoubleTy() && CI->getValue() == -134217728) {
2451 setDiffe(
2452 &BO,
2453 Constant::getNullValue(gutils->getShadowType(BO.getType())),
2454 Builder2);
2455 // Derivative is zero (equivalent to rounding as just chopping off
2456 // bits of mantissa), no update
2457 return;
2458 }
2459 }
2460 }
2461 if (looseTypeAnalysis) {
2462 llvm::errs() << "warning: binary operator is integer and constant: "
2463 << BO << "\n";
2464 // if loose type analysis, assume this integer and is constant
2465 return;
2466 }
2467 goto def;
2468 }
2469 case Instruction::Xor: {
2470 auto &dl = gutils->oldFunc->getParent()->getDataLayout();
2471 auto size = dl.getTypeSizeInBits(BO.getType()) / 8;
2472
2473 auto FT = TR.query(&BO).IsAllFloat(size, dl);
2474 auto eFT = FT;
2475 // If ^ against 0b10000000000 and a float the result is a float
2476 if (FT)
2477 for (int i = 0; i < 2; ++i) {
2478 if (containsOnlyAtMostTopBit(BO.getOperand(i), eFT, dl, &FT)) {
2479 setDiffe(
2480 &BO,
2481 Constant::getNullValue(gutils->getShadowType(BO.getType())),
2482 Builder2);
2483 auto isZero = Builder2.CreateICmpEQ(
2484 lookup(gutils->getNewFromOriginal(BO.getOperand(i)), Builder2),
2485 Constant::getNullValue(BO.getType()));
2486 auto rule = [&](Value *idiff) {
2487 auto ext = Builder2.CreateBitCast(idiff, FT);
2488 auto neg = Builder2.CreateFNeg(ext);
2489 neg = CreateSelect(Builder2, isZero, ext, neg);
2490 neg = Builder2.CreateBitCast(neg, BO.getType());
2491 return neg;
2492 };
2493 auto bc = applyChainRule(BO.getOperand(1 - i)->getType(), Builder2,
2494 rule, idiff);
2495 addToDiffe(BO.getOperand(1 - i), bc, Builder2, FT);
2496 return;
2497 }
2498 }
2499 if (looseTypeAnalysis) {
2500 llvm::errs() << "warning: binary operator is integer and constant: "
2501 << BO << "\n";
2502 // if loose type analysis, assume this integer and is constant
2503 return;
2504 }
2505 goto def;
2506 }
2507 case Instruction::Or: {
2508 auto &dl = gutils->oldFunc->getParent()->getDataLayout();
2509 auto size = dl.getTypeSizeInBits(BO.getType()) / 8;
2510
2511 auto FT = TR.query(&BO).IsAllFloat(size, dl);
2512 auto eFT = FT;
2513 // If & against 0b10000000000 and a float the result is a float
2514 if (FT)
2515 for (int i = 0; i < 2; ++i) {
2516 auto CI = dyn_cast<ConstantInt>(BO.getOperand(i));
2517 if (auto CV = dyn_cast<ConstantVector>(BO.getOperand(i))) {
2518 CI = dyn_cast_or_null<ConstantInt>(CV->getSplatValue());
2519 FT = VectorType::get(FT, CV->getType()->getElementCount());
2520 }
2521 if (auto CV = dyn_cast<ConstantDataVector>(BO.getOperand(i))) {
2522 CI = dyn_cast_or_null<ConstantInt>(CV->getSplatValue());
2523 FT = VectorType::get(FT, CV->getType()->getElementCount());
2524 }
2525 if (CI && dl.getTypeSizeInBits(eFT) ==
2526 dl.getTypeSizeInBits(CI->getType())) {
2527 auto AP = CI->getValue();
2528 bool validXor = false;
2529#if LLVM_VERSION_MAJOR > 16
2530 if (AP.isZero())
2531#else
2532 if (AP.isNullValue())
2533#endif
2534 {
2535 validXor = true;
2536 } else if (
2537 !AP.isNegative() &&
2538 ((FT->isFloatTy()
2539#if LLVM_VERSION_MAJOR > 16
2540 && (AP & ~0b01111111100000000000000000000000ULL).isZero()
2541#else
2542 && (AP & ~0b01111111100000000000000000000000ULL).isNullValue()
2543#endif
2544 ) ||
2545 (FT->isDoubleTy()
2546#if LLVM_VERSION_MAJOR > 16
2547 &&
2548 (AP &
2549 ~0b0111111111110000000000000000000000000000000000000000000000000000ULL)
2550 .isZero()
2551#else
2552 &&
2553 (AP &
2554 ~0b0111111111110000000000000000000000000000000000000000000000000000ULL)
2555 .isNullValue()
2556#endif
2557 ))) {
2558 validXor = true;
2559 }
2560 if (validXor) {
2561 setDiffe(
2562 &BO,
2563 Constant::getNullValue(gutils->getShadowType(BO.getType())),
2564 Builder2);
2565
2566 auto arg = lookup(
2567 gutils->getNewFromOriginal(BO.getOperand(1 - i)), Builder2);
2568
2569 auto rule = [&](Value *idiff) {
2570 auto prev = Builder2.CreateOr(arg, BO.getOperand(i));
2571 prev = Builder2.CreateSub(prev, arg, "", /*NUW*/ true,
2572 /*NSW*/ false);
2573 uint64_t num = 0;
2574 if (FT->isFloatTy()) {
2575 num = 127ULL << 23;
2576 } else {
2577 assert(FT->isDoubleTy());
2578 num = 1023ULL << 52;
2579 }
2580 prev = Builder2.CreateAdd(
2581 prev, ConstantInt::get(prev->getType(), num, false), "",
2582 /*NUW*/ true, /*NSW*/ true);
2583 prev = Builder2.CreateBitCast(
2584 checkedMul(gutils->strongZero, Builder2,
2585 Builder2.CreateBitCast(idiff, FT),
2586 Builder2.CreateBitCast(prev, FT)),
2587 prev->getType());
2588 return prev;
2589 };
2590
2591 Value *prev = applyChainRule(BO.getOperand(1 - i)->getType(),
2592 Builder2, rule, idiff);
2593 addToDiffe(BO.getOperand(1 - i), prev, Builder2, FT);
2594 return;
2595 }
2596 }
2597 }
2598 if (looseTypeAnalysis) {
2599 llvm::errs() << "warning: binary operator is integer and constant: "
2600 << BO << "\n";
2601 // if loose type analysis, assume this integer or is constant
2602 return;
2603 }
2604 goto def;
2605 }
2606 case Instruction::UDiv:
2607 case Instruction::URem:
2608 case Instruction::SRem:
2609 case Instruction::SDiv:
2610 case Instruction::Shl:
2611 case Instruction::Mul:
2612 case Instruction::Sub:
2613 case Instruction::Add: {
2614 if (looseTypeAnalysis) {
2615 llvm::errs()
2616 << "warning: binary operator is integer and assumed constant: "
2617 << BO << "\n";
2618 // if loose type analysis, assume this integer add is constant
2619 return;
2620 }
2621 goto def;
2622 }
2623 default:
2624 def:;
2625 std::string s;
2626 llvm::raw_string_ostream ss(s);
2627 ss << *gutils->oldFunc << "\n";
2628 for (auto &arg : gutils->oldFunc->args()) {
2629 ss << " constantarg[" << arg << "] = " << gutils->isConstantValue(&arg)
2630 << " type: " << TR.query(&arg).str() << " - vals: {";
2631 for (auto v : TR.knownIntegralValues(&arg))
2632 ss << v << ",";
2633 ss << "}\n";
2634 }
2635 for (auto &BB : *gutils->oldFunc)
2636 for (auto &I : BB) {
2637 ss << " constantinst[" << I
2638 << "] = " << gutils->isConstantInstruction(&I)
2639 << " val:" << gutils->isConstantValue(&I)
2640 << " type: " << TR.query(&I).str() << "\n";
2641 }
2642 ss << "cannot handle unknown binary operator: " << BO << "\n";
2643 EmitNoDerivativeError(ss.str(), BO, gutils, Builder2);
2644 }
2645
2646 done:;
2647 if (dif0 || dif1)
2648 setDiffe(&BO, Constant::getNullValue(gutils->getShadowType(BO.getType())),
2649 Builder2);
2650 if (dif0)
2651 addToDiffe(orig_op0, dif0, Builder2, addingType);
2652 if (dif1)
2653 addToDiffe(orig_op1, dif1, Builder2, addingType);
2654 }
2655
2656 void createBinaryOperatorDual(llvm::BinaryOperator &BO) {
2657 using namespace llvm;
2658
2659 if (gutils->isConstantInstruction(&BO)) {
2661 return;
2662 }
2663
2664 IRBuilder<> Builder2(&BO);
2665 getForwardBuilder(Builder2);
2666
2667 Value *orig_op0 = BO.getOperand(0);
2668 Value *orig_op1 = BO.getOperand(1);
2669
2670 bool constantval0 = gutils->isConstantValue(orig_op0);
2671 bool constantval1 = gutils->isConstantValue(orig_op1);
2672
2673 switch (BO.getOpcode()) {
2674 case Instruction::And: {
2675 // If & against 0b10000000000 and a float the result is 0
2676 auto &dl = gutils->oldFunc->getParent()->getDataLayout();
2677 auto size = dl.getTypeSizeInBits(BO.getType()) / 8;
2678 Type *diffTy = gutils->getShadowType(BO.getType());
2679
2680 auto FT = TR.query(&BO).IsAllFloat(size, dl);
2681 auto eFT = FT;
2682 if (FT)
2683 for (int i = 0; i < 2; ++i) {
2684 auto CI = dyn_cast<ConstantInt>(BO.getOperand(i));
2685 if (CI && dl.getTypeSizeInBits(eFT) ==
2686 dl.getTypeSizeInBits(CI->getType())) {
2687 if (eFT->isDoubleTy() && CI->getValue() == -134217728) {
2688 setDiffe(&BO, Constant::getNullValue(diffTy), Builder2);
2689 // Derivative is zero (equivalent to rounding as just chopping off
2690 // bits of mantissa), no update
2691 return;
2692 }
2693 }
2694 }
2695 if (looseTypeAnalysis) {
2697 llvm::errs() << "warning: binary operator is integer and constant: "
2698 << BO << "\n";
2699 // if loose type analysis, assume this integer and is constant
2700 return;
2701 }
2702 goto def;
2703 }
2704 case Instruction::Xor: {
2705 auto &dl = gutils->oldFunc->getParent()->getDataLayout();
2706 auto size = dl.getTypeSizeInBits(BO.getType()) / 8;
2707
2708 auto FT = TR.query(&BO).IsAllFloat(size, dl);
2709 auto eFT = FT;
2710
2711 Value *dif[2] = {constantval0 ? nullptr : diffe(orig_op0, Builder2),
2712 constantval1 ? nullptr : diffe(orig_op1, Builder2)};
2713
2714 for (int i = 0; i < 2; ++i) {
2715 if (containsOnlyAtMostTopBit(BO.getOperand(i), eFT, dl, &FT) &&
2716 dif[1 - i] && !dif[i]) {
2717 auto isZero = Builder2.CreateICmpEQ(
2718 gutils->getNewFromOriginal(BO.getOperand(i)),
2719 Constant::getNullValue(BO.getType()));
2720 auto rule = [&](Value *idiff) {
2721 auto ext = Builder2.CreateBitCast(idiff, FT);
2722 auto neg = Builder2.CreateFNeg(ext);
2723 neg = CreateSelect(Builder2, isZero, ext, neg);
2724 neg = Builder2.CreateBitCast(neg, BO.getType());
2725 return neg;
2726 };
2727 auto bc = applyChainRule(BO.getOperand(1 - i)->getType(), Builder2,
2728 rule, dif[1 - i]);
2729 setDiffe(&BO, bc, Builder2);
2730 return;
2731 }
2732 }
2733 if (looseTypeAnalysis) {
2735 llvm::errs() << "warning: binary operator is integer and constant: "
2736 << BO << "\n";
2737 // if loose type analysis, assume this integer and is constant
2738 return;
2739 }
2740 goto def;
2741 }
2742 case Instruction::Or: {
2743 auto &dl = gutils->oldFunc->getParent()->getDataLayout();
2744 auto size = dl.getTypeSizeInBits(BO.getType()) / 8;
2745
2746 Value *dif[2] = {constantval0 ? nullptr : diffe(orig_op0, Builder2),
2747 constantval1 ? nullptr : diffe(orig_op1, Builder2)};
2748
2749 auto FT = TR.query(&BO).IsAllFloat(size, dl);
2750 auto eFT = FT;
2751 // If & against 0b10000000000 and a float the result is a float
2752 if (FT)
2753 for (int i = 0; i < 2; ++i) {
2754 auto CI = dyn_cast<ConstantInt>(BO.getOperand(i));
2755 if (auto CV = dyn_cast<ConstantVector>(BO.getOperand(i))) {
2756 CI = dyn_cast_or_null<ConstantInt>(CV->getSplatValue());
2757 FT = VectorType::get(FT, CV->getType()->getElementCount());
2758 }
2759 if (auto CV = dyn_cast<ConstantDataVector>(BO.getOperand(i))) {
2760 CI = dyn_cast_or_null<ConstantInt>(CV->getSplatValue());
2761 }
2762 if (CI && dl.getTypeSizeInBits(eFT) ==
2763 dl.getTypeSizeInBits(CI->getType())) {
2764 auto AP = CI->getValue();
2765 bool validXor = false;
2766#if LLVM_VERSION_MAJOR > 16
2767 if (AP.isZero())
2768#else
2769 if (AP.isNullValue())
2770#endif
2771 {
2772 validXor = true;
2773 } else if (
2774 !AP.isNegative() &&
2775 ((FT->isFloatTy()
2776#if LLVM_VERSION_MAJOR > 16
2777 && (AP & ~0b01111111100000000000000000000000ULL).isZero()
2778#else
2779 && (AP & ~0b01111111100000000000000000000000ULL).isNullValue()
2780#endif
2781 ) ||
2782 (FT->isDoubleTy()
2783#if LLVM_VERSION_MAJOR > 16
2784 &&
2785 (AP &
2786 ~0b0111111111110000000000000000000000000000000000000000000000000000ULL)
2787 .isZero()
2788#else
2789 &&
2790 (AP &
2791 ~0b0111111111110000000000000000000000000000000000000000000000000000ULL)
2792 .isNullValue()
2793#endif
2794 ))) {
2795 validXor = true;
2796 }
2797 if (validXor) {
2798 auto rule = [&](Value *difi) {
2799 auto arg = gutils->getNewFromOriginal(BO.getOperand(1 - i));
2800 auto prev = Builder2.CreateOr(arg, BO.getOperand(i));
2801 prev = Builder2.CreateSub(prev, arg, "", /*NUW*/ true,
2802 /*NSW*/ false);
2803 uint64_t num = 0;
2804 if (FT->isFloatTy()) {
2805 num = 127ULL << 23;
2806 } else {
2807 assert(FT->isDoubleTy());
2808 num = 1023ULL << 52;
2809 }
2810 prev = Builder2.CreateAdd(
2811 prev, ConstantInt::get(prev->getType(), num, false), "",
2812 /*NUW*/ true, /*NSW*/ true);
2813 prev = Builder2.CreateBitCast(
2814 checkedMul(gutils->strongZero, Builder2,
2815 Builder2.CreateBitCast(difi, FT),
2816 Builder2.CreateBitCast(prev, FT)),
2817 prev->getType());
2818
2819 return prev;
2820 };
2821
2822 auto diffe =
2823 applyChainRule(BO.getType(), Builder2, rule, dif[1 - i]);
2824 setDiffe(&BO, diffe, Builder2);
2825 return;
2826 }
2827 }
2828 }
2829 if (looseTypeAnalysis) {
2831 llvm::errs() << "warning: binary operator is integer and constant: "
2832 << BO << "\n";
2833 // if loose type analysis, assume this integer or is constant
2834 return;
2835 }
2836 goto def;
2837 }
2838 case Instruction::LShr: {
2839 if (!gutils->isConstantValue(orig_op0)) {
2840 if (auto ci = dyn_cast<ConstantInt>(orig_op1)) {
2841 size_t size = 1;
2842 if (orig_op0->getType()->isSized())
2843 size = (gutils->newFunc->getParent()
2844 ->getDataLayout()
2845 .getTypeSizeInBits(orig_op0->getType()) +
2846 7) /
2847 8;
2848
2849 if (Type *flt = TR.addingType(size, orig_op0)) {
2850 auto bits = gutils->newFunc->getParent()
2851 ->getDataLayout()
2852 .getTypeAllocSizeInBits(flt);
2853 if (ci->getSExtValue() >= (int64_t)bits &&
2854 ci->getSExtValue() % bits == 0) {
2855 auto rule = [&](Value *idiff) {
2856 return Builder2.CreateLShr(idiff, ci);
2857 };
2858 auto dif = applyChainRule(orig_op0->getType(), Builder2, rule,
2859 diffe(orig_op0, Builder2));
2860 setDiffe(&BO, dif, Builder2);
2861 return;
2862 }
2863 }
2864 }
2865 }
2866 if (looseTypeAnalysis) {
2868 llvm::errs() << "warning: binary operator is integer and constant: "
2869 << BO << "\n";
2870 // if loose type analysis, assume this integer or is constant
2871 return;
2872 }
2873 goto def;
2874 }
2875 case Instruction::AShr:
2876 case Instruction::SDiv:
2877 case Instruction::UDiv:
2878 case Instruction::SRem:
2879 case Instruction::URem:
2880 case Instruction::Shl:
2881 case Instruction::Mul:
2882 case Instruction::Sub:
2883 case Instruction::Add: {
2884 if (looseTypeAnalysis) {
2886 llvm::errs() << "warning: binary operator is integer and constant: "
2887 << BO << "\n";
2888 // if loose type analysis, assume this integer add is constant
2889 return;
2890 }
2891 goto def;
2892 }
2893 default:
2894 def:;
2895 std::string s;
2896 llvm::raw_string_ostream ss(s);
2897 ss << *gutils->oldFunc << "\n";
2898 for (auto &arg : gutils->oldFunc->args()) {
2899 ss << " constantarg[" << arg << "] = " << gutils->isConstantValue(&arg)
2900 << " type: " << TR.query(&arg).str() << " - vals: {";
2901 for (auto v : TR.knownIntegralValues(&arg))
2902 ss << v << ",";
2903 ss << "}\n";
2904 }
2905 for (auto &BB : *gutils->oldFunc)
2906 for (auto &I : BB) {
2907 ss << " constantinst[" << I
2908 << "] = " << gutils->isConstantInstruction(&I)
2909 << " val:" << gutils->isConstantValue(&I)
2910 << " type: " << TR.query(&I).str() << "\n";
2911 }
2912 ss << "cannot handle unknown binary operator: " << BO << "\n";
2913 auto rval = EmitNoDerivativeError(ss.str(), BO, gutils, Builder2);
2914 if (!rval)
2915 rval = Constant::getNullValue(gutils->getShadowType(BO.getType()));
2916 auto ifound = gutils->invertedPointers.find(&BO);
2917 if (!gutils->isConstantValue(&BO)) {
2918 if (ifound != gutils->invertedPointers.end()) {
2919 auto placeholder = cast<PHINode>(&*ifound->second);
2920 gutils->invertedPointers.erase(ifound);
2921 gutils->replaceAWithB(placeholder, rval);
2922 gutils->erase(placeholder);
2923 gutils->invertedPointers.insert(std::make_pair(
2924 (const Value *)&BO, InvertedPointerVH(gutils, rval)));
2925 }
2926 } else {
2927 assert(ifound == gutils->invertedPointers.end());
2928 }
2929 break;
2930 }
2931 }
2932
2933 void visitMemSetInst(llvm::MemSetInst &MS) { visitMemSetCommon(MS); }
2934
2935 void visitMemSetCommon(llvm::CallInst &MS) {
2936 using namespace llvm;
2937
2938 IRBuilder<> BuilderZ(&MS);
2939 getForwardBuilder(BuilderZ);
2940
2941 IRBuilder<> Builder2(&MS);
2944 getReverseBuilder(Builder2);
2945
2946 bool forceErase = false;
2948 for (const auto &pair : gutils->rematerializableAllocations) {
2949 if (pair.second.stores.count(&MS) && pair.second.LI) {
2950 forceErase = true;
2951 }
2952 }
2953 }
2954 if (forceErase)
2955 eraseIfUnused(MS, /*erase*/ true, /*check*/ false);
2956 else
2957 eraseIfUnused(MS);
2958
2959 Value *orig_op0 = MS.getArgOperand(0);
2960 Value *orig_op1 = MS.getArgOperand(1);
2961
2962 // If constant destination then no operation needs doing
2963 if (gutils->isConstantValue(orig_op0)) {
2964 return;
2965 }
2966
2967 bool activeValToSet = !gutils->isConstantValue(orig_op1);
2968 if (activeValToSet)
2969 if (auto CI = dyn_cast<ConstantInt>(orig_op1))
2970 if (CI->isZero())
2971 activeValToSet = false;
2972 if (activeValToSet) {
2973 std::string s;
2974 llvm::raw_string_ostream ss(s);
2975 ss << "couldn't handle non constant inst in memset to "
2976 "propagate differential to\n"
2977 << MS;
2978 EmitNoDerivativeError(ss.str(), MS, gutils, BuilderZ);
2979 }
2980
2981 if (Mode == DerivativeMode::ForwardMode ||
2983 Value *op0 = gutils->invertPointerM(orig_op0, BuilderZ);
2984 Value *op1 = gutils->getNewFromOriginal(MS.getArgOperand(1));
2985 Value *op2 = gutils->getNewFromOriginal(MS.getArgOperand(2));
2986 Value *op3 = nullptr;
2987 if (3 < MS.arg_size()) {
2988 op3 = gutils->getNewFromOriginal(MS.getOperand(3));
2989 }
2990
2991 auto Defs =
2992 gutils->getInvertedBundles(&MS,
2995 BuilderZ, /*lookup*/ false);
2996
2997 auto funcName = getFuncNameFromCall(&MS);
2999 BuilderZ,
3000 [&](Value *op0) {
3001 SmallVector<Value *, 4> args = {op0, op1, op2};
3002 if (op3)
3003 args.push_back(op3);
3004
3005 CallInst *cal;
3006 if (startsWith(funcName, "memset_pattern") ||
3007 startsWith(funcName, "llvm.experimental.memset"))
3008 cal = Builder2.CreateMemSet(
3009 op0, ConstantInt::get(Builder2.getInt8Ty(), 0), op2, {});
3010 else
3011 cal = BuilderZ.CreateCall(MS.getCalledFunction(), args, Defs);
3012
3013 llvm::SmallVector<unsigned int, 9> ToCopy2(MD_ToCopy);
3014 ToCopy2.push_back(LLVMContext::MD_noalias);
3015 cal->copyMetadata(MS, ToCopy2);
3016 if (auto m = hasMetadata(&MS, "enzyme_zerostack"))
3017 cal->setMetadata("enzyme_zerostack", m);
3018
3019 if (startsWith(funcName, "memset_pattern") ||
3020 startsWith(funcName, "llvm.experimental.memset")) {
3021 AttributeList NewAttrs;
3022 for (auto idx :
3023 {AttributeList::ReturnIndex, AttributeList::FunctionIndex,
3024 AttributeList::FirstArgIndex})
3025 for (auto attr : MS.getAttributes().getAttributes(idx))
3026 NewAttrs =
3027 NewAttrs.addAttributeAtIndex(MS.getContext(), idx, attr);
3028 cal->setAttributes(NewAttrs);
3029 } else
3030 cal->setAttributes(MS.getAttributes());
3031 cal->setCallingConv(MS.getCallingConv());
3032 cal->setTailCallKind(MS.getTailCallKind());
3033 cal->setDebugLoc(gutils->getNewFromOriginal(MS.getDebugLoc()));
3034 },
3035 op0);
3036 return;
3037 }
3038
3039 bool backwardsShadow = false;
3040 bool forwardsShadow = true;
3041 for (auto pair : gutils->backwardsOnlyShadows) {
3042 if (pair.second.stores.count(&MS)) {
3043 backwardsShadow = true;
3044 forwardsShadow = pair.second.primalInitialize;
3045 if (auto inst = dyn_cast<Instruction>(pair.first))
3046 if (!forwardsShadow && pair.second.LI &&
3047 pair.second.LI->contains(inst->getParent()))
3048 backwardsShadow = false;
3049 }
3050 }
3051
3052 size_t size = 1;
3053 if (auto ci = dyn_cast<ConstantInt>(MS.getOperand(2))) {
3054 size = ci->getLimitedValue();
3055 }
3056
3057 // TODO note that we only handle memset of ONE type (aka memset of {int,
3058 // double} not allowed)
3059
3060 if (size == 0) {
3061 llvm::errs() << MS << "\n";
3062 }
3063 assert(size != 0);
3064
3065 // Offsets of the form Optional<floating type>, segment start, segment size
3066 std::vector<std::tuple<Type *, size_t, size_t>> toIterate;
3067
3068 // Special handling mechanism to bypass TA limitations by supporting
3069 // arbitrary sized types.
3070 if (auto MD = hasMetadata(&MS, "enzyme_truetype")) {
3071 toIterate = parseTrueType(MD, Mode, false);
3072 } else {
3073 auto &DL = gutils->newFunc->getParent()->getDataLayout();
3074 auto vd = TR.query(MS.getOperand(0)).Data0().ShiftIndices(DL, 0, size, 0);
3075
3076 if (!vd.isKnownPastPointer()) {
3077 // If unknown type results, and zeroing known undef allocation, consider
3078 // integers
3079 if (auto CI = dyn_cast<ConstantInt>(MS.getOperand(1)))
3080 if (CI->isZero()) {
3081 auto root = getBaseObject(MS.getOperand(0));
3082 bool writtenTo = false;
3083 bool undefMemory =
3084 isa<AllocaInst>(root) || isAllocationCall(root, gutils->TLI);
3085 if (auto arg = dyn_cast<Argument>(root))
3086 if (arg->hasStructRetAttr())
3087 undefMemory = true;
3088 if (undefMemory) {
3089 Instruction *cur = MS.getPrevNode();
3090 while (cur) {
3091 if (cur == root)
3092 break;
3093 if (auto MCI = dyn_cast<ConstantInt>(MS.getOperand(2))) {
3094 if (auto II = dyn_cast<IntrinsicInst>(cur)) {
3095 if (II->getCalledFunction()->getName() ==
3096 "llvm.enzyme.lifetime_start") {
3097 if (getBaseObject(II->getOperand(1)) == root) {
3098 if (auto CI2 =
3099 dyn_cast<ConstantInt>(II->getOperand(0))) {
3100 if (MCI->getValue().ule(CI2->getValue()))
3101 break;
3102 }
3103 }
3104 cur = cur->getPrevNode();
3105 continue;
3106 }
3107 // If the start of the lifetime for more memory than being
3108 // memset, its valid.
3109 if (II->getIntrinsicID() == Intrinsic::lifetime_start) {
3110 if (getBaseObject(II->getOperand(1)) == root) {
3111 if (auto CI2 =
3112 dyn_cast<ConstantInt>(II->getOperand(0))) {
3113 if (MCI->getValue().ule(CI2->getValue()))
3114 break;
3115 }
3116 }
3117 cur = cur->getPrevNode();
3118 continue;
3119 }
3120 }
3121 }
3122 if (cur->mayWriteToMemory()) {
3123 writtenTo = true;
3124 break;
3125 }
3126 cur = cur->getPrevNode();
3127 }
3128
3129 if (!writtenTo) {
3131 vd.insert({-1}, BaseType::Integer);
3132 }
3133 }
3134 }
3135 }
3136
3137 if (!vd.isKnownPastPointer()) {
3138 // If unknown type results, consider the intersection of all incoming.
3139 if (isa<PHINode>(MS.getOperand(0)) ||
3140 isa<SelectInst>(MS.getOperand(0))) {
3141 SmallVector<Value *, 2> todo = {MS.getOperand(0)};
3142 bool set = false;
3143 SmallSet<Value *, 2> seen;
3144 TypeTree vd2;
3145 while (todo.size()) {
3146 Value *cur = todo.back();
3147 todo.pop_back();
3148 if (seen.count(cur))
3149 continue;
3150 seen.insert(cur);
3151 if (auto PN = dyn_cast<PHINode>(cur)) {
3152 for (size_t i = 0, end = PN->getNumIncomingValues(); i < end;
3153 i++) {
3154 todo.push_back(PN->getIncomingValue(i));
3155 }
3156 continue;
3157 }
3158 if (auto S = dyn_cast<SelectInst>(cur)) {
3159 todo.push_back(S->getTrueValue());
3160 todo.push_back(S->getFalseValue());
3161 continue;
3162 }
3163 if (auto CE = dyn_cast<ConstantExpr>(cur)) {
3164 if (CE->isCast()) {
3165 todo.push_back(CE->getOperand(0));
3166 continue;
3167 }
3168 }
3169 if (auto CI = dyn_cast<CastInst>(cur)) {
3170 todo.push_back(CI->getOperand(0));
3171 continue;
3172 }
3173 if (isa<ConstantPointerNull>(cur))
3174 continue;
3175 if (auto CI = dyn_cast<ConstantInt>(cur))
3176 if (CI->isZero())
3177 continue;
3178 auto curTT = TR.query(cur).Data0().ShiftIndices(DL, 0, size, 0);
3179 if (!set)
3180 vd2 = curTT;
3181 else
3182 vd2 &= curTT;
3183 set = true;
3184 }
3185 vd = vd2;
3186 }
3187 }
3188 if (!vd.isKnownPastPointer()) {
3189 if (looseTypeAnalysis) {
3190#if LLVM_VERSION_MAJOR < 17
3191 if (auto CI = dyn_cast<CastInst>(MS.getOperand(0))) {
3192 if (auto PT = dyn_cast<PointerType>(CI->getSrcTy())) {
3193 auto ET = PT->getPointerElementType();
3194 while (1) {
3195 if (auto ST = dyn_cast<StructType>(ET)) {
3196 if (ST->getNumElements()) {
3197 ET = ST->getElementType(0);
3198 continue;
3199 }
3200 }
3201 if (auto AT = dyn_cast<ArrayType>(ET)) {
3202 ET = AT->getElementType();
3203 continue;
3204 }
3205 break;
3206 }
3207 if (ET->isFPOrFPVectorTy()) {
3208 vd = TypeTree(ConcreteType(ET->getScalarType())).Only(0, &MS);
3209 goto known;
3210 }
3211 if (ET->isPointerTy()) {
3212 vd = TypeTree(BaseType::Pointer).Only(0, &MS);
3213 goto known;
3214 }
3215 if (ET->isIntOrIntVectorTy()) {
3216 vd = TypeTree(BaseType::Integer).Only(0, &MS);
3217 goto known;
3218 }
3219 }
3220 }
3221#endif
3222 if (auto gep = dyn_cast<GetElementPtrInst>(MS.getOperand(0))) {
3223 if (auto AT = dyn_cast<ArrayType>(gep->getSourceElementType())) {
3224 if (AT->getElementType()->isIntegerTy()) {
3225 vd = TypeTree(BaseType::Integer).Only(0, &MS);
3226 goto known;
3227 }
3228 }
3229 }
3230 EmitWarning("CannotDeduceType", MS,
3231 "failed to deduce type of memset ", MS);
3232 vd = TypeTree(BaseType::Pointer).Only(0, &MS);
3233 goto known;
3234 }
3235 std::string str;
3236 raw_string_ostream ss(str);
3237 ss << "Cannot deduce type of memset " << MS;
3238 EmitNoTypeError(str, MS, gutils, BuilderZ);
3239 return;
3240 }
3241 known:;
3242 {
3243 unsigned start = 0;
3244 while (1) {
3245 unsigned nextStart = size;
3246
3247 auto dt = vd[{-1}];
3248 for (size_t i = start; i < size; ++i) {
3249 bool Legal = true;
3250 dt.checkedOrIn(vd[{(int)i}], /*PointerIntSame*/ true, Legal);
3251 if (!Legal) {
3252 nextStart = i;
3253 break;
3254 }
3255 }
3256 if (!dt.isKnown()) {
3257 TR.dump();
3258 llvm::errs() << " vd:" << vd.str() << " start:" << start
3259 << " size: " << size << " dt:" << dt.str() << "\n";
3260 }
3261 assert(dt.isKnown());
3262 toIterate.emplace_back(dt.isFloat(), start, nextStart - start);
3263
3264 if (nextStart == size)
3265 break;
3266 start = nextStart;
3267 }
3268 }
3269 }
3270
3271#if 0
3272 unsigned dstalign = dstAlign.valueOrOne().value();
3273 unsigned srcalign = srcAlign.valueOrOne().value();
3274#endif
3275
3276 Value *op1 = gutils->getNewFromOriginal(MS.getArgOperand(1));
3277 Value *new_size = gutils->getNewFromOriginal(MS.getArgOperand(2));
3278 Value *op3 = nullptr;
3279 if (3 < MS.arg_size()) {
3280 op3 = gutils->getNewFromOriginal(MS.getOperand(3));
3281 }
3282
3283 for (auto &&[secretty_ref, seg_start_ref, seg_size_ref] : toIterate) {
3284 auto secretty = secretty_ref;
3285 auto seg_start = seg_start_ref;
3286 auto seg_size = seg_size_ref;
3287
3288 Value *length = new_size;
3289 if (seg_start != std::get<1>(toIterate.back())) {
3290 length = ConstantInt::get(new_size->getType(), seg_start + seg_size);
3291 }
3292 if (seg_start != 0)
3293 length = BuilderZ.CreateSub(
3294 length, ConstantInt::get(new_size->getType(), seg_start));
3295
3296#if 0
3297 unsigned subdstalign = dstalign;
3298 // todo make better alignment calculation
3299 if (dstalign != 0) {
3300 if (start % dstalign != 0) {
3301 dstalign = 1;
3302 }
3303 }
3304 unsigned subsrcalign = srcalign;
3305 // todo make better alignment calculation
3306 if (srcalign != 0) {
3307 if (start % srcalign != 0) {
3308 srcalign = 1;
3309 }
3310 }
3311#endif
3312
3313 Value *shadow_dst = gutils->invertPointerM(MS.getOperand(0), BuilderZ);
3314
3315 // TODO ponder forward split mode
3316 if (!secretty &&
3317 ((Mode == DerivativeMode::ReverseModePrimal && forwardsShadow) ||
3318 (Mode == DerivativeMode::ReverseModeCombined && forwardsShadow) ||
3319 (Mode == DerivativeMode::ReverseModeGradient && backwardsShadow) ||
3320 (Mode == DerivativeMode::ForwardModeSplit && backwardsShadow))) {
3321 auto Defs =
3322 gutils->getInvertedBundles(&MS,
3325 BuilderZ, /*lookup*/ false);
3326 auto rule = [&](Value *op0) {
3327 if (seg_start != 0) {
3328 Value *idxs[] = {ConstantInt::get(
3329 Type::getInt32Ty(op0->getContext()), seg_start)};
3330 op0 = BuilderZ.CreateInBoundsGEP(Type::getInt8Ty(op0->getContext()),
3331 op0, idxs);
3332 }
3333 SmallVector<Value *, 4> args = {op0, op1, length};
3334 if (op3)
3335 args.push_back(op3);
3336 auto cal = BuilderZ.CreateCall(MS.getCalledFunction(), args, Defs);
3337 llvm::SmallVector<unsigned int, 9> ToCopy2(MD_ToCopy);
3338 ToCopy2.push_back(LLVMContext::MD_noalias);
3339 if (auto m = hasMetadata(&MS, "enzyme_zerostack"))
3340 cal->setMetadata("enzyme_zerostack", m);
3341 cal->copyMetadata(MS, ToCopy2);
3342 cal->setAttributes(MS.getAttributes());
3343 cal->setCallingConv(MS.getCallingConv());
3344 cal->setTailCallKind(MS.getTailCallKind());
3345 cal->setDebugLoc(gutils->getNewFromOriginal(MS.getDebugLoc()));
3346 };
3347
3348 applyChainRule(BuilderZ, rule, shadow_dst);
3349 }
3350 if (secretty && (Mode == DerivativeMode::ReverseModeGradient ||
3352
3353 auto Defs =
3354 gutils->getInvertedBundles(&MS,
3357 BuilderZ, /*lookup*/ true);
3358 Value *op1l = gutils->lookupM(op1, Builder2);
3359 Value *op3l = op3;
3360 if (op3l)
3361 op3l = gutils->lookupM(op3l, BuilderZ);
3362 length = gutils->lookupM(length, Builder2);
3363 auto rule = [&](Value *op0) {
3364 if (seg_start != 0) {
3365 Value *idxs[] = {ConstantInt::get(
3366 Type::getInt32Ty(op0->getContext()), seg_start)};
3367 op0 = Builder2.CreateInBoundsGEP(Type::getInt8Ty(op0->getContext()),
3368 op0, idxs);
3369 }
3370 SmallVector<Value *, 4> args = {op0, op1l, length};
3371 if (op3l)
3372 args.push_back(op3l);
3373 CallInst *cal;
3374 auto funcName = getFuncNameFromCall(&MS);
3375 if (startsWith(funcName, "memset_pattern") ||
3376 startsWith(funcName, "llvm.experimental.memset"))
3377 cal = Builder2.CreateMemSet(
3378 op0, ConstantInt::get(Builder2.getInt8Ty(), 0), length, {});
3379 else
3380 cal = Builder2.CreateCall(MS.getCalledFunction(), args, Defs);
3381 llvm::SmallVector<unsigned int, 9> ToCopy2(MD_ToCopy);
3382 ToCopy2.push_back(LLVMContext::MD_noalias);
3383 cal->copyMetadata(MS, ToCopy2);
3384 if (auto m = hasMetadata(&MS, "enzyme_zerostack"))
3385 cal->setMetadata("enzyme_zerostack", m);
3386
3387 if (startsWith(funcName, "memset_pattern") ||
3388 startsWith(funcName, "llvm.experimental.memset")) {
3389 AttributeList NewAttrs;
3390 for (auto idx :
3391 {AttributeList::ReturnIndex, AttributeList::FunctionIndex,
3392 AttributeList::FirstArgIndex})
3393 for (auto attr : MS.getAttributes().getAttributes(idx))
3394 NewAttrs =
3395 NewAttrs.addAttributeAtIndex(MS.getContext(), idx, attr);
3396 cal->setAttributes(NewAttrs);
3397 } else
3398 cal->setAttributes(MS.getAttributes());
3399 cal->setCallingConv(MS.getCallingConv());
3400 cal->setDebugLoc(gutils->getNewFromOriginal(MS.getDebugLoc()));
3401 };
3402
3403 applyChainRule(Builder2, rule, gutils->lookupM(shadow_dst, Builder2));
3404 }
3405 }
3406 }
3407
3408 void visitMemTransferInst(llvm::MemTransferInst &MTI) {
3409 using namespace llvm;
3410 Value *isVolatile = gutils->getNewFromOriginal(MTI.getOperand(3));
3411 auto srcAlign = MTI.getSourceAlign();
3412 auto dstAlign = MTI.getDestAlign();
3413 visitMemTransferCommon(MTI.getIntrinsicID(), srcAlign, dstAlign, MTI,
3414 MTI.getOperand(0), MTI.getOperand(1),
3415 gutils->getNewFromOriginal(MTI.getOperand(2)),
3416 isVolatile);
3417 }
3418
3419 void visitMemTransferCommon(llvm::Intrinsic::ID ID, llvm::MaybeAlign srcAlign,
3420 llvm::MaybeAlign dstAlign, llvm::CallInst &MTI,
3421 llvm::Value *orig_dst, llvm::Value *orig_src,
3422 llvm::Value *new_size, llvm::Value *isVolatile) {
3423 using namespace llvm;
3424
3425 if (gutils->isConstantValue(MTI.getOperand(0))) {
3426 eraseIfUnused(MTI);
3427 return;
3428 }
3429
3430 if (unnecessaryStores.count(&MTI)) {
3431 eraseIfUnused(MTI);
3432 return;
3433 }
3434
3435 // memcpy of size 1 cannot move differentiable data [single byte copy]
3436 if (auto ci = dyn_cast<ConstantInt>(new_size)) {
3437 if (ci->getValue() == 1) {
3438 eraseIfUnused(MTI);
3439 return;
3440 }
3441 }
3442
3443 // copying into nullptr is invalid (not sure why it exists here), but we
3444 // shouldn't do it in reverse pass or shadow
3445 if (isa<ConstantPointerNull>(orig_dst) ||
3446 TR.query(orig_dst)[{-1}] == BaseType::Anything) {
3447 eraseIfUnused(MTI);
3448 return;
3449 }
3450
3451 size_t size = 1;
3452 if (auto ci = dyn_cast<ConstantInt>(new_size)) {
3453 size = ci->getLimitedValue();
3454 }
3455
3456 // TODO note that we only handle memcpy/etc of ONE type (aka memcpy of {int,
3457 // double} not allowed)
3458 if (size == 0) {
3459 eraseIfUnused(MTI);
3460 return;
3461 }
3462
3463 if ((Mode == DerivativeMode::ForwardMode ||
3465 gutils->isConstantValue(orig_dst)) {
3466 eraseIfUnused(MTI);
3467 return;
3468 }
3469
3470 // Offsets of the form Optional<floating type>, segment start, segment size
3471 std::vector<std::tuple<Type *, size_t, size_t>> toIterate;
3472 IRBuilder<> BuilderZ(gutils->getNewFromOriginal(&MTI));
3473
3474 // Special handling mechanism to bypass TA limitations by supporting
3475 // arbitrary sized types.
3476 if (auto MD = hasMetadata(&MTI, "enzyme_truetype")) {
3477 toIterate = parseTrueType(MD, Mode,
3478 !gutils->isConstantValue(orig_src) &&
3479 !gutils->runtimeActivity);
3480 } else {
3481 auto &DL = gutils->newFunc->getParent()->getDataLayout();
3482 auto vd = TR.query(orig_dst).Data0().ShiftIndices(DL, 0, size, 0);
3483 vd |= TR.query(orig_src).Data0().PurgeAnything().ShiftIndices(DL, 0, size,
3484 0);
3485 for (size_t i = 0; i < MTI.getNumOperands(); i++)
3486 if (MTI.getOperand(i) == orig_dst)
3487 if (MTI.getAttributes().hasParamAttr(i, "enzyme_type")) {
3488 auto attr = MTI.getAttributes().getParamAttr(i, "enzyme_type");
3489 auto TT =
3490 TypeTree::parse(attr.getValueAsString(), MTI.getContext());
3491 vd |= TT.Data0().ShiftIndices(DL, 0, size, 0);
3492 break;
3493 }
3494
3495 bool errorIfNoType = true;
3496 if ((Mode == DerivativeMode::ForwardMode ||
3498 (!gutils->isConstantValue(orig_src) && !gutils->runtimeActivity)) {
3499 errorIfNoType = false;
3500 }
3501
3502 if (!vd.isKnownPastPointer()) {
3503 if (looseTypeAnalysis) {
3504 for (auto val : {orig_dst, orig_src}) {
3505#if LLVM_VERSION_MAJOR < 17
3506 if (auto CI = dyn_cast<CastInst>(val)) {
3507 if (auto PT = dyn_cast<PointerType>(CI->getSrcTy())) {
3508 auto ET = PT->getPointerElementType();
3509 while (1) {
3510 if (auto ST = dyn_cast<StructType>(ET)) {
3511 if (ST->getNumElements()) {
3512 ET = ST->getElementType(0);
3513 continue;
3514 }
3515 }
3516 if (auto AT = dyn_cast<ArrayType>(ET)) {
3517 ET = AT->getElementType();
3518 continue;
3519 }
3520 break;
3521 }
3522 if (ET->isFPOrFPVectorTy()) {
3523 vd =
3524 TypeTree(ConcreteType(ET->getScalarType())).Only(0, &MTI);
3525 goto known;
3526 }
3527 if (ET->isPointerTy()) {
3528 vd = TypeTree(BaseType::Pointer).Only(0, &MTI);
3529 goto known;
3530 }
3531 if (ET->isIntOrIntVectorTy()) {
3532 vd = TypeTree(BaseType::Integer).Only(0, &MTI);
3533 goto known;
3534 }
3535 }
3536 }
3537#endif
3538 if (auto gep = dyn_cast<GetElementPtrInst>(val)) {
3539 if (auto AT = dyn_cast<ArrayType>(gep->getSourceElementType())) {
3540 if (AT->getElementType()->isIntegerTy()) {
3541 vd = TypeTree(BaseType::Integer).Only(0, &MTI);
3542 goto known;
3543 }
3544 }
3545 }
3546 }
3547 // If the type is known, but outside of the known range
3548 // (but the memcpy size is a variable), attempt to use
3549 // the first type out of range as the memcpy type.
3550 if (size == 1 && !isa<ConstantInt>(new_size)) {
3551 for (auto ptr : {orig_dst, orig_src}) {
3552 vd = TR.query(ptr).Data0().ShiftIndices(DL, 0, -1, 0);
3553 if (vd.isKnownPastPointer()) {
3555 size_t minInt = 0xFFFFFFFF;
3556 for (const auto &pair : vd.getMapping()) {
3557 if (pair.first.size() != 1)
3558 continue;
3559 if (minInt < (size_t)pair.first[0])
3560 continue;
3561 minInt = pair.first[0];
3562 mv = pair.second;
3563 }
3564 assert(mv != BaseType::Unknown);
3565 vd.insert({0}, mv);
3566 goto known;
3567 }
3568 }
3569 }
3570 if (errorIfNoType)
3571 EmitWarning("CannotDeduceType", MTI,
3572 "failed to deduce type of copy ", MTI);
3573 vd = TypeTree(BaseType::Pointer).Only(0, &MTI);
3574 goto known;
3575 }
3576 if (errorIfNoType) {
3577 std::string str;
3578 raw_string_ostream ss(str);
3579 ss << "Cannot deduce type of copy " << MTI;
3580 EmitNoTypeError(str, MTI, gutils, BuilderZ);
3581 vd = TypeTree(BaseType::Integer).Only(0, &MTI);
3582 } else {
3583 vd = TypeTree(BaseType::Pointer).Only(0, &MTI);
3584 }
3585 }
3586
3587 known:;
3588 {
3589
3590 unsigned start = 0;
3591 while (1) {
3592 unsigned nextStart = size;
3593
3594 auto dt = vd[{-1}];
3595 for (size_t i = start; i < size; ++i) {
3596 bool Legal = true;
3597 auto tmp = dt;
3598 auto next = vd[{(int)i}];
3599 tmp.checkedOrIn(next, /*PointerIntSame*/ true, Legal);
3600 // Prevent fusion of {Anything, Float} since anything is an int rule
3601 // but float requires zeroing.
3602 if ((dt == BaseType::Anything &&
3603 (next != BaseType::Anything && next.isKnown())) ||
3604 (next == BaseType::Anything &&
3605 (dt != BaseType::Anything && dt.isKnown())))
3606 Legal = false;
3607 if (!Legal) {
3608 if (Mode == DerivativeMode::ForwardMode ||
3610 // if both are floats (of any type), forward mode is the same.
3611 // + [potentially zero if const, otherwise copy]
3612 // if both are int/pointer (of any type), also the same
3613 // + copy
3614 // if known non-constant, also the same
3615 // + copy
3616 if ((dt.isFloat() == nullptr) ==
3617 (vd[{(int)i}].isFloat() == nullptr)) {
3618 Legal = true;
3619 }
3620 if (!gutils->isConstantValue(orig_src) &&
3621 !gutils->runtimeActivity) {
3622 Legal = true;
3623 }
3624 }
3625 if (!Legal) {
3626 nextStart = i;
3627 break;
3628 }
3629 } else
3630 dt = tmp;
3631 }
3632 if (!dt.isKnown()) {
3633 TR.dump();
3634 llvm::errs() << " vd:" << vd.str() << " start:" << start
3635 << " size: " << size << " dt:" << dt.str() << "\n";
3636 }
3637 assert(dt.isKnown());
3638 toIterate.emplace_back(dt.isFloat(), start, nextStart - start);
3639
3640 if (nextStart == size)
3641 break;
3642 start = nextStart;
3643 }
3644 }
3645 }
3646
3647 // llvm::errs() << "MIT: " << MTI << "|size: " << size << " vd: " <<
3648 // vd.str() << "\n";
3649
3650 unsigned dstalign = dstAlign.valueOrOne().value();
3651 unsigned srcalign = srcAlign.valueOrOne().value();
3652
3653 bool backwardsShadow = false;
3654 bool forwardsShadow = true;
3655 for (auto pair : gutils->backwardsOnlyShadows) {
3656 if (pair.second.stores.count(&MTI)) {
3657 backwardsShadow = true;
3658 forwardsShadow = pair.second.primalInitialize;
3659 if (auto inst = dyn_cast<Instruction>(pair.first))
3660 if (!forwardsShadow && pair.second.LI &&
3661 pair.second.LI->contains(inst->getParent()))
3662 backwardsShadow = false;
3663 }
3664 }
3665
3666 for (auto &&[floatTy_ref, seg_start_ref, seg_size_ref] : toIterate) {
3667 auto floatTy = floatTy_ref;
3668 auto seg_start = seg_start_ref;
3669 auto seg_size = seg_size_ref;
3670
3671 Value *length = new_size;
3672 if (seg_start != std::get<1>(toIterate.back())) {
3673 length = ConstantInt::get(new_size->getType(), seg_start + seg_size);
3674 }
3675 if (seg_start != 0)
3676 length = BuilderZ.CreateSub(
3677 length, ConstantInt::get(new_size->getType(), seg_start));
3678
3679 unsigned subdstalign = dstalign;
3680 // todo make better alignment calculation
3681 if (dstalign != 0) {
3682 if (seg_start % dstalign != 0) {
3683 dstalign = 1;
3684 }
3685 }
3686 unsigned subsrcalign = srcalign;
3687 // todo make better alignment calculation
3688 if (srcalign != 0) {
3689 if (seg_start % srcalign != 0) {
3690 srcalign = 1;
3691 }
3692 }
3693 IRBuilder<> BuilderZ(gutils->getNewFromOriginal(&MTI));
3694 Value *shadow_dst = gutils->isConstantValue(orig_dst)
3695 ? nullptr
3696 : gutils->invertPointerM(orig_dst, BuilderZ);
3697 Value *shadow_src = gutils->isConstantValue(orig_src)
3698 ? nullptr
3699 : gutils->invertPointerM(orig_src, BuilderZ);
3700
3701 auto rev_rule = [&](Value *shadow_dst, Value *shadow_src) {
3702 if (shadow_dst == nullptr)
3703 shadow_dst = gutils->getNewFromOriginal(orig_dst);
3704 if (shadow_src == nullptr)
3705 shadow_src = gutils->getNewFromOriginal(orig_src);
3707 gutils, Mode, floatTy, ID, subdstalign, subsrcalign,
3708 /*offset*/ seg_start, gutils->isConstantValue(orig_dst), shadow_dst,
3709 gutils->isConstantValue(orig_src), shadow_src,
3710 /*length*/ length, /*volatile*/ isVolatile, &MTI,
3711 /*allowForward*/ forwardsShadow, /*shadowsLookedup*/ false,
3712 /*backwardsShadow*/ backwardsShadow);
3713 };
3714
3715 auto fwd_rule = [&](Value *ddst, Value *dsrc) {
3716 if (ddst == nullptr)
3717 ddst = gutils->getNewFromOriginal(orig_dst);
3718 if (dsrc == nullptr)
3719 dsrc = gutils->getNewFromOriginal(orig_src);
3720 MaybeAlign dalign;
3721 if (subdstalign)
3722 dalign = MaybeAlign(subdstalign);
3723 MaybeAlign salign;
3724 if (subsrcalign)
3725 salign = MaybeAlign(subsrcalign);
3726 if (ddst->getType()->isIntegerTy())
3727 ddst =
3728 BuilderZ.CreateIntToPtr(ddst, getInt8PtrTy(ddst->getContext()));
3729 if (seg_start != 0) {
3730 ddst = BuilderZ.CreateConstInBoundsGEP1_64(
3731 Type::getInt8Ty(ddst->getContext()), ddst, seg_start);
3732 }
3733 CallInst *call;
3734 // TODO add gutils->runtimeActivity (correctness)
3735 if (floatTy && gutils->isConstantValue(orig_src)) {
3736 call = BuilderZ.CreateMemSet(
3737 ddst, ConstantInt::get(Type::getInt8Ty(ddst->getContext()), 0),
3738 length, dalign, isVolatile);
3739 } else {
3740 if (dsrc->getType()->isIntegerTy())
3741 dsrc =
3742 BuilderZ.CreateIntToPtr(dsrc, getInt8PtrTy(dsrc->getContext()));
3743 if (seg_start != 0) {
3744 dsrc = BuilderZ.CreateConstInBoundsGEP1_64(
3745 Type::getInt8Ty(ddst->getContext()), dsrc, seg_start);
3746 }
3747 if (ID == Intrinsic::memmove) {
3748 call = BuilderZ.CreateMemMove(ddst, dalign, dsrc, salign, length);
3749 } else {
3750 call = BuilderZ.CreateMemCpy(ddst, dalign, dsrc, salign, length);
3751 }
3752 call->setAttributes(MTI.getAttributes());
3753 }
3754 // TODO shadow scope/noalias (performance)
3755 call->setMetadata(LLVMContext::MD_alias_scope,
3756 MTI.getMetadata(LLVMContext::MD_alias_scope));
3757 call->setMetadata(LLVMContext::MD_noalias,
3758 MTI.getMetadata(LLVMContext::MD_noalias));
3759 call->setMetadata(LLVMContext::MD_tbaa,
3760 MTI.getMetadata(LLVMContext::MD_tbaa));
3761 call->setMetadata(LLVMContext::MD_tbaa_struct,
3762 MTI.getMetadata(LLVMContext::MD_tbaa_struct));
3763 call->setMetadata(LLVMContext::MD_invariant_group,
3764 MTI.getMetadata(LLVMContext::MD_invariant_group));
3765 call->setTailCallKind(MTI.getTailCallKind());
3766 };
3767
3768 if (Mode == DerivativeMode::ForwardMode ||
3770 applyChainRule(BuilderZ, fwd_rule, shadow_dst, shadow_src);
3771 else
3772 applyChainRule(BuilderZ, rev_rule, shadow_dst, shadow_src);
3773 }
3774
3775 eraseIfUnused(MTI);
3776 }
3777
3778 void visitFenceInst(llvm::FenceInst &FI) {
3779 using namespace llvm;
3780
3781 switch (Mode) {
3782 default:
3783 break;
3786 bool emitReverse = true;
3787 if (EnzymeJuliaAddrLoad) {
3788 if (auto prev = dyn_cast_or_null<CallBase>(FI.getPrevNode())) {
3789 if (auto F = prev->getCalledFunction())
3790 if (F->getName() == "julia.safepoint")
3791 emitReverse = false;
3792 }
3793 if (auto prev = dyn_cast_or_null<CallBase>(FI.getNextNode())) {
3794 if (auto F = prev->getCalledFunction())
3795 if (F->getName() == "julia.safepoint")
3796 emitReverse = false;
3797 }
3798 }
3799 if (emitReverse) {
3800 IRBuilder<> Builder2(&FI);
3801 getReverseBuilder(Builder2);
3802 auto order = FI.getOrdering();
3803 switch (order) {
3804 case AtomicOrdering::Acquire:
3805 order = AtomicOrdering::Release;
3806 break;
3807 case AtomicOrdering::Release:
3808 order = AtomicOrdering::Acquire;
3809 break;
3810 default:
3811 break;
3812 }
3813 Builder2.CreateFence(order, FI.getSyncScopeID());
3814 }
3815 }
3816 }
3817 eraseIfUnused(FI);
3818 }
3819
3820 void visitIntrinsicInst(llvm::IntrinsicInst &II) {
3821 using namespace llvm;
3822
3823 if (II.getIntrinsicID() == Intrinsic::stacksave) {
3824 eraseIfUnused(II, /*erase*/ true, /*check*/ false);
3825 return;
3826 }
3827 if (II.getIntrinsicID() == Intrinsic::stackrestore ||
3828 II.getIntrinsicID() == Intrinsic::lifetime_end ||
3829 II.getCalledFunction()->getName() == "llvm.enzyme.lifetime_end") {
3830 eraseIfUnused(II, /*erase*/ true, /*check*/ false);
3831 return;
3832 }
3833#if LLVM_VERSION_MAJOR >= 20
3834 if (II.getIntrinsicID() == Intrinsic::experimental_memset_pattern) {
3836 return;
3837 }
3838#endif
3839
3840 // When compiling Enzyme against standard LLVM, and not Intel's
3841 // modified version of LLVM, the intrinsic `llvm.intel.subscript` is
3842 // not fully understood by LLVM. One of the results of this is that the ID
3843 // of the intrinsic is set to Intrinsic::not_intrinsic - hence we are
3844 // handling the intrinsic here.
3845 if (isIntelSubscriptIntrinsic(II)) {
3850 }
3851 } else {
3852 SmallVector<Value *, 2> orig_ops(II.getNumOperands());
3853
3854 for (unsigned i = 0; i < II.getNumOperands(); ++i) {
3855 orig_ops[i] = II.getOperand(i);
3856 }
3857 if (handleAdjointForIntrinsic(II.getIntrinsicID(), II, orig_ops))
3858 return;
3859 }
3860 if (gutils->knownRecomputeHeuristic.find(&II) !=
3861 gutils->knownRecomputeHeuristic.end()) {
3862 if (!gutils->knownRecomputeHeuristic[&II]) {
3863 CallInst *const newCall =
3864 cast<CallInst>(gutils->getNewFromOriginal(&II));
3865 IRBuilder<> BuilderZ(newCall);
3866 BuilderZ.setFastMathFlags(getFast());
3867
3868 gutils->cacheForReverse(BuilderZ, newCall,
3869 getIndex(&II, CacheType::Self, BuilderZ));
3870 }
3871 }
3872 eraseIfUnused(II);
3873 }
3874
3875 bool
3876 handleAdjointForIntrinsic(llvm::Intrinsic::ID ID, llvm::Instruction &I,
3877 llvm::SmallVectorImpl<llvm::Value *> &orig_ops) {
3878 using namespace llvm;
3879
3880 Module *M = I.getParent()->getParent()->getParent();
3881
3882 switch (ID) {
3883#if LLVM_VERSION_MAJOR < 20
3884 case Intrinsic::nvvm_ldg_global_i:
3885 case Intrinsic::nvvm_ldg_global_p:
3886 case Intrinsic::nvvm_ldg_global_f:
3887#endif
3888 case Intrinsic::nvvm_ldu_global_i:
3889 case Intrinsic::nvvm_ldu_global_p:
3890 case Intrinsic::nvvm_ldu_global_f: {
3891 auto CI = cast<ConstantInt>(I.getOperand(1));
3892 visitLoadLike(I, /*Align*/ MaybeAlign(CI->getZExtValue()),
3893 /*constantval*/ false);
3894 return false;
3895 }
3896 default:
3897 break;
3898 }
3899
3900 if (ID == Intrinsic::masked_store) {
3901 auto align0 = cast<ConstantInt>(I.getOperand(2))->getZExtValue();
3902 auto align = MaybeAlign(align0);
3903 visitCommonStore(I, /*orig_ptr*/ I.getOperand(1),
3904 /*orig_val*/ I.getOperand(0), align,
3905 /*isVolatile*/ false, llvm::AtomicOrdering::NotAtomic,
3906 SyncScope::SingleThread,
3907 /*mask*/ gutils->getNewFromOriginal(I.getOperand(3)));
3908 return false;
3909 }
3910 if (ID == Intrinsic::masked_load) {
3911 auto align0 = cast<ConstantInt>(I.getOperand(1))->getZExtValue();
3912 auto align = MaybeAlign(align0);
3913 auto &DL = gutils->newFunc->getParent()->getDataLayout();
3914 bool constantval = parseTBAA(I, DL, nullptr)[{-1}].isIntegral();
3915 visitLoadLike(I, align, constantval,
3916 /*mask*/ gutils->getNewFromOriginal(I.getOperand(2)),
3917 /*orig_maskInit*/ I.getOperand(3));
3918 return false;
3919 }
3920
3921 auto mod = I.getParent()->getParent()->getParent();
3922 auto called = cast<CallInst>(&I)->getCalledFunction();
3923 (void)called;
3924 switch (ID) {
3925#include "IntrinsicDerivatives.inc"
3926 default:
3927 break;
3928 }
3929
3930 switch (Mode) {
3932 switch (ID) {
3933#if LLVM_VERSION_MAJOR <= 20
3934 case Intrinsic::nvvm_barrier0:
3935#else
3936 case Intrinsic::nvvm_barrier_cta_sync_aligned_all:
3937 case Intrinsic::nvvm_barrier_cta_sync_aligned_count:
3938#endif
3939#if LLVM_VERSION_MAJOR < 22
3940 case Intrinsic::nvvm_barrier0_popc:
3941 case Intrinsic::nvvm_barrier0_and:
3942 case Intrinsic::nvvm_barrier0_or:
3943#else
3944 case Intrinsic::nvvm_barrier_cta_red_and_aligned_all:
3945 case Intrinsic::nvvm_barrier_cta_red_and_aligned_count:
3946 case Intrinsic::nvvm_barrier_cta_red_or_aligned_all:
3947 case Intrinsic::nvvm_barrier_cta_red_or_aligned_count:
3948 case Intrinsic::nvvm_barrier_cta_red_popc_aligned_all:
3949 case Intrinsic::nvvm_barrier_cta_red_popc_aligned_count:
3950#endif
3951 case Intrinsic::nvvm_membar_cta:
3952 case Intrinsic::nvvm_membar_gl:
3953 case Intrinsic::nvvm_membar_sys:
3954 case Intrinsic::amdgcn_s_barrier:
3955 return false;
3956 default:
3957 if (gutils->isConstantInstruction(&I))
3958 return false;
3959 if (ID == Intrinsic::umax || ID == Intrinsic::smax ||
3960 ID == Intrinsic::abs || ID == Intrinsic::sadd_with_overflow ||
3961 ID == Intrinsic::uadd_with_overflow ||
3962 ID == Intrinsic::smul_with_overflow ||
3963 ID == Intrinsic::umul_with_overflow ||
3964 ID == Intrinsic::ssub_with_overflow ||
3965 ID == Intrinsic::usub_with_overflow)
3966 if (looseTypeAnalysis) {
3967 EmitWarning("CannotDeduceType", I,
3968 "failed to deduce type of intrinsic ", I);
3969 return false;
3970 }
3971 std::string s;
3972 llvm::raw_string_ostream ss(s);
3973 ss << *gutils->oldFunc << "\n";
3974 ss << *gutils->newFunc << "\n";
3975 ss << "cannot handle (augmented) unknown intrinsic\n" << I;
3976 IRBuilder<> BuilderZ(&I);
3977 getForwardBuilder(BuilderZ);
3978 EmitNoDerivativeError(ss.str(), I, gutils, BuilderZ);
3979 return false;
3980 }
3981 return false;
3982 }
3983
3986
3987 IRBuilder<> Builder2(&I);
3988 getReverseBuilder(Builder2);
3989
3990 Value *vdiff = nullptr;
3991 if (!gutils->isConstantValue(&I)) {
3992 vdiff = diffe(&I, Builder2);
3993 setDiffe(&I, Constant::getNullValue(gutils->getShadowType(I.getType())),
3994 Builder2);
3995 }
3996 (void)vdiff;
3997
3998 switch (ID) {
3999#if LLVM_VERSION_MAJOR < 22
4000 case Intrinsic::nvvm_barrier0_popc:
4001 case Intrinsic::nvvm_barrier0_and:
4002 case Intrinsic::nvvm_barrier0_or:
4003#else
4004 case Intrinsic::nvvm_barrier_cta_red_and_aligned_all:
4005 case Intrinsic::nvvm_barrier_cta_red_and_aligned_count:
4006 case Intrinsic::nvvm_barrier_cta_red_or_aligned_all:
4007 case Intrinsic::nvvm_barrier_cta_red_or_aligned_count:
4008 case Intrinsic::nvvm_barrier_cta_red_popc_aligned_all:
4009 case Intrinsic::nvvm_barrier_cta_red_popc_aligned_count:
4010#endif
4011 {
4012 SmallVector<Value *, 1> args = {};
4013#if LLVM_VERSION_MAJOR > 20
4014 auto cal = cast<CallInst>(Builder2.CreateCall(
4016 M, Intrinsic::nvvm_barrier_cta_sync_aligned_all),
4017 args));
4018 cal->setCallingConv(getIntrinsicDeclaration(
4019 M, Intrinsic::nvvm_barrier_cta_sync_aligned_all)
4020 ->getCallingConv());
4021#else
4022 auto cal = cast<CallInst>(Builder2.CreateCall(
4023 getIntrinsicDeclaration(M, Intrinsic::nvvm_barrier0), args));
4024 cal->setCallingConv(getIntrinsicDeclaration(M, Intrinsic::nvvm_barrier0)
4025 ->getCallingConv());
4026#endif
4027 cal->setDebugLoc(gutils->getNewFromOriginal(I.getDebugLoc()));
4028 return false;
4029 }
4030
4031#if LLVM_VERSION_MAJOR <= 20
4032 case Intrinsic::nvvm_barrier0:
4033#else
4034 case Intrinsic::nvvm_barrier_cta_sync_aligned_all:
4035 case Intrinsic::nvvm_barrier_cta_sync_aligned_count:
4036#endif
4037 case Intrinsic::amdgcn_s_barrier:
4038 case Intrinsic::nvvm_membar_cta:
4039 case Intrinsic::nvvm_membar_gl:
4040 case Intrinsic::nvvm_membar_sys: {
4041 SmallVector<Value *, 1> args = {};
4042 auto cal = cast<CallInst>(
4043 Builder2.CreateCall(getIntrinsicDeclaration(M, ID), args));
4044 cal->setCallingConv(getIntrinsicDeclaration(M, ID)->getCallingConv());
4045 cal->setDebugLoc(gutils->getNewFromOriginal(I.getDebugLoc()));
4046 return false;
4047 }
4048
4049 case Intrinsic::lifetime_start: {
4050 if (gutils->isConstantInstruction(&I))
4051 return false;
4052 SmallVector<Value *, 2> args = {
4053 lookup(gutils->getNewFromOriginal(orig_ops[0]), Builder2),
4054 lookup(gutils->getNewFromOriginal(orig_ops[1]), Builder2)};
4055 Type *tys[] = {args[1]->getType()};
4056 auto cal = Builder2.CreateCall(
4057 getIntrinsicDeclaration(M, Intrinsic::lifetime_end, tys), args);
4058 cal->setCallingConv(
4059 getIntrinsicDeclaration(M, Intrinsic::lifetime_end, tys)
4060 ->getCallingConv());
4061 return false;
4062 }
4063
4064 case Intrinsic::vector_reduce_fmax: {
4065 if (vdiff && !gutils->isConstantValue(orig_ops[0])) {
4066 auto prev = lookup(gutils->getNewFromOriginal(orig_ops[0]), Builder2);
4067 auto VT = cast<VectorType>(orig_ops[0]->getType());
4068
4069 assert(!VT->getElementCount().isScalable());
4070 size_t numElems = VT->getElementCount().getKnownMinValue();
4071 SmallVector<Value *> elems;
4072 SmallVector<Value *> cmps;
4073
4074 for (size_t i = 0; i < numElems; ++i)
4075 elems.push_back(Builder2.CreateExtractElement(prev, (uint64_t)i));
4076
4077 Value *curmax = elems[0];
4078 for (size_t i = 0; i < numElems - 1; ++i) {
4079 cmps.push_back(Builder2.CreateFCmpOLT(curmax, elems[i + 1]));
4080 if (i + 2 != numElems)
4081 curmax = CreateSelect(Builder2, cmps[i], elems[i + 1], curmax);
4082 }
4083
4084 auto rule = [&](Value *vdiff) {
4085 auto nv = Constant::getNullValue(orig_ops[0]->getType());
4086 Value *res = Builder2.CreateInsertElement(nv, vdiff, (uint64_t)0);
4087
4088 for (size_t i = 0; i < numElems - 1; ++i) {
4089 auto rhs_v = Builder2.CreateInsertElement(nv, vdiff, i + 1);
4090 res = CreateSelect(Builder2, cmps[i], rhs_v, res);
4091 }
4092 return res;
4093 };
4094 Value *dif0 =
4095 applyChainRule(orig_ops[0]->getType(), Builder2, rule, vdiff);
4096 addToDiffe(orig_ops[0], dif0, Builder2, I.getType());
4097 }
4098 return false;
4099 }
4100 default:
4101 if (gutils->isConstantInstruction(&I))
4102 return false;
4103 if (ID == Intrinsic::umax || ID == Intrinsic::smax ||
4104 ID == Intrinsic::abs || ID == Intrinsic::sadd_with_overflow ||
4105 ID == Intrinsic::uadd_with_overflow ||
4106 ID == Intrinsic::smul_with_overflow ||
4107 ID == Intrinsic::umul_with_overflow ||
4108 ID == Intrinsic::ssub_with_overflow ||
4109 ID == Intrinsic::usub_with_overflow)
4110 if (looseTypeAnalysis) {
4111 EmitWarning("CannotDeduceType", I,
4112 "failed to deduce type of intrinsic ", I);
4113 return false;
4114 }
4115 std::string s;
4116 llvm::raw_string_ostream ss(s);
4117 ss << *gutils->oldFunc << "\n";
4118 ss << *gutils->newFunc << "\n";
4119 if (Intrinsic::isOverloaded(ID))
4120 ss << "cannot handle (reverse) unknown intrinsic\n"
4121 << Intrinsic::getName(ID, ArrayRef<Type *>(),
4122 gutils->oldFunc->getParent(), nullptr)
4123 << "\n"
4124 << I;
4125 else
4126 ss << "cannot handle (reverse) unknown intrinsic\n"
4127 << Intrinsic::getName(ID) << "\n"
4128 << I;
4129 EmitNoDerivativeError(ss.str(), I, gutils, Builder2);
4130 return false;
4131 }
4132 return false;
4133 }
4137
4138 IRBuilder<> Builder2(&I);
4139 getForwardBuilder(Builder2);
4140
4141 switch (ID) {
4142
4143 case Intrinsic::vector_reduce_fmax: {
4144 if (gutils->isConstantInstruction(&I))
4145 return false;
4146 auto prev = gutils->getNewFromOriginal(orig_ops[0]);
4147 auto VT = cast<VectorType>(orig_ops[0]->getType());
4148
4149 assert(!VT->getElementCount().isScalable());
4150 size_t numElems = VT->getElementCount().getKnownMinValue();
4151 SmallVector<Value *> elems;
4152 SmallVector<Value *> cmps;
4153
4154 for (size_t i = 0; i < numElems; ++i)
4155 elems.push_back(Builder2.CreateExtractElement(prev, (uint64_t)i));
4156
4157 Value *curmax = elems[0];
4158 for (size_t i = 0; i < numElems - 1; ++i) {
4159 cmps.push_back(Builder2.CreateFCmpOLT(curmax, elems[i + 1]));
4160 if (i + 2 != numElems)
4161 curmax = CreateSelect(Builder2, cmps[i], elems[i + 1], curmax);
4162 }
4163
4164 auto rule = [&](Value *vdiff) {
4165 Value *res = Builder2.CreateExtractElement(vdiff, (uint64_t)0);
4166
4167 for (size_t i = 0; i < numElems - 1; ++i) {
4168 auto rhs_v = Builder2.CreateExtractElement(vdiff, i + 1);
4169 res = CreateSelect(Builder2, cmps[i], rhs_v, res);
4170 }
4171 return res;
4172 };
4173 auto vdiff = diffe(orig_ops[0], Builder2);
4174
4175 Value *dif = applyChainRule(I.getType(), Builder2, rule, vdiff);
4176 setDiffe(&I, dif, Builder2);
4177 return false;
4178 }
4179 default:
4180 if (!gutils->isConstantValue(&I)) {
4181 auto toset =
4182 Constant::getNullValue(gutils->getShadowType(I.getType()));
4183 setDiffe(&I, toset, Builder2);
4184 }
4185 if (gutils->isConstantInstruction(&I))
4186 return false;
4187 if (ID == Intrinsic::umax || ID == Intrinsic::smax ||
4188 ID == Intrinsic::abs || ID == Intrinsic::sadd_with_overflow ||
4189 ID == Intrinsic::uadd_with_overflow ||
4190 ID == Intrinsic::smul_with_overflow ||
4191 ID == Intrinsic::umul_with_overflow ||
4192 ID == Intrinsic::ssub_with_overflow ||
4193 ID == Intrinsic::usub_with_overflow)
4194 if (looseTypeAnalysis) {
4195 EmitWarning("CannotDeduceType", I,
4196 "failed to deduce type of intrinsic ", I);
4197 return false;
4198 }
4199 std::string s;
4200 llvm::raw_string_ostream ss(s);
4201 if (Intrinsic::isOverloaded(ID))
4202 ss << "cannot handle (forward) unknown intrinsic\n"
4203 << Intrinsic::getName(ID, ArrayRef<Type *>(),
4204 gutils->oldFunc->getParent(), nullptr)
4205 << "\n"
4206 << I;
4207 else
4208 ss << "cannot handle (forward) unknown intrinsic\n"
4209 << Intrinsic::getName(ID) << "\n"
4210 << I;
4211 EmitNoDerivativeError(ss.str(), I, gutils, Builder2);
4212 return false;
4213 }
4214 return false;
4215 }
4216 }
4217
4218 return false;
4219 }
4220
4221// first one allows adding attributes to blas functions declared in the second
4222#include "BlasAttributor.inc"
4223#include "BlasDerivatives.inc"
4224
4225 void visitOMPCall(llvm::CallInst &call) {
4226 using namespace llvm;
4227
4228 Function *kmpc = call.getCalledFunction();
4229
4230 if (overwritten_args_map.find(&call) == overwritten_args_map.end()) {
4231 llvm::errs() << " call: " << call << "\n";
4232 for (auto &pair : overwritten_args_map) {
4233 llvm::errs() << " + " << *pair.first << "\n";
4234 }
4235 }
4236
4237 auto found_ow = overwritten_args_map.find(&call);
4238 assert(found_ow != overwritten_args_map.end());
4239 const bool subsequent_calls_may_write = found_ow->second.first;
4240 const std::vector<bool> &overwritten_args = found_ow->second.second;
4241
4242 IRBuilder<> BuilderZ(gutils->getNewFromOriginal(&call));
4243 BuilderZ.setFastMathFlags(getFast());
4244
4245 Function *task = dyn_cast<Function>(call.getArgOperand(2));
4246 if (task == nullptr && isa<ConstantExpr>(call.getArgOperand(2))) {
4247 task = dyn_cast<Function>(
4248 cast<ConstantExpr>(call.getArgOperand(2))->getOperand(0));
4249 }
4250 if (task == nullptr) {
4251 llvm::errs() << "could not derive underlying task from omp call: " << call
4252 << "\n";
4253 llvm_unreachable("could not derive underlying task from omp call");
4254 }
4255 if (task->empty()) {
4256 llvm::errs()
4257 << "could not derive underlying task contents from omp call: " << call
4258 << "\n";
4259 llvm_unreachable(
4260 "could not derive underlying task contents from omp call");
4261 }
4262
4263 auto called = task;
4264 // bool modifyPrimal = true;
4265
4266 bool foreignFunction = called == nullptr;
4267
4268 SmallVector<Value *, 8> args = {0, 0, 0};
4269 SmallVector<Value *, 8> pre_args = {0, 0, 0};
4270 std::vector<DIFFE_TYPE> argsInverted = {DIFFE_TYPE::CONSTANT,
4272 SmallVector<Instruction *, 4> postCreate;
4273 SmallVector<Instruction *, 4> userReplace;
4274
4275 SmallVector<Value *, 4> OutTypes;
4276 SmallVector<Type *, 4> OutFPTypes;
4277
4278 for (unsigned i = 3; i < call.arg_size(); ++i) {
4279
4280 auto argi = gutils->getNewFromOriginal(call.getArgOperand(i));
4281
4282 pre_args.push_back(argi);
4283
4285 IRBuilder<> Builder2(&call);
4286 getReverseBuilder(Builder2);
4287 args.push_back(lookup(argi, Builder2));
4288 }
4289
4290 auto argTy = gutils->getDiffeType(call.getArgOperand(i), foreignFunction);
4291 argsInverted.push_back(argTy);
4292
4293 if (argTy == DIFFE_TYPE::CONSTANT) {
4294 continue;
4295 }
4296
4297 auto argType = argi->getType();
4298
4299 if (argTy == DIFFE_TYPE::DUP_ARG || argTy == DIFFE_TYPE::DUP_NONEED) {
4301 IRBuilder<> Builder2(&call);
4302 getReverseBuilder(Builder2);
4303 args.push_back(
4304 lookup(gutils->invertPointerM(call.getArgOperand(i), Builder2),
4305 Builder2));
4306 }
4307 pre_args.push_back(
4308 gutils->invertPointerM(call.getArgOperand(i), BuilderZ));
4309
4310 // Note sometimes whattype mistakenly says something should be constant
4311 // [because composed of integer pointers alone]
4312 assert(whatType(argType, Mode) == DIFFE_TYPE::DUP_ARG ||
4313 whatType(argType, Mode) == DIFFE_TYPE::CONSTANT);
4314 } else {
4315 assert(TR.query(call.getArgOperand(i))[{-1}].isFloat());
4316 OutTypes.push_back(call.getArgOperand(i));
4317 OutFPTypes.push_back(argType);
4318 assert(whatType(argType, Mode) == DIFFE_TYPE::OUT_DIFF ||
4319 whatType(argType, Mode) == DIFFE_TYPE::CONSTANT);
4320 }
4321 }
4322
4323 DIFFE_TYPE subretType = DIFFE_TYPE::CONSTANT;
4324
4325 Value *tape = nullptr;
4326 CallInst *augmentcall = nullptr;
4327 // Value *cachereplace = nullptr;
4328
4329 // TODO consider reduction of int 0 args
4330 FnTypeInfo nextTypeInfo(called);
4331
4332 if (called) {
4333 std::map<Value *, std::set<int64_t>> intseen;
4334
4335 TypeTree IntPtr;
4336 IntPtr.insert({-1, -1}, BaseType::Integer);
4337 IntPtr.insert({-1}, BaseType::Pointer);
4338
4339 int argnum = 0;
4340 for (auto &arg : called->args()) {
4341 if (argnum <= 1) {
4342 nextTypeInfo.Arguments.insert(
4343 std::pair<Argument *, TypeTree>(&arg, IntPtr));
4344 nextTypeInfo.KnownValues.insert(
4345 std::pair<Argument *, std::set<int64_t>>(&arg, {0}));
4346 } else {
4347 nextTypeInfo.Arguments.insert(std::pair<Argument *, TypeTree>(
4348 &arg, TR.query(call.getArgOperand(argnum - 2 + 3))));
4349 nextTypeInfo.KnownValues.insert(
4350 std::pair<Argument *, std::set<int64_t>>(
4351 &arg,
4352 TR.knownIntegralValues(call.getArgOperand(argnum - 2 + 3))));
4353 }
4354
4355 ++argnum;
4356 }
4357 nextTypeInfo.Return = TR.query(&call);
4358 }
4359
4360 // std::optional<std::map<std::pair<Instruction*, std::string>, unsigned>>
4361 // sub_index_map;
4362 // Optional<int> tapeIdx;
4363 // Optional<int> returnIdx;
4364 // Optional<int> differetIdx;
4365
4366 const AugmentedReturn *subdata = nullptr;
4368 assert(augmentedReturn);
4369 if (augmentedReturn) {
4370 auto fd = augmentedReturn->subaugmentations.find(&call);
4371 if (fd != augmentedReturn->subaugmentations.end()) {
4372 subdata = fd->second;
4373 }
4374 }
4375 }
4376
4379 if (called) {
4380 subdata = &gutils->Logic.CreateAugmentedPrimal(
4381 RequestContext(&call, &BuilderZ), cast<Function>(called),
4382 subretType, argsInverted, TR.analyzer->interprocedural,
4383 /*return is used*/ false,
4384 /*shadowReturnUsed*/ false, nextTypeInfo,
4385 subsequent_calls_may_write, overwritten_args, false,
4386 gutils->runtimeActivity, gutils->strongZero, gutils->getWidth(),
4387 /*AtomicAdd*/ true,
4388 /*OpenMP*/ true);
4390 assert(augmentedReturn);
4391 auto subaugmentations =
4392 (std::map<const llvm::CallInst *, AugmentedReturn *>
4393 *)&augmentedReturn->subaugmentations;
4395 *subaugmentations, &call, (AugmentedReturn *)subdata);
4396 }
4397
4398 assert(subdata);
4399 auto newcalled = subdata->fn;
4400
4401 if (subdata->returns.find(AugmentedStruct::Tape) !=
4402 subdata->returns.end()) {
4403 ValueToValueMapTy VMap;
4404 newcalled = CloneFunction(newcalled, VMap);
4405 auto tapeArg = newcalled->arg_end();
4406 tapeArg--;
4407 Type *tapeElemType = subdata->tapeType;
4408 SmallVector<std::pair<ssize_t, Value *>, 4> geps;
4409 SmallPtrSet<Instruction *, 4> gepsToErase;
4410 for (auto a : tapeArg->users()) {
4411 if (auto gep = dyn_cast<GetElementPtrInst>(a)) {
4412 auto idx = gep->idx_begin();
4413 idx++;
4414 auto cidx = cast<ConstantInt>(idx->get());
4415 assert(gep->getNumIndices() == 2);
4416 SmallPtrSet<StoreInst *, 1> storesToErase;
4417 for (auto st : gep->users()) {
4418 auto SI = cast<StoreInst>(st);
4419 Value *op = SI->getValueOperand();
4420 storesToErase.insert(SI);
4421 geps.emplace_back(cidx->getLimitedValue(), op);
4422 }
4423 for (auto SI : storesToErase)
4424 SI->eraseFromParent();
4425 gepsToErase.insert(gep);
4426 } else if (auto SI = dyn_cast<StoreInst>(a)) {
4427 Value *op = SI->getValueOperand();
4428 gepsToErase.insert(SI);
4429 geps.emplace_back(-1, op);
4430 } else {
4431 llvm::errs() << "unknown tape user: " << a << "\n";
4432 assert(0 && "unknown tape user");
4433 llvm_unreachable("unknown tape user");
4434 }
4435 }
4436 for (auto gep : gepsToErase)
4437 gep->eraseFromParent();
4438 IRBuilder<> ph(&*newcalled->getEntryBlock().begin());
4439 tape = UndefValue::get(tapeElemType);
4440 ValueToValueMapTy available;
4441 auto subarg = newcalled->arg_begin();
4442 subarg++;
4443 subarg++;
4444 for (size_t i = 3; i < pre_args.size(); ++i) {
4445 available[&*subarg] = pre_args[i];
4446 subarg++;
4447 }
4448 for (auto pair : geps) {
4449 Value *op = pair.second;
4450 Value *alloc = op;
4451 Value *replacement = gutils->unwrapM(op, BuilderZ, available,
4453 tape =
4454 pair.first == -1
4455 ? replacement
4456 : BuilderZ.CreateInsertValue(tape, replacement, pair.first);
4457 if (auto ci = dyn_cast<CastInst>(alloc)) {
4458 alloc = ci->getOperand(0);
4459 }
4460 if (auto uload = dyn_cast<Instruction>(replacement)) {
4461 gutils->unwrappedLoads.erase(uload);
4462 if (auto ci = dyn_cast<CastInst>(replacement)) {
4463 if (auto ucast = dyn_cast<Instruction>(ci->getOperand(0)))
4464 gutils->unwrappedLoads.erase(ucast);
4465 }
4466 }
4467 if (auto ci = dyn_cast<CallInst>(alloc)) {
4468 if (auto F = ci->getCalledFunction()) {
4469 // Store cached values
4470 if (F->getName() == "malloc") {
4471 const_cast<AugmentedReturn *>(subdata)
4472 ->tapeIndiciesToFree.emplace(pair.first);
4473 Value *toload = tapeArg;
4474 if (pair.first != -1) {
4475 Value *Idxs[] = {
4476 ConstantInt::get(
4477 Type::getInt64Ty(tapeArg->getContext()), 0),
4478 ConstantInt::get(
4479 Type::getInt32Ty(tapeArg->getContext()),
4480 pair.first)};
4481 toload = ph.CreateInBoundsGEP(tapeElemType, toload, Idxs);
4482 }
4483 op->replaceAllUsesWith(ph.CreateLoad(op->getType(), toload));
4484 cast<Instruction>(op)->eraseFromParent();
4485 if (op != alloc)
4486 ci->eraseFromParent();
4487 continue;
4488 }
4489 }
4490 }
4491 Value *Idxs[] = {
4492 ConstantInt::get(Type::getInt64Ty(tapeArg->getContext()), 0),
4493 ConstantInt::get(Type::getInt32Ty(tapeArg->getContext()),
4494 pair.first)};
4495 op->replaceAllUsesWith(ph.CreateLoad(
4496 op->getType(),
4497 pair.first == -1
4498 ? tapeArg
4499 : ph.CreateInBoundsGEP(tapeElemType, tapeArg, Idxs)));
4500 cast<Instruction>(op)->eraseFromParent();
4501 }
4502 assert(tape);
4503 auto alloc =
4504 IRBuilder<>(gutils->inversionAllocs).CreateAlloca(tapeElemType);
4505 BuilderZ.CreateStore(tape, alloc);
4506 pre_args.push_back(alloc);
4507 assert(tape);
4508 gutils->cacheForReverse(BuilderZ, tape,
4509 getIndex(&call, CacheType::Tape, BuilderZ));
4510 }
4511
4512 auto numargs = ConstantInt::get(Type::getInt32Ty(call.getContext()),
4513 pre_args.size() - 3);
4514 pre_args[0] = gutils->getNewFromOriginal(call.getArgOperand(0));
4515 pre_args[1] = numargs;
4516 pre_args[2] = BuilderZ.CreatePointerCast(
4517 newcalled, kmpc->getFunctionType()->getParamType(2));
4518 augmentcall =
4519 BuilderZ.CreateCall(kmpc->getFunctionType(), kmpc, pre_args);
4520 augmentcall->setCallingConv(call.getCallingConv());
4521 augmentcall->setDebugLoc(
4522 gutils->getNewFromOriginal(call.getDebugLoc()));
4523 BuilderZ.SetInsertPoint(
4524 gutils->getNewFromOriginal(&call)->getNextNode());
4525 gutils->erase(gutils->getNewFromOriginal(&call));
4526 } else {
4527 assert(0 && "unhandled unknown outline");
4528 }
4529 }
4530
4531 {
4532 Intrinsic::ID ID = Intrinsic::not_intrinsic;
4533 if (!subdata && !isMemFreeLibMFunction(getFuncNameFromCall(&call), &ID)) {
4534 llvm::errs() << *gutils->oldFunc->getParent() << "\n";
4535 llvm::errs() << *gutils->oldFunc << "\n";
4536 llvm::errs() << *gutils->newFunc << "\n";
4537 llvm::errs() << *called << "\n";
4538 llvm_unreachable("no subdata");
4539 }
4540 }
4541
4542 if (subdata) {
4543 auto found = subdata->returns.find(AugmentedStruct::DifferentialReturn);
4544 assert(found == subdata->returns.end());
4545 }
4546 if (subdata) {
4547 auto found = subdata->returns.find(AugmentedStruct::Return);
4548 assert(found == subdata->returns.end());
4549 }
4550
4553 IRBuilder<> Builder2(&call);
4554 getReverseBuilder(Builder2);
4555
4557 BuilderZ.SetInsertPoint(
4558 gutils->getNewFromOriginal(&call)->getNextNode());
4559 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
4560 }
4561
4562 Function *newcalled = nullptr;
4563 if (called) {
4564 if (subdata && subdata->returns.find(AugmentedStruct::Tape) !=
4565 subdata->returns.end()) {
4567 if (tape == nullptr) {
4568#if LLVM_VERSION_MAJOR >= 18
4569 auto It = BuilderZ.GetInsertPoint();
4570 It.setHeadBit(true);
4571 BuilderZ.SetInsertPoint(It);
4572#endif
4573 tape = BuilderZ.CreatePHI(subdata->tapeType, 0, "tapeArg");
4574 }
4575 tape = gutils->cacheForReverse(
4576 BuilderZ, tape, getIndex(&call, CacheType::Tape, BuilderZ));
4577 }
4578 tape = lookup(tape, Builder2);
4579 auto alloc = IRBuilder<>(gutils->inversionAllocs)
4580 .CreateAlloca(tape->getType());
4581 Builder2.CreateStore(tape, alloc);
4582 args.push_back(alloc);
4583 }
4584
4585 if (Mode == DerivativeMode::ReverseModeGradient && subdata) {
4586 for (size_t i = 0; i < argsInverted.size(); i++) {
4587 if (subdata->constant_args[i] == argsInverted[i])
4588 continue;
4589 assert(subdata->constant_args[i] == DIFFE_TYPE::DUP_ARG);
4590 assert(argsInverted[i] == DIFFE_TYPE::DUP_NONEED);
4591 argsInverted[i] = DIFFE_TYPE::DUP_ARG;
4592 }
4593 }
4594
4595 newcalled = gutils->Logic.CreatePrimalAndGradient(
4596 RequestContext(&call, &Builder2),
4598 .todiff = cast<Function>(called),
4599 .retType = subretType,
4600 .constant_args = argsInverted,
4601 .subsequent_calls_may_write = subsequent_calls_may_write,
4602 .overwritten_args = overwritten_args,
4603 .returnUsed = false,
4604 .shadowReturnUsed = false,
4606 .width = gutils->getWidth(),
4607 .freeMemory = true,
4608 .AtomicAdd = true,
4609 .additionalType = tape ? getUnqual(tape->getType()) : nullptr,
4610 .forceAnonymousTape = false,
4611 .typeInfo = nextTypeInfo,
4612 .runtimeActivity = gutils->runtimeActivity,
4613 .strongZero = gutils->strongZero},
4614 TR.analyzer->interprocedural, subdata,
4615 /*omp*/ true);
4616
4617 if (subdata && subdata->returns.find(AugmentedStruct::Tape) !=
4618 subdata->returns.end()) {
4619 auto tapeArg = newcalled->arg_end();
4620 tapeArg--;
4621 LoadInst *tape = nullptr;
4622 for (auto u : tapeArg->users()) {
4623 assert(!tape);
4624 if (!isa<LoadInst>(u)) {
4625 llvm::errs() << " newcalled: " << *newcalled << "\n";
4626 llvm::errs() << " u: " << *u << "\n";
4627 }
4628 tape = cast<LoadInst>(u);
4629 }
4630 assert(tape);
4631 SmallVector<Value *, 4> extracts;
4632 if (subdata->tapeIndices.size() == 1) {
4633 assert(subdata->tapeIndices.begin()->second == -1);
4634 extracts.push_back(tape);
4635 } else {
4636 for (auto a : tape->users()) {
4637 extracts.push_back(a);
4638 }
4639 }
4640 SmallVector<LoadInst *, 4> geps;
4641 for (auto E : extracts) {
4642 AllocaInst *AI = nullptr;
4643 for (auto U : E->users()) {
4644 if (auto SI = dyn_cast<StoreInst>(U)) {
4645 assert(SI->getValueOperand() == E);
4646 AI = cast<AllocaInst>(SI->getPointerOperand());
4647 }
4648 }
4649 if (AI) {
4650 for (auto U : AI->users()) {
4651 if (auto LI = dyn_cast<LoadInst>(U)) {
4652 geps.push_back(LI);
4653 }
4654 }
4655 }
4656 }
4657 for (auto LI : geps) {
4658 CallInst *freeCall = nullptr;
4659 for (auto LU : LI->users()) {
4660 if (auto CI = dyn_cast<CallInst>(LU)) {
4661 if (auto F = CI->getCalledFunction()) {
4662 if (F->getName() == "free") {
4663 freeCall = CI;
4664 break;
4665 }
4666 }
4667 } else if (auto BC = dyn_cast<CastInst>(LU)) {
4668 for (auto CU : BC->users()) {
4669 if (auto CI = dyn_cast<CallInst>(CU)) {
4670 if (auto F = CI->getCalledFunction()) {
4671 if (F->getName() == "free") {
4672 freeCall = CI;
4673 break;
4674 }
4675 }
4676 }
4677 }
4678 if (freeCall)
4679 break;
4680 }
4681 }
4682 if (freeCall) {
4683 freeCall->eraseFromParent();
4684 }
4685 }
4686 }
4687
4688 Value *OutAlloc = nullptr;
4689 auto ST = StructType::get(newcalled->getContext(), OutFPTypes);
4690 if (OutTypes.size()) {
4691 OutAlloc = IRBuilder<>(gutils->inversionAllocs).CreateAlloca(ST);
4692 args.push_back(OutAlloc);
4693
4694 SmallVector<Type *, 3> MetaTypes;
4695 for (auto P :
4696 cast<Function>(newcalled)->getFunctionType()->params()) {
4697 MetaTypes.push_back(P);
4698 }
4699 MetaTypes.push_back(getUnqual(ST));
4700 auto FT = FunctionType::get(Type::getVoidTy(newcalled->getContext()),
4701 MetaTypes, false);
4702 Function *F =
4703 Function::Create(FT, GlobalVariable::InternalLinkage,
4704 cast<Function>(newcalled)->getName() + "#out",
4705 *task->getParent());
4706 BasicBlock *entry =
4707 BasicBlock::Create(newcalled->getContext(), "entry", F);
4708 IRBuilder<> B(entry);
4709 SmallVector<Value *, 2> SubArgs;
4710 for (auto &arg : F->args())
4711 SubArgs.push_back(&arg);
4712 Value *cacheArg = SubArgs.back();
4713 SubArgs.pop_back();
4714 Value *outdiff = B.CreateCall(newcalled, SubArgs);
4715 for (size_t ee = 0; ee < OutTypes.size(); ee++) {
4716 Value *dif = B.CreateExtractValue(outdiff, ee);
4717 Value *Idxs[] = {
4718 ConstantInt::get(Type::getInt64Ty(ST->getContext()), 0),
4719 ConstantInt::get(Type::getInt32Ty(ST->getContext()), ee)};
4720 Value *ptr = B.CreateInBoundsGEP(ST, cacheArg, Idxs);
4721
4722 if (dif->getType()->isIntOrIntVectorTy()) {
4723
4724 ptr = B.CreateBitCast(
4725 ptr,
4727 IntToFloatTy(dif->getType()),
4728 cast<PointerType>(ptr->getType())->getAddressSpace()));
4729 dif = B.CreateBitCast(dif, IntToFloatTy(dif->getType()));
4730 }
4731
4732 MaybeAlign align;
4733 AtomicRMWInst::BinOp op = AtomicRMWInst::FAdd;
4734 if (auto vt = dyn_cast<VectorType>(dif->getType())) {
4735 assert(!vt->getElementCount().isScalable());
4736 size_t numElems = vt->getElementCount().getKnownMinValue();
4737 for (size_t i = 0; i < numElems; ++i) {
4738 auto vdif = B.CreateExtractElement(dif, i);
4739 Value *Idxs[] = {
4740 ConstantInt::get(Type::getInt64Ty(vt->getContext()), 0),
4741 ConstantInt::get(Type::getInt32Ty(vt->getContext()), i)};
4742 auto vptr = B.CreateInBoundsGEP(vt, ptr, Idxs);
4743 B.CreateAtomicRMW(op, vptr, vdif, align,
4744 AtomicOrdering::Monotonic, SyncScope::System);
4745 }
4746 } else {
4747 B.CreateAtomicRMW(op, ptr, dif, align, AtomicOrdering::Monotonic,
4748 SyncScope::System);
4749 }
4750 }
4751 B.CreateRetVoid();
4752 newcalled = F;
4753 }
4754
4755 auto numargs = ConstantInt::get(Type::getInt32Ty(call.getContext()),
4756 args.size() - 3);
4757 args[0] =
4758 lookup(gutils->getNewFromOriginal(call.getArgOperand(0)), Builder2);
4759 args[1] = numargs;
4760 args[2] = Builder2.CreatePointerCast(
4761 newcalled, kmpc->getFunctionType()->getParamType(2));
4762
4763 CallInst *diffes =
4764 Builder2.CreateCall(kmpc->getFunctionType(), kmpc, args);
4765 diffes->setCallingConv(call.getCallingConv());
4766 diffes->setDebugLoc(gutils->getNewFromOriginal(call.getDebugLoc()));
4767
4768 for (size_t i = 0; i < OutTypes.size(); i++) {
4769
4770 size_t size = 1;
4771 if (OutTypes[i]->getType()->isSized())
4772 size = (gutils->newFunc->getParent()
4773 ->getDataLayout()
4774 .getTypeSizeInBits(OutTypes[i]->getType()) +
4775 7) /
4776 8;
4777 Value *Idxs[] = {
4778 ConstantInt::get(Type::getInt64Ty(call.getContext()), 0),
4779 ConstantInt::get(Type::getInt32Ty(call.getContext()), i)};
4780 ((DiffeGradientUtils *)gutils)
4781 ->addToDiffe(OutTypes[i],
4782 Builder2.CreateLoad(
4783 OutFPTypes[i],
4784 Builder2.CreateInBoundsGEP(ST, OutAlloc, Idxs)),
4785 Builder2, TR.addingType(size, OutTypes[i]));
4786 }
4787
4788 if (tape && shouldFree()) {
4789 for (auto idx : subdata->tapeIndiciesToFree) {
4790 CreateDealloc(Builder2,
4791 idx == -1 ? tape
4792 : Builder2.CreateExtractValue(tape, idx));
4793 }
4794 }
4795 } else {
4796 assert(0 && "openmp indirect unhandled");
4797 }
4798 }
4799 }
4800
4802 llvm::CallInst &call, llvm::Value *origArg, llvm::Value *dsto,
4803 llvm::Value *srco, llvm::Value *len_arg, llvm::IRBuilder<> &Builder2,
4804 llvm::ArrayRef<llvm::OperandBundleDef> ReverseDefs) {
4805 using namespace llvm;
4806
4807 size_t size = 1;
4808 if (auto ci = dyn_cast<ConstantInt>(len_arg)) {
4809 size = ci->getLimitedValue();
4810 }
4811 auto &DL = gutils->newFunc->getParent()->getDataLayout();
4812 auto vd = TR.query(origArg).Data0().ShiftIndices(DL, 0, size, 0);
4813 if (!vd.isKnownPastPointer()) {
4814#if LLVM_VERSION_MAJOR < 17
4815 if (looseTypeAnalysis) {
4816 if (isa<CastInst>(origArg) &&
4817 cast<CastInst>(origArg)->getSrcTy()->isPointerTy() &&
4818 cast<CastInst>(origArg)
4819 ->getSrcTy()
4820 ->getPointerElementType()
4821 ->isFPOrFPVectorTy()) {
4822 vd = TypeTree(ConcreteType(cast<CastInst>(origArg)
4823 ->getSrcTy()
4824 ->getPointerElementType()
4825 ->getScalarType()))
4826 .Only(0, &call);
4827 goto knownF;
4828 }
4829 }
4830#endif
4831 TR.dump();
4832 EmitFailure("CannotDeduceType", call.getDebugLoc(), &call,
4833 "failed to deduce type of copy ", call);
4834 }
4835#if LLVM_VERSION_MAJOR < 17
4836 knownF:
4837#endif
4838 unsigned start = 0;
4839 while (1) {
4840 unsigned nextStart = size;
4841
4842 auto dt = vd[{-1}];
4843 for (size_t i = start; i < size; ++i) {
4844 bool Legal = true;
4845 dt.checkedOrIn(vd[{(int)i}], /*PointerIntSame*/ true, Legal);
4846 if (!Legal) {
4847 nextStart = i;
4848 break;
4849 }
4850 }
4851 if (!dt.isKnown()) {
4852 TR.dump();
4853 llvm::errs() << " vd:" << vd.str() << " start:" << start
4854 << " size: " << size << " dt:" << dt.str() << "\n";
4855 }
4856 assert(dt.isKnown());
4857
4858 Value *length = len_arg;
4859 if (nextStart != size) {
4860 length = ConstantInt::get(len_arg->getType(), nextStart);
4861 }
4862 if (start != 0)
4863 length = Builder2.CreateSub(
4864 length, ConstantInt::get(len_arg->getType(), start));
4865
4866 if (auto secretty = dt.isFloat()) {
4867 auto offset = start;
4868 if (dsto->getType()->isIntegerTy())
4869 dsto =
4870 Builder2.CreateIntToPtr(dsto, getInt8PtrTy(dsto->getContext()));
4871 unsigned dstaddr =
4872 cast<PointerType>(dsto->getType())->getAddressSpace();
4873 auto secretpt = getPointerType(secretty, dstaddr);
4874 if (offset != 0) {
4875 dsto = Builder2.CreateConstInBoundsGEP1_64(
4876 Type::getInt8Ty(dsto->getContext()), dsto, offset);
4877 }
4878 if (srco->getType()->isIntegerTy())
4879 srco =
4880 Builder2.CreateIntToPtr(srco, getInt8PtrTy(dsto->getContext()));
4881 unsigned srcaddr =
4882 cast<PointerType>(srco->getType())->getAddressSpace();
4883 secretpt = getPointerType(secretty, srcaddr);
4884
4885 if (offset != 0) {
4886 srco = Builder2.CreateConstInBoundsGEP1_64(
4887 Type::getInt8Ty(srco->getContext()), srco, offset);
4888 }
4889 Value *args[3] = {
4890 Builder2.CreatePointerCast(dsto, secretpt),
4891 Builder2.CreatePointerCast(srco, secretpt),
4892 Builder2.CreateUDiv(
4893 length,
4894
4895 ConstantInt::get(length->getType(),
4896 Builder2.GetInsertBlock()
4897 ->getParent()
4898 ->getParent()
4899 ->getDataLayout()
4900 .getTypeAllocSizeInBits(secretty) /
4901 8))};
4902
4904 *Builder2.GetInsertBlock()->getParent()->getParent(), secretty,
4905 /*dstalign*/ 1, /*srcalign*/ 1, dstaddr, srcaddr,
4906 cast<IntegerType>(length->getType())->getBitWidth());
4907
4908 Builder2.CreateCall(dmemcpy, args, ReverseDefs);
4909 }
4910
4911 if (nextStart == size)
4912 break;
4913 start = nextStart;
4914 }
4915 }
4916
4917 void recursivelyHandleSubfunction(llvm::CallInst &call,
4918 llvm::Function *called,
4919 bool subsequent_calls_may_write,
4920 const std::vector<bool> &overwritten_args,
4921 bool shadowReturnUsed,
4922 DIFFE_TYPE subretType, bool subretused) {
4923 using namespace llvm;
4924
4925 IRBuilder<> BuilderZ(gutils->getNewFromOriginal(&call));
4926 BuilderZ.setFastMathFlags(getFast());
4927
4928 CallInst *newCall = cast<CallInst>(gutils->getNewFromOriginal(&call));
4929 Module &M = *call.getParent()->getParent()->getParent();
4930
4931 bool foreignFunction = called == nullptr;
4932
4933 FnTypeInfo nextTypeInfo(called);
4934
4935 if (called) {
4936 nextTypeInfo = TR.getCallInfo(call, *called);
4937 }
4938
4939 const AugmentedReturn *subdata = nullptr;
4942 assert(augmentedReturn);
4943 if (augmentedReturn) {
4944 auto fd = augmentedReturn->subaugmentations.find(&call);
4945 if (fd != augmentedReturn->subaugmentations.end()) {
4946 subdata = fd->second;
4947 }
4948 }
4949 }
4950
4951 if (Mode == DerivativeMode::ForwardMode ||
4954 IRBuilder<> Builder2(&call);
4955 getForwardBuilder(Builder2);
4956
4957 SmallVector<Value *, 8> args;
4958 std::vector<DIFFE_TYPE> argsInverted;
4959 std::map<int, Type *> gradByVal;
4960 std::map<int, std::vector<Attribute>> structAttrs;
4961
4962 for (unsigned i = 0; i < call.arg_size(); ++i) {
4963
4964 if (call.paramHasAttr(i, Attribute::StructRet)) {
4965 structAttrs[args.size()].push_back(Attribute::get(
4966 call.getContext(), "enzyme_sret",
4967 convertSRetTypeToString(call.getParamAttr(i, Attribute::StructRet)
4968 .getValueAsType())));
4969 }
4970 for (auto attr : {"enzymejl_returnRoots", "enzymejl_parmtype",
4971 "enzymejl_parmtype_ref", "enzyme_type",
4972 "enzymejl_sret_union_bytes", "enzymejl_rooted_typ"})
4973 if (call.getAttributes().hasParamAttr(i, attr)) {
4974 structAttrs[args.size()].push_back(call.getParamAttr(i, attr));
4975 }
4976 for (auto ty : PrimalParamAttrsToPreserve)
4977 if (call.getAttributes().hasParamAttr(i, ty)) {
4978 auto attr = call.getAttributes().getParamAttr(i, ty);
4979 structAttrs[args.size()].push_back(attr);
4980 }
4981
4982 auto argi = gutils->getNewFromOriginal(call.getArgOperand(i));
4983
4984 if (call.isByValArgument(i)) {
4985 gradByVal[args.size()] = call.getParamByValType(i);
4986 }
4987
4988 bool writeOnlyNoCapture = true;
4989 bool readOnly = true;
4990 if (!isNoCapture(&call, i)) {
4991 writeOnlyNoCapture = false;
4992 }
4993 if (!isWriteOnly(&call, i)) {
4994 writeOnlyNoCapture = false;
4995 }
4996 if (!isReadOnly(&call, i)) {
4997 readOnly = false;
4998 }
4999
5000 if (shouldDisableNoWrite(&call))
5001 writeOnlyNoCapture = false;
5002
5003 auto argTy =
5004 gutils->getDiffeType(call.getArgOperand(i), foreignFunction);
5005
5006 bool replace =
5007 (argTy == DIFFE_TYPE::DUP_NONEED &&
5008 (writeOnlyNoCapture ||
5009 !isa<Argument>(getBaseObject(call.getArgOperand(i))))) ||
5010 (writeOnlyNoCapture && Mode == DerivativeMode::ForwardModeSplit) ||
5011 (writeOnlyNoCapture && readOnly);
5012
5013 if (replace) {
5014 argi = getUndefinedValueForType(M, argi->getType());
5015 }
5016 argsInverted.push_back(argTy);
5017 args.push_back(argi);
5018
5019 if (argTy == DIFFE_TYPE::CONSTANT) {
5020 continue;
5021 }
5022
5023 if (gutils->getWidth() == 1)
5024 for (auto ty : ShadowParamAttrsToPreserve)
5025 if (call.getAttributes().hasParamAttr(i, ty)) {
5026 auto attr = call.getAttributes().getParamAttr(i, ty);
5027 structAttrs[args.size()].push_back(attr);
5028 }
5029
5030 for (auto attr : {"enzymejl_returnRoots", "enzymejl_parmtype",
5031 "enzymejl_parmtype_ref", "enzyme_type",
5032 "enzymejl_sret_union_bytes", "enzymejl_rooted_typ"})
5033 if (call.getAttributes().hasParamAttr(i, attr)) {
5034 if (gutils->getWidth() == 1) {
5035 structAttrs[args.size()].push_back(call.getParamAttr(i, attr));
5036 } else if (attr == std::string("enzymejl_returnRoots")) {
5037 structAttrs[args.size()].push_back(
5038 Attribute::get(call.getContext(), "enzymejl_returnRoots_v",
5039 call.getAttributes()
5040 .getParamAttr(i, "enzymejl_returnRoots")
5041 .getValueAsString()));
5042 } else if (attr == std::string("enzymejl_sret_union_bytes")) {
5043 structAttrs[args.size()].push_back(Attribute::get(
5044 call.getContext(), "enzymejl_sret_union_bytes_v",
5045 call.getAttributes()
5046 .getParamAttr(i, "enzymejl_sret_union_bytes")
5047 .getValueAsString()));
5048 } else if (attr == std::string("enzymejl_rooted_typ")) {
5049 structAttrs[args.size()].push_back(
5050 Attribute::get(call.getContext(), "enzymejl_rooted_typ_v",
5051 call.getAttributes()
5052 .getParamAttr(i, "enzymejl_rooted_typ")
5053 .getValueAsString()));
5054 }
5055 }
5056 if (call.paramHasAttr(i, Attribute::StructRet)) {
5057 if (gutils->getWidth() == 1) {
5058 structAttrs[args.size()].push_back(
5059 Attribute::get(call.getContext(), "enzyme_sret",
5061 call.getParamAttr(i, Attribute::StructRet)
5062 .getValueAsType())));
5063 } else {
5064 structAttrs[args.size()].push_back(
5065 Attribute::get(call.getContext(), "enzyme_sret_v",
5067 call.getParamAttr(i, Attribute::StructRet)
5068 .getValueAsType())));
5069 }
5070 }
5071
5072 assert(argTy == DIFFE_TYPE::DUP_ARG || argTy == DIFFE_TYPE::DUP_NONEED);
5073
5074 args.push_back(gutils->invertPointerM(call.getArgOperand(i), Builder2));
5075 }
5076#if LLVM_VERSION_MAJOR >= 16
5077 std::optional<int> tapeIdx;
5078#else
5079 Optional<int> tapeIdx;
5080#endif
5081 if (subdata) {
5082 auto found = subdata->returns.find(AugmentedStruct::Tape);
5083 if (found != subdata->returns.end()) {
5084 tapeIdx = found->second;
5085 }
5086 }
5087 Value *tape = nullptr;
5088 if (tapeIdx) {
5089
5090 auto idx = *tapeIdx;
5091 FunctionType *FT = subdata->fn->getFunctionType();
5092#if LLVM_VERSION_MAJOR >= 18
5093 auto It = BuilderZ.GetInsertPoint();
5094 It.setHeadBit(true);
5095 BuilderZ.SetInsertPoint(It);
5096#endif
5097 tape = BuilderZ.CreatePHI(
5098 (tapeIdx == -1)
5099 ? FT->getReturnType()
5100 : cast<StructType>(FT->getReturnType())->getElementType(idx),
5101 1, "tapeArg");
5102
5103 assert(!tape->getType()->isEmptyTy());
5104 gutils->TapesToPreventRecomputation.insert(cast<Instruction>(tape));
5105 tape = gutils->cacheForReverse(
5106 BuilderZ, tape, getIndex(&call, CacheType::Tape, BuilderZ));
5107 args.push_back(tape);
5108 }
5109
5110 Value *newcalled = nullptr;
5111 FunctionType *FT = nullptr;
5112
5113 if (called) {
5114 newcalled = gutils->Logic.CreateForwardDiff(
5115 RequestContext(&call, &BuilderZ), cast<Function>(called),
5116 subretType, argsInverted, TR.analyzer->interprocedural,
5117 /*returnValue*/ subretused, Mode,
5118 ((DiffeGradientUtils *)gutils)->FreeMemory, gutils->runtimeActivity,
5119 gutils->strongZero, gutils->getWidth(),
5120 tape ? tape->getType() : nullptr, nextTypeInfo,
5121 subsequent_calls_may_write, overwritten_args,
5122 /*augmented*/ subdata);
5123 FT = cast<Function>(newcalled)->getFunctionType();
5124 } else {
5125 auto callval = call.getCalledOperand();
5126 newcalled = gutils->invertPointerM(callval, BuilderZ);
5127
5128 if (gutils->getWidth() > 1) {
5129 newcalled = BuilderZ.CreateExtractValue(newcalled, {0});
5130 }
5131
5133 BuilderZ, gutils->getNewFromOriginal(callval), newcalled,
5134 "Attempting to call an indirect active function "
5135 "whose runtime value is inactive",
5136 gutils->getNewFromOriginal(call.getDebugLoc()), &call);
5137
5138 auto ft = call.getFunctionType();
5139 bool retActive = subretType != DIFFE_TYPE::CONSTANT;
5140
5142 ft, Mode, gutils->getWidth(), tape ? tape->getType() : nullptr,
5143 argsInverted, false, /*returnTape*/ false,
5144 /*returnPrimal*/ subretused, /*returnShadow*/ retActive);
5145 PointerType *fptype = getUnqual(FT);
5146 newcalled = BuilderZ.CreatePointerCast(newcalled, getUnqual(fptype));
5147 newcalled = BuilderZ.CreateLoad(fptype, newcalled);
5148 }
5149
5150 assert(newcalled);
5151 assert(FT);
5152
5153 SmallVector<ValueType, 2> BundleTypes;
5154 for (auto A : argsInverted)
5155 if (A == DIFFE_TYPE::CONSTANT)
5156 BundleTypes.push_back(ValueType::Primal);
5157 else
5158 BundleTypes.push_back(ValueType::Both);
5159
5160 auto Defs = gutils->getInvertedBundles(&call, BundleTypes, Builder2,
5161 /*lookup*/ false);
5162
5163 CallInst *diffes = Builder2.CreateCall(FT, newcalled, args, Defs);
5164 diffes->setCallingConv(call.getCallingConv());
5165 diffes->setDebugLoc(gutils->getNewFromOriginal(call.getDebugLoc()));
5166
5167 for (auto pair : gradByVal) {
5168 diffes->addParamAttr(
5169 pair.first,
5170 Attribute::getWithByValType(diffes->getContext(), pair.second));
5171 }
5172
5173 for (auto &pair : structAttrs) {
5174 for (auto val : pair.second)
5175 diffes->addParamAttr(pair.first, val);
5176 }
5177
5178 auto newcall = gutils->getNewFromOriginal(&call);
5179 auto ifound = gutils->invertedPointers.find(&call);
5180 Value *primal = nullptr;
5181 Value *diffe = nullptr;
5182
5183 if (subretused && subretType != DIFFE_TYPE::CONSTANT) {
5184 primal = Builder2.CreateExtractValue(diffes, 0);
5185 diffe = Builder2.CreateExtractValue(diffes, 1);
5186 } else if (subretType != DIFFE_TYPE::CONSTANT) {
5187 diffe = diffes;
5188 } else if (!FT->getReturnType()->isVoidTy()) {
5189 primal = diffes;
5190 }
5191
5192 if (ifound != gutils->invertedPointers.end()) {
5193 auto placeholder = cast<PHINode>(&*ifound->second);
5194 if (primal) {
5195 gutils->replaceAWithB(newcall, primal);
5196 gutils->erase(newcall);
5197 } else {
5198 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
5199 }
5200 if (diffe) {
5201 gutils->replaceAWithB(placeholder, diffe);
5202 } else {
5203 gutils->invertedPointers.erase(ifound);
5204 }
5205 gutils->erase(placeholder);
5206 } else {
5207 if (primal && diffe) {
5208 gutils->replaceAWithB(newcall, primal);
5209 if (!gutils->isConstantValue(&call)) {
5210 setDiffe(&call, diffe, Builder2);
5211 }
5212 gutils->erase(newcall);
5213 } else if (diffe) {
5214 setDiffe(&call, diffe, Builder2);
5215 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
5216 } else if (primal) {
5217 gutils->replaceAWithB(newcall, primal);
5218 gutils->erase(newcall);
5219 } else {
5220 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
5221 }
5222 }
5223
5224 return;
5225 }
5226
5227 bool modifyPrimal = shouldAugmentCall(&call, gutils);
5228
5229 SmallVector<Value *, 8> args;
5230 SmallVector<Value *, 8> pre_args;
5231 std::vector<DIFFE_TYPE> argsInverted;
5232 SmallVector<Instruction *, 4> postCreate;
5233 SmallVector<Instruction *, 4> userReplace;
5234 std::map<int, Type *> preByVal;
5235 std::map<int, Type *> gradByVal;
5236 std::map<int, std::vector<Attribute>> structAttrs;
5237
5238 bool replaceFunction = false;
5239
5240 if (Mode == DerivativeMode::ReverseModeCombined && !foreignFunction) {
5241 replaceFunction = legalCombinedForwardReverse(
5242 &call, *replacedReturns, postCreate, userReplace, gutils,
5243 unnecessaryInstructions, oldUnreachable, subretused);
5244 if (replaceFunction) {
5245 modifyPrimal = false;
5246 }
5247 }
5248
5249 SmallVector<ValueType, 2> PreBundleTypes;
5250 SmallVector<ValueType, 2> BundleTypes;
5251
5252 for (unsigned i = 0; i < call.arg_size(); ++i) {
5253
5254 auto argi = gutils->getNewFromOriginal(call.getArgOperand(i));
5255
5256 if (call.isByValArgument(i)) {
5257 preByVal[pre_args.size()] = call.getParamByValType(i);
5258 }
5259 for (auto attr : {"enzymejl_returnRoots", "enzymejl_parmtype",
5260 "enzymejl_parmtype_ref", "enzyme_type",
5261 "enzymejl_sret_union_bytes", "enzymejl_rooted_typ"})
5262 if (call.getAttributes().hasParamAttr(i, attr)) {
5263 structAttrs[pre_args.size()].push_back(call.getParamAttr(i, attr));
5264 }
5265 if (call.paramHasAttr(i, Attribute::StructRet)) {
5266 structAttrs[pre_args.size()].push_back(Attribute::get(
5267 call.getContext(), "enzyme_sret",
5269 call.getParamAttr(i, Attribute::StructRet).getValueAsType())));
5270 }
5271 for (auto ty : PrimalParamAttrsToPreserve)
5272 if (call.getAttributes().hasParamAttr(i, ty)) {
5273 auto attr = call.getAttributes().getParamAttr(i, ty);
5274 structAttrs[pre_args.size()].push_back(attr);
5275 }
5276
5277 auto argTy = gutils->getDiffeType(call.getArgOperand(i), foreignFunction);
5278
5279 bool writeOnlyNoCapture = true;
5280 bool readNoneNoCapture = false;
5281 if (!isNoCapture(&call, i)) {
5282 writeOnlyNoCapture = false;
5283 readNoneNoCapture = false;
5284 }
5285 if (!isWriteOnly(&call, i)) {
5286 writeOnlyNoCapture = false;
5287 }
5288 if (!(isReadOnly(&call, i) && isWriteOnly(&call, i))) {
5289 readNoneNoCapture = false;
5290 }
5291
5292 if (shouldDisableNoWrite(&call)) {
5293 writeOnlyNoCapture = false;
5294 readNoneNoCapture = false;
5295 }
5296
5297 Value *prearg = argi;
5298
5299 ValueType preType = ValueType::Primal;
5300 ValueType revType = ValueType::Primal;
5301
5302 // Keep the existing passed value if coming from outside.
5303 if (readNoneNoCapture ||
5304 (argTy == DIFFE_TYPE::DUP_NONEED &&
5305 (writeOnlyNoCapture ||
5306 !isa<Argument>(getBaseObject(call.getArgOperand(i)))))) {
5307 prearg = getUndefinedValueForType(M, argi->getType());
5308 preType = ValueType::None;
5309 }
5310 pre_args.push_back(prearg);
5311
5313 IRBuilder<> Builder2(&call);
5314 getReverseBuilder(Builder2);
5315
5316 if (call.isByValArgument(i)) {
5317 gradByVal[args.size()] = call.getParamByValType(i);
5318 }
5319
5320 if ((writeOnlyNoCapture && !replaceFunction) ||
5321 (readNoneNoCapture ||
5322 (argTy == DIFFE_TYPE::DUP_NONEED &&
5323 (writeOnlyNoCapture ||
5324 !isa<Argument>(getBaseObject(call.getOperand(i))))))) {
5325 argi = getUndefinedValueForType(M, argi->getType());
5326 revType = ValueType::None;
5327 }
5328 args.push_back(lookup(argi, Builder2));
5329 }
5330
5331 argsInverted.push_back(argTy);
5332
5333 if (argTy == DIFFE_TYPE::CONSTANT) {
5334 PreBundleTypes.push_back(preType);
5335 BundleTypes.push_back(revType);
5336 continue;
5337 }
5338
5339 auto argType = argi->getType();
5340
5341 if (argTy == DIFFE_TYPE::DUP_ARG || argTy == DIFFE_TYPE::DUP_NONEED) {
5342 if (gutils->getWidth() == 1)
5343 for (auto ty : ShadowParamAttrsToPreserve)
5344 if (call.getAttributes().hasParamAttr(i, ty)) {
5345 auto attr = call.getAttributes().getParamAttr(i, ty);
5346 structAttrs[pre_args.size()].push_back(attr);
5347 }
5348
5349 for (auto attr : {"enzymejl_returnRoots", "enzymejl_parmtype",
5350 "enzymejl_parmtype_ref", "enzyme_type",
5351 "enzymejl_sret_union_bytes", "enzymejl_rooted_typ"})
5352 if (call.getAttributes().hasParamAttr(i, attr)) {
5353 if (gutils->getWidth() == 1) {
5354 structAttrs[pre_args.size()].push_back(
5355 call.getParamAttr(i, attr));
5356 } else if (attr == std::string("enzymejl_returnRoots")) {
5357 structAttrs[pre_args.size()].push_back(
5358 Attribute::get(call.getContext(), "enzymejl_returnRoots_v",
5359 call.getAttributes()
5360 .getParamAttr(i, attr)
5361 .getValueAsString()));
5362 } else if (attr == std::string("enzymejl_sret_union_bytes")) {
5363 structAttrs[pre_args.size()].push_back(Attribute::get(
5364 call.getContext(), "enzymejl_sret_union_bytes_v",
5365 call.getAttributes()
5366 .getParamAttr(i, attr)
5367 .getValueAsString()));
5368 } else if (attr == std::string("enzymejl_rooted_typ")) {
5369 structAttrs[pre_args.size()].push_back(
5370 Attribute::get(call.getContext(), "enzymejl_rooted_typ_v",
5371 call.getAttributes()
5372 .getParamAttr(i, attr)
5373 .getValueAsString()));
5374 }
5375 }
5376 if (call.paramHasAttr(i, Attribute::StructRet)) {
5377 if (gutils->getWidth() == 1) {
5378 structAttrs[pre_args.size()].push_back(
5379 Attribute::get(call.getContext(), "enzyme_sret",
5381 call.getParamAttr(i, Attribute::StructRet)
5382 .getValueAsType())));
5383 } else {
5384 structAttrs[pre_args.size()].push_back(
5385 Attribute::get(call.getContext(), "enzyme_sret_v",
5387 call.getParamAttr(i, Attribute::StructRet)
5388 .getValueAsType())));
5389 }
5390 }
5392 IRBuilder<> Builder2(&call);
5393 getReverseBuilder(Builder2);
5394
5395 Value *darg = nullptr;
5396
5397 if (((writeOnlyNoCapture && TR.query(call.getArgOperand(
5398 i))[{-1, -1}] == BaseType::Pointer) ||
5399 gutils->isConstantInstruction(&call)) &&
5400 !replaceFunction) {
5402 M, gutils->getShadowType(argi->getType()));
5403 } else {
5404 darg = gutils->invertPointerM(call.getArgOperand(i), Builder2);
5405 revType = (revType == ValueType::None) ? ValueType::Shadow
5407 }
5408 args.push_back(lookup(darg, Builder2));
5409 }
5410 if (Mode == DerivativeMode::ReverseModeGradient && !replaceFunction) {
5411 pre_args.push_back(getUndefinedValueForType(M, argi->getType()));
5412 } else {
5413 pre_args.push_back(
5414 gutils->invertPointerM(call.getArgOperand(i), BuilderZ));
5415 }
5416 preType =
5418
5419 // Note sometimes whattype mistakenly says something should be
5420 // constant [because composed of integer pointers alone]
5421 auto wt = whatType(argType, Mode);
5422 if (wt != DIFFE_TYPE::DUP_ARG && wt != DIFFE_TYPE::CONSTANT) {
5423 std::string str;
5424 raw_string_ostream ss(str);
5425 ss << "Mismatched estimated activity type for " << *argType
5426 << " expected DUP_ARG or CONSTANT found " << wt
5427 << ", call = " << call << "\n";
5428 if (CustomErrorHandler) {
5429 CustomErrorHandler(str.c_str(), wrap(&call),
5430 ErrorType::InternalError, nullptr, nullptr,
5431 nullptr);
5432 } else {
5433 EmitFailure("MismatchArgType", call.getDebugLoc(), &call, ss.str());
5434 }
5435 }
5436 } else {
5437 if (foreignFunction)
5438 assert(!argType->isIntOrIntVectorTy());
5439 assert(whatType(argType, Mode) == DIFFE_TYPE::OUT_DIFF ||
5440 whatType(argType, Mode) == DIFFE_TYPE::CONSTANT);
5441 }
5442 PreBundleTypes.push_back(preType);
5443 BundleTypes.push_back(revType);
5444 }
5445 if (called) {
5446 if (call.arg_size() !=
5447 cast<Function>(called)->getFunctionType()->getNumParams()) {
5448 llvm::errs() << *gutils->oldFunc->getParent() << "\n";
5449 llvm::errs() << *gutils->oldFunc << "\n";
5450 llvm::errs() << call << "\n";
5451 llvm::errs() << " number of arg operands != function parameters\n";
5452 EmitFailure("MismatchArgs", call.getDebugLoc(), &call,
5453 "Number of arg operands != function parameters\n", call);
5454 }
5455 }
5456
5457 Value *tape = nullptr;
5458 CallInst *augmentcall = nullptr;
5459 Value *cachereplace = nullptr;
5460
5461 // std::optional<std::map<std::pair<Instruction*, std::string>,
5462 // unsigned>> sub_index_map;
5463#if LLVM_VERSION_MAJOR >= 16
5464 std::optional<int> tapeIdx;
5465 std::optional<int> returnIdx;
5466 std::optional<int> differetIdx;
5467#else
5468 Optional<int> tapeIdx;
5469 Optional<int> returnIdx;
5470 Optional<int> differetIdx;
5471#endif
5472 if (modifyPrimal) {
5473
5474 Value *newcalled = nullptr;
5475 FunctionType *FT = nullptr;
5476 const AugmentedReturn *fnandtapetype = nullptr;
5477
5478 if (!called) {
5479 auto callval = call.getCalledOperand();
5480 Value *uncast = callval;
5481 while (auto CE = dyn_cast<ConstantExpr>(uncast)) {
5482 if (CE->isCast()) {
5483 uncast = CE->getOperand(0);
5484 continue;
5485 }
5486 break;
5487 }
5488 if (isa<ConstantInt>(uncast)) {
5489 std::string str;
5490 raw_string_ostream ss(str);
5491 ss << "cannot find shadow for " << *callval
5492 << " for use as function in " << call;
5493 EmitNoDerivativeError(ss.str(), call, gutils, BuilderZ);
5494 }
5495 newcalled = gutils->invertPointerM(callval, BuilderZ);
5496
5499 BuilderZ, gutils->getNewFromOriginal(callval), newcalled,
5500 "Attempting to call an indirect active function "
5501 "whose runtime value is inactive",
5502 gutils->getNewFromOriginal(call.getDebugLoc()), &call);
5503
5504 FunctionType *ft = call.getFunctionType();
5505
5506 std::set<llvm::Type *> seen;
5507 DIFFE_TYPE subretType = whatType(call.getType(), Mode,
5508 /*intAreConstant*/ false, seen);
5510 ft, /*returnUsed*/ true, /*subretType*/ subretType);
5511 FT = FunctionType::get(
5512 StructType::get(newcalled->getContext(), res.second), res.first,
5513 ft->isVarArg());
5514 auto fptype = getUnqual(FT);
5515 newcalled = BuilderZ.CreatePointerCast(newcalled, getUnqual(fptype));
5516 newcalled = BuilderZ.CreateLoad(fptype, newcalled);
5517 tapeIdx = 0;
5518
5519 if (!call.getType()->isVoidTy()) {
5520 returnIdx = 1;
5521 if (subretType == DIFFE_TYPE::DUP_ARG ||
5522 subretType == DIFFE_TYPE::DUP_NONEED) {
5523 differetIdx = 2;
5524 }
5525 }
5526 } else {
5529 subdata = &gutils->Logic.CreateAugmentedPrimal(
5530 RequestContext(&call, &BuilderZ), cast<Function>(called),
5531 subretType, argsInverted, TR.analyzer->interprocedural,
5532 /*return is used*/ subretused, shadowReturnUsed, nextTypeInfo,
5533 subsequent_calls_may_write, overwritten_args, false,
5534 gutils->runtimeActivity, gutils->strongZero, gutils->getWidth(),
5535 gutils->AtomicAdd);
5537 assert(augmentedReturn);
5538 auto subaugmentations =
5539 (std::map<const llvm::CallInst *, AugmentedReturn *>
5540 *)&augmentedReturn->subaugmentations;
5542 *subaugmentations, &call, (AugmentedReturn *)subdata);
5543 }
5544 }
5545 {
5546 Intrinsic::ID ID = Intrinsic::not_intrinsic;
5547 if (!subdata &&
5549 llvm::errs() << *gutils->oldFunc->getParent() << "\n";
5550 llvm::errs() << *gutils->oldFunc << "\n";
5551 llvm::errs() << *gutils->newFunc << "\n";
5552 llvm::errs() << *called << "\n";
5553 assert(subdata);
5554 }
5555 }
5556
5557 if (subdata) {
5558 fnandtapetype = subdata;
5559 newcalled = subdata->fn;
5560 FT = cast<Function>(newcalled)->getFunctionType();
5561
5562 auto found =
5564 if (found != subdata->returns.end()) {
5565 differetIdx = found->second;
5566 } else {
5567 assert(!shadowReturnUsed);
5568 }
5569
5570 found = subdata->returns.find(AugmentedStruct::Return);
5571 if (found != subdata->returns.end()) {
5572 returnIdx = found->second;
5573 } else {
5574 assert(!subretused);
5575 }
5576
5577 found = subdata->returns.find(AugmentedStruct::Tape);
5578 if (found != subdata->returns.end()) {
5579 tapeIdx = found->second;
5580 }
5581 }
5582 }
5583 // sub_index_map = fnandtapetype.tapeIndices;
5584
5585 // llvm::errs() << "seeing sub_index_map of " << sub_index_map->size()
5586 // << " in ap " << cast<Function>(called)->getName() << "\n";
5589
5590 assert(newcalled);
5591 assert(FT);
5592
5593 if (false) {
5594 badaugmentedfn:;
5595 auto NC = dyn_cast<Function>(newcalled);
5596 llvm::errs() << *gutils->oldFunc << "\n";
5597 llvm::errs() << *gutils->newFunc << "\n";
5598 if (NC)
5599 llvm::errs() << " trying to call " << NC->getName() << " " << *FT
5600 << "\n";
5601 else
5602 llvm::errs() << " trying to call " << *newcalled << " " << *FT
5603 << "\n";
5604
5605 for (unsigned i = 0; i < pre_args.size(); ++i) {
5606 assert(pre_args[i]);
5607 assert(pre_args[i]->getType());
5608 llvm::errs() << "args[" << i << "] = " << *pre_args[i]
5609 << " FT:" << *FT->getParamType(i) << "\n";
5610 }
5611 assert(0 && "calling with wrong number of arguments");
5612 exit(1);
5613 }
5614
5615 if (pre_args.size() != FT->getNumParams())
5616 goto badaugmentedfn;
5617
5618 for (unsigned i = 0; i < pre_args.size(); ++i) {
5619 if (pre_args[i]->getType() == FT->getParamType(i))
5620 continue;
5621 else if (!call.getCalledFunction())
5622 pre_args[i] =
5623 BuilderZ.CreateBitCast(pre_args[i], FT->getParamType(i));
5624 else
5625 goto badaugmentedfn;
5626 }
5627
5628 augmentcall = BuilderZ.CreateCall(
5629 FT, newcalled, pre_args,
5630 gutils->getInvertedBundles(&call, PreBundleTypes, BuilderZ,
5631 /*lookup*/ false));
5632 augmentcall->setCallingConv(call.getCallingConv());
5633 augmentcall->setDebugLoc(
5634 gutils->getNewFromOriginal(call.getDebugLoc()));
5635
5636 for (auto pair : preByVal) {
5637 augmentcall->addParamAttr(
5638 pair.first, Attribute::getWithByValType(augmentcall->getContext(),
5639 pair.second));
5640 }
5641
5642 for (auto &pair : structAttrs) {
5643 for (auto val : pair.second)
5644 augmentcall->addParamAttr(pair.first, val);
5645 }
5646
5647 if (!augmentcall->getType()->isVoidTy())
5648 augmentcall->setName(call.getName() + "_augmented");
5649
5650 if (tapeIdx) {
5651 auto tval = *tapeIdx;
5652 tape = (tval == -1) ? augmentcall
5653 : BuilderZ.CreateExtractValue(
5654 augmentcall, {(unsigned)tval}, "subcache");
5655 if (tape->getType()->isEmptyTy()) {
5656 auto tt = tape->getType();
5657 gutils->erase(cast<Instruction>(tape));
5658 tape = UndefValue::get(tt);
5659 } else {
5660 gutils->TapesToPreventRecomputation.insert(cast<Instruction>(tape));
5661 }
5662 tape = gutils->cacheForReverse(
5663 BuilderZ, tape, getIndex(&call, CacheType::Tape, BuilderZ));
5664 }
5665
5666 if (subretused) {
5667 Value *dcall = nullptr;
5668 assert(returnIdx);
5669 assert(augmentcall);
5670 auto rval = *returnIdx;
5671 dcall = (rval < 0) ? augmentcall
5672 : BuilderZ.CreateExtractValue(augmentcall,
5673 {(unsigned)rval});
5674 gutils->originalToNewFn[&call] = dcall;
5675 gutils->newToOriginalFn.erase(newCall);
5676 gutils->newToOriginalFn[dcall] = &call;
5677
5678 assert(dcall->getType() == call.getType());
5679 assert(dcall);
5680
5681 if (!gutils->isConstantValue(&call)) {
5682 if (!call.getType()->isFPOrFPVectorTy() && TR.anyPointer(&call)) {
5683 } else if (Mode != DerivativeMode::ReverseModePrimal) {
5684 ((DiffeGradientUtils *)gutils)->differentials[dcall] =
5685 ((DiffeGradientUtils *)gutils)->differentials[newCall];
5686 ((DiffeGradientUtils *)gutils)->differentials.erase(newCall);
5687 }
5688 }
5689 assert(dcall->getType() == call.getType());
5690 gutils->replaceAWithB(newCall, dcall);
5691
5692 if (isa<Instruction>(dcall) && !isa<PHINode>(dcall)) {
5693 cast<Instruction>(dcall)->takeName(newCall);
5694 }
5695
5697 !gutils->unnecessaryIntermediates.count(&call)) {
5698
5699 std::map<UsageKey, bool> Seen;
5700 bool primalNeededInReverse = false;
5701 for (auto pair : gutils->knownRecomputeHeuristic)
5702 if (!pair.second) {
5703 if (pair.first == &call) {
5704 primalNeededInReverse = true;
5705 break;
5706 } else {
5707 Seen[UsageKey(pair.first, QueryType::Primal)] = false;
5708 }
5709 }
5710 if (!primalNeededInReverse) {
5711
5712 auto minCutMode = (Mode == DerivativeMode::ReverseModePrimal)
5714 : Mode;
5715 primalNeededInReverse =
5717 QueryType::Primal>(gutils, &call, minCutMode, Seen,
5718 oldUnreachable);
5719 }
5720 if (primalNeededInReverse)
5721 gutils->cacheForReverse(
5722 BuilderZ, dcall, getIndex(&call, CacheType::Self, BuilderZ));
5723 }
5724 BuilderZ.SetInsertPoint(newCall->getNextNode());
5725 gutils->erase(newCall);
5726 } else {
5727 BuilderZ.SetInsertPoint(BuilderZ.GetInsertPoint()->getNextNode());
5728 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
5729 gutils->originalToNewFn[&call] = augmentcall;
5730 gutils->newToOriginalFn[augmentcall] = &call;
5731 }
5732
5733 } else {
5734 if (subdata && subdata->returns.find(AugmentedStruct::Tape) ==
5735 subdata->returns.end()) {
5736 } else {
5737 // assert(!tape);
5738 // assert(subdata);
5739 if (FT) {
5740 if (!tape) {
5741 assert(tapeIdx);
5742 auto tval = *tapeIdx;
5743#if LLVM_VERSION_MAJOR >= 18
5744 auto It = BuilderZ.GetInsertPoint();
5745 It.setHeadBit(true);
5746 BuilderZ.SetInsertPoint(It);
5747#endif
5748 tape = BuilderZ.CreatePHI(
5749 (tapeIdx == -1) ? FT->getReturnType()
5750 : cast<StructType>(FT->getReturnType())
5751 ->getElementType(tval),
5752 1, "tapeArg");
5753 }
5754 tape = gutils->cacheForReverse(
5755 BuilderZ, tape, getIndex(&call, CacheType::Tape, BuilderZ));
5756 }
5757 }
5758
5759 if (subretused) {
5760 Intrinsic::ID ID = Intrinsic::not_intrinsic;
5762 QueryType::Primal>(gutils, &call, Mode, oldUnreachable) &&
5763 !gutils->unnecessaryIntermediates.count(&call)) {
5764
5765 if (!isMemFreeLibMFunction(getFuncNameFromCall(&call), &ID)) {
5766
5767#if LLVM_VERSION_MAJOR >= 18
5768 auto It = BuilderZ.GetInsertPoint();
5769 It.setHeadBit(true);
5770 BuilderZ.SetInsertPoint(It);
5771#endif
5772 auto idx = getIndex(&call, CacheType::Self, BuilderZ);
5773 if (idx == IndexMappingError) {
5774 std::string str;
5775 raw_string_ostream ss(str);
5776 ss << "Failed to compute consistent cache index for operation: "
5777 << call << "\n";
5778 if (CustomErrorHandler) {
5779 CustomErrorHandler(str.c_str(), wrap(&call),
5780 ErrorType::InternalError, nullptr, nullptr,
5781 nullptr);
5782 } else {
5783 EmitFailure("GetIndexError", call.getDebugLoc(), &call,
5784 ss.str());
5785 }
5786 } else {
5788 cachereplace = newCall;
5789 else
5790 cachereplace = BuilderZ.CreatePHI(
5791 call.getType(), 1, call.getName() + "_tmpcacheB");
5792 cachereplace =
5793 gutils->cacheForReverse(BuilderZ, cachereplace, idx);
5794 }
5795 }
5796 } else {
5797#if LLVM_VERSION_MAJOR >= 18
5798 auto It = BuilderZ.GetInsertPoint();
5799 It.setHeadBit(true);
5800 BuilderZ.SetInsertPoint(It);
5801#endif
5802 auto pn = BuilderZ.CreatePHI(
5803 call.getType(), 1, (call.getName() + "_replacementE").str());
5804 gutils->fictiousPHIs[pn] = &call;
5805 cachereplace = pn;
5806 }
5807 } else {
5808 // TODO move right after newCall for the insertion point of BuilderZ
5809
5810 BuilderZ.SetInsertPoint(BuilderZ.GetInsertPoint()->getNextNode());
5811 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
5812 }
5813 }
5814
5815 auto ifound = gutils->invertedPointers.find(&call);
5816 if (ifound != gutils->invertedPointers.end()) {
5817 auto placeholder = cast<PHINode>(&*ifound->second);
5818
5819 bool subcheck = (subretType == DIFFE_TYPE::DUP_ARG ||
5820 subretType == DIFFE_TYPE::DUP_NONEED);
5821
5822 //! We only need the shadow pointer for non-forward Mode if it is used
5823 //! in a non return setting
5824 bool hasNonReturnUse = false;
5825 for (auto use : call.users()) {
5827 !isa<ReturnInst>(use)) {
5828 hasNonReturnUse = true;
5829 }
5830 }
5831
5832 if (subcheck && hasNonReturnUse) {
5833
5834 Value *newip = nullptr;
5837
5838 if (!differetIdx) {
5839 std::string str;
5840 raw_string_ostream ss(str);
5841 ss << "Did not have return index set when differentiating "
5842 "function\n";
5843 ss << " call" << call << "\n";
5844 ss << " augmentcall" << *augmentcall << "\n";
5845 if (CustomErrorHandler) {
5846 CustomErrorHandler(str.c_str(), wrap(&call),
5847 ErrorType::InternalError, nullptr, nullptr,
5848 nullptr);
5849 } else {
5850 EmitFailure("GetIndexError", call.getDebugLoc(), &call,
5851 ss.str());
5852 }
5853 placeholder->replaceAllUsesWith(
5854 UndefValue::get(placeholder->getType()));
5855 if (placeholder == &*BuilderZ.GetInsertPoint()) {
5856 BuilderZ.SetInsertPoint(placeholder->getNextNode());
5857 }
5858 gutils->erase(placeholder);
5859 } else {
5860 auto drval = *differetIdx;
5861 newip = (drval < 0)
5862 ? augmentcall
5863 : BuilderZ.CreateExtractValue(augmentcall,
5864 {(unsigned)drval},
5865 call.getName() + "'ac");
5866 assert(newip->getType() == placeholder->getType());
5867 placeholder->replaceAllUsesWith(newip);
5868 if (placeholder == &*BuilderZ.GetInsertPoint()) {
5869 BuilderZ.SetInsertPoint(placeholder->getNextNode());
5870 }
5871 gutils->erase(placeholder);
5872 }
5873 } else {
5874 newip = placeholder;
5875 }
5876
5877 newip = gutils->cacheForReverse(
5878 BuilderZ, newip, getIndex(&call, CacheType::Shadow, BuilderZ));
5879
5880 gutils->invertedPointers.insert(std::make_pair(
5881 (const Value *)&call, InvertedPointerVH(gutils, newip)));
5882 } else {
5883 gutils->invertedPointers.erase(ifound);
5884 if (placeholder == &*BuilderZ.GetInsertPoint()) {
5885 BuilderZ.SetInsertPoint(placeholder->getNextNode());
5886 }
5887 gutils->erase(placeholder);
5888 }
5889 }
5890
5891 if (fnandtapetype && fnandtapetype->tapeType &&
5895 shouldFree()) {
5896 assert(tape);
5897 auto tapep = BuilderZ.CreatePointerCast(
5898 tape, getPointerType(
5899 fnandtapetype->tapeType,
5900 cast<PointerType>(tape->getType())->getAddressSpace()));
5901 auto truetape =
5902 BuilderZ.CreateLoad(fnandtapetype->tapeType, tapep, "tapeld");
5903 truetape->setMetadata("enzyme_mustcache",
5904 MDNode::get(truetape->getContext(), {}));
5905
5906 CreateDealloc(BuilderZ, tape);
5907 tape = truetape;
5908 }
5909 } else {
5910 auto ifound = gutils->invertedPointers.find(&call);
5911 if (ifound != gutils->invertedPointers.end()) {
5912 auto placeholder = cast<PHINode>(&*ifound->second);
5913 gutils->invertedPointers.erase(ifound);
5914 gutils->erase(placeholder);
5915 }
5916 if (/*!topLevel*/ Mode != DerivativeMode::ReverseModeCombined &&
5917 subretused && !call.doesNotAccessMemory()) {
5919 QueryType::Primal>(gutils, &call, Mode, oldUnreachable) &&
5920 !gutils->unnecessaryIntermediates.count(&call)) {
5921 assert(!replaceFunction);
5922#if LLVM_VERSION_MAJOR >= 18
5923 auto It = BuilderZ.GetInsertPoint();
5924 It.setHeadBit(true);
5925 BuilderZ.SetInsertPoint(It);
5926#endif
5927 cachereplace = BuilderZ.CreatePHI(call.getType(), 1,
5928 call.getName() + "_cachereplace2");
5929 cachereplace = gutils->cacheForReverse(
5930 BuilderZ, cachereplace,
5931 getIndex(&call, CacheType::Self, BuilderZ));
5932 } else {
5933#if LLVM_VERSION_MAJOR >= 18
5934 auto It = BuilderZ.GetInsertPoint();
5935 It.setHeadBit(true);
5936 BuilderZ.SetInsertPoint(It);
5937#endif
5938 auto pn = BuilderZ.CreatePHI(call.getType(), 1,
5939 call.getName() + "_replacementC");
5940 gutils->fictiousPHIs[pn] = &call;
5941 cachereplace = pn;
5942 }
5943 }
5944
5945 if (!subretused && !replaceFunction)
5946 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
5947 }
5948
5949 // Note here down only contains the reverse bits
5951 return;
5952 }
5953
5954 IRBuilder<> Builder2(&call);
5955 getReverseBuilder(Builder2);
5956
5957 Value *newcalled = nullptr;
5958 FunctionType *FT = nullptr;
5959
5960 DerivativeMode subMode = (replaceFunction || !modifyPrimal)
5963 if (called) {
5964 if (Mode == DerivativeMode::ReverseModeGradient && subdata) {
5965 for (size_t i = 0; i < argsInverted.size(); i++) {
5966 if (subdata->constant_args[i] == argsInverted[i])
5967 continue;
5968 assert(subdata->constant_args[i] == DIFFE_TYPE::DUP_ARG);
5969 assert(argsInverted[i] == DIFFE_TYPE::DUP_NONEED);
5970 argsInverted[i] = DIFFE_TYPE::DUP_ARG;
5971 }
5972 }
5973
5974 newcalled = gutils->Logic.CreatePrimalAndGradient(
5975 RequestContext(&call, &Builder2),
5977 .todiff = cast<Function>(called),
5978 .retType = subretType,
5979 .constant_args = argsInverted,
5980 .subsequent_calls_may_write = subsequent_calls_may_write,
5981 .overwritten_args = overwritten_args,
5982 .returnUsed = replaceFunction && subretused,
5983 .shadowReturnUsed = shadowReturnUsed && replaceFunction,
5984 .mode = subMode,
5985 .width = gutils->getWidth(),
5986 .freeMemory = true,
5987 .AtomicAdd = gutils->AtomicAdd,
5988 .additionalType = tape ? tape->getType() : nullptr,
5989 .forceAnonymousTape = false,
5990 .typeInfo = nextTypeInfo,
5991 .runtimeActivity = gutils->runtimeActivity,
5992 .strongZero = gutils->strongZero},
5993 TR.analyzer->interprocedural, subdata);
5994 if (!newcalled)
5995 return;
5996 FT = cast<Function>(newcalled)->getFunctionType();
5997 } else {
5998
5999 assert(subMode != DerivativeMode::ReverseModeCombined);
6000
6001 auto callval = call.getCalledOperand();
6002
6003 if (gutils->isConstantValue(callval)) {
6004 std::string s;
6005 llvm::raw_string_ostream ss(s);
6006 ss << *gutils->oldFunc << "\n";
6007 ss << "in Mode: " << to_string(Mode) << "\n";
6008 ss << " orig: " << call << " callval: " << *callval << "\n";
6009 ss << " constant function being called, but active call instruction\n";
6010 auto val = EmitNoDerivativeError(ss.str(), call, gutils, Builder2);
6011 if (val)
6012 newcalled = val;
6013 else
6014 newcalled =
6015 UndefValue::get(gutils->getShadowType(callval->getType()));
6016 } else {
6017 newcalled = lookup(gutils->invertPointerM(callval, Builder2), Builder2);
6018 }
6019
6020 auto ft = call.getFunctionType();
6021
6022 auto res =
6023 getDefaultFunctionTypeForGradient(ft, /*subretType*/ subretType);
6024 // TODO Note there is empty tape added here, replace with generic
6025 res.first.push_back(getInt8PtrTy(newcalled->getContext()));
6026 FT = FunctionType::get(
6027 StructType::get(newcalled->getContext(), res.second), res.first,
6028 ft->isVarArg());
6029 auto fptype = getUnqual(FT);
6030 newcalled = Builder2.CreatePointerCast(newcalled, getUnqual(fptype));
6031 newcalled = Builder2.CreateLoad(
6032 fptype, Builder2.CreateConstGEP1_64(fptype, newcalled, 1));
6033 }
6034
6035 if (subretType == DIFFE_TYPE::OUT_DIFF) {
6036 args.push_back(diffe(&call, Builder2));
6037 }
6038
6039 if (tape) {
6040 auto ntape = gutils->lookupM(tape, Builder2);
6041 assert(ntape);
6042 assert(ntape->getType());
6043 args.push_back(ntape);
6044 }
6045
6046 assert(newcalled);
6047 assert(FT);
6048
6049 if (false) {
6050 badfn:;
6051 auto NC = dyn_cast<Function>(newcalled);
6052 llvm::errs() << *gutils->oldFunc << "\n";
6053 llvm::errs() << *gutils->newFunc << "\n";
6054 if (NC)
6055 llvm::errs() << " trying to call " << NC->getName() << " " << *FT
6056 << "\n";
6057 else
6058 llvm::errs() << " trying to call " << *newcalled << " " << *FT << "\n";
6059
6060 for (unsigned i = 0; i < args.size(); ++i) {
6061 assert(args[i]);
6062 assert(args[i]->getType());
6063 llvm::errs() << "args[" << i << "] = " << *args[i]
6064 << " FT:" << *FT->getParamType(i) << "\n";
6065 }
6066 assert(0 && "calling with wrong number of arguments");
6067 exit(1);
6068 }
6069
6070 if (args.size() != FT->getNumParams())
6071 goto badfn;
6072
6073 for (unsigned i = 0; i < args.size(); ++i) {
6074 if (args[i]->getType() == FT->getParamType(i))
6075 continue;
6076 else if (!call.getCalledFunction())
6077 args[i] = Builder2.CreateBitCast(args[i], FT->getParamType(i));
6078 else
6079 goto badfn;
6080 }
6081
6082 CallInst *diffes =
6083 Builder2.CreateCall(FT, newcalled, args,
6084 gutils->getInvertedBundles(
6085 &call, BundleTypes, Builder2, /*lookup*/ true));
6086 diffes->setCallingConv(call.getCallingConv());
6087 diffes->setDebugLoc(gutils->getNewFromOriginal(call.getDebugLoc()));
6088
6089 for (auto pair : gradByVal) {
6090 diffes->addParamAttr(pair.first, Attribute::getWithByValType(
6091 diffes->getContext(), pair.second));
6092 }
6093
6094 for (auto &pair : structAttrs) {
6095 for (auto val : pair.second)
6096 diffes->addParamAttr(pair.first, val);
6097 }
6098
6099 unsigned structidx = 0;
6100 if (replaceFunction) {
6101 if (subretused)
6102 structidx++;
6103 if (shadowReturnUsed)
6104 structidx++;
6105 }
6106
6107 for (unsigned i = 0; i < call.arg_size(); ++i) {
6108 if (argsInverted[i] == DIFFE_TYPE::OUT_DIFF) {
6109 Value *diffeadd = Builder2.CreateExtractValue(diffes, {structidx});
6110 ++structidx;
6111
6112 if (!gutils->isConstantValue(call.getArgOperand(i))) {
6113 size_t size = 1;
6114 if (call.getArgOperand(i)->getType()->isSized())
6115 size = (gutils->newFunc->getParent()
6116 ->getDataLayout()
6117 .getTypeSizeInBits(call.getArgOperand(i)->getType()) +
6118 7) /
6119 8;
6120
6121 addToDiffe(call.getArgOperand(i), diffeadd, Builder2,
6122 TR.addingType(size, call.getArgOperand(i)));
6123 }
6124 }
6125 }
6126
6127 if (diffes->getType()->isVoidTy()) {
6128 if (structidx != 0) {
6129 llvm::errs() << *gutils->oldFunc->getParent() << "\n";
6130 llvm::errs() << "diffes: " << *diffes << " structidx=" << structidx
6131 << " subretused=" << subretused
6132 << " shadowReturnUsed=" << shadowReturnUsed << "\n";
6133 }
6134 assert(structidx == 0);
6135 } else {
6136 assert(cast<StructType>(diffes->getType())->getNumElements() ==
6137 structidx);
6138 }
6139
6140 if (subretType == DIFFE_TYPE::OUT_DIFF)
6141 setDiffe(&call,
6142 Constant::getNullValue(gutils->getShadowType(call.getType())),
6143 Builder2);
6144
6145 if (replaceFunction) {
6146
6147 // if a function is replaced for joint forward/reverse, handle inverted
6148 // pointers
6149 auto ifound = gutils->invertedPointers.find(&call);
6150 if (ifound != gutils->invertedPointers.end()) {
6151 auto placeholder = cast<PHINode>(&*ifound->second);
6152 gutils->invertedPointers.erase(ifound);
6153 if (shadowReturnUsed) {
6154 dumpMap(gutils->invertedPointers);
6155 auto dretval = cast<Instruction>(
6156 Builder2.CreateExtractValue(diffes, {subretused ? 1U : 0U}));
6157 /* todo handle this case later */
6158 assert(!subretused);
6159 gutils->invertedPointers.insert(std::make_pair(
6160 (const Value *)&call, InvertedPointerVH(gutils, dretval)));
6161 }
6162 gutils->erase(placeholder);
6163 }
6164
6165 Instruction *retval = nullptr;
6166
6167 if (subretused) {
6168 retval = cast<Instruction>(Builder2.CreateExtractValue(diffes, {0}));
6169 if (retval) {
6170 gutils->replaceAndRemoveUnwrapCacheFor(newCall, retval);
6171 }
6172 gutils->replaceAWithB(newCall, retval, /*storeInCache*/ true);
6173 } else {
6174 eraseIfUnused(call, /*erase*/ false, /*check*/ false);
6175 }
6176
6177 SmallPtrSet<Value *, 2> postCreateSet(postCreate.begin(),
6178 postCreate.end());
6179 for (auto a : postCreate) {
6180 a->moveBefore(*Builder2.GetInsertBlock(), Builder2.GetInsertPoint());
6181 for (size_t i = 0; i < a->getNumOperands(); i++) {
6182 auto op = dyn_cast<Instruction>(a->getOperand(i));
6183 if (!op || postCreateSet.count(op))
6184 continue;
6185 if (gutils->isOriginal(op->getParent())) {
6186 IRBuilder<> BuilderA(a);
6187 a->setOperand(i, gutils->lookupM(op, BuilderA));
6188 }
6189 }
6190 }
6191
6192 gutils->originalToNewFn[&call] = retval ? retval : diffes;
6193 gutils->newToOriginalFn.erase(newCall);
6194 gutils->newToOriginalFn[retval ? retval : diffes] = &call;
6195
6196 gutils->erase(newCall);
6197
6198 return;
6199 }
6200
6201 if (cachereplace) {
6202 if (subretused) {
6203 Value *dcall = nullptr;
6204 assert(cachereplace->getType() == call.getType());
6205 assert(dcall == nullptr);
6206 dcall = cachereplace;
6207 assert(dcall);
6208
6209 if (!gutils->isConstantValue(&call)) {
6210 gutils->originalToNewFn[&call] = dcall;
6211 gutils->newToOriginalFn.erase(newCall);
6212 gutils->newToOriginalFn[dcall] = &call;
6213 if (!call.getType()->isFPOrFPVectorTy() && TR.anyPointer(&call)) {
6214 } else {
6215 ((DiffeGradientUtils *)gutils)->differentials[dcall] =
6216 ((DiffeGradientUtils *)gutils)->differentials[newCall];
6217 ((DiffeGradientUtils *)gutils)->differentials.erase(newCall);
6218 }
6219 }
6220 assert(dcall->getType() == call.getType());
6221 newCall->replaceAllUsesWith(dcall);
6222 if (isa<Instruction>(dcall) && !isa<PHINode>(dcall)) {
6223 cast<Instruction>(dcall)->takeName(&call);
6224 }
6225 gutils->erase(newCall);
6226 } else {
6227 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
6228 if (augmentcall) {
6229 gutils->originalToNewFn[&call] = augmentcall;
6230 gutils->newToOriginalFn.erase(newCall);
6231 gutils->newToOriginalFn[augmentcall] = &call;
6232 }
6233 }
6234 }
6235 return;
6236 }
6237
6238 void handleMPI(llvm::CallInst &call, llvm::Function *called,
6239 llvm::StringRef funcName);
6240
6241 bool handleKnownCallDerivatives(llvm::CallInst &call, llvm::Function *called,
6242 llvm::StringRef funcName,
6243 bool subsequent_calls_may_write,
6244 const std::vector<bool> &overwritten_args,
6245 llvm::CallInst *const newCall);
6246
6247 // Return
6248 void visitCallInst(llvm::CallInst &call) {
6249 using namespace llvm;
6250
6251 StringRef funcName = getFuncNameFromCall(&call);
6252
6253 // When compiling Enzyme against standard LLVM, and not Intel's
6254 // modified version of LLVM, the intrinsic `llvm.intel.subscript` is
6255 // not fully understood by LLVM. One of the results of this is that the
6256 // visitor dispatches to visitCallInst, rather than visitIntrinsicInst, when
6257 // presented with the intrinsic - hence why we are handling it here.
6258 if (startsWith(funcName, ("llvm.intel.subscript"))) {
6259 assert(isa<IntrinsicInst>(call));
6260 visitIntrinsicInst(cast<IntrinsicInst>(call));
6261 return;
6262 }
6263
6264 if (funcName == "llvm.enzyme.lifetime_start") {
6265 visitIntrinsicInst(cast<IntrinsicInst>(call));
6266 return;
6267 }
6268 if (funcName == "llvm.enzyme.lifetime_end") {
6269 SmallVector<Value *, 2> orig_ops(call.getNumOperands());
6270 for (unsigned i = 0; i < call.getNumOperands(); ++i) {
6271 orig_ops[i] = call.getOperand(i);
6272 }
6273 handleAdjointForIntrinsic(Intrinsic::lifetime_end, call, orig_ops);
6274 eraseIfUnused(call);
6275 return;
6276 }
6277
6278 CallInst *const newCall = cast<CallInst>(gutils->getNewFromOriginal(&call));
6279 IRBuilder<> BuilderZ(newCall);
6280 BuilderZ.setFastMathFlags(getFast());
6281
6282 if (overwritten_args_map.find(&call) == overwritten_args_map.end() &&
6285 llvm::errs() << " call: " << call << "\n";
6286 for (auto &pair : overwritten_args_map) {
6287 llvm::errs() << " + " << *pair.first << "\n";
6288 }
6289 }
6290
6291 assert(overwritten_args_map.find(&call) != overwritten_args_map.end() ||
6294 const bool subsequent_calls_may_write =
6295 (Mode == DerivativeMode::ForwardMode ||
6297 ? false
6298 : overwritten_args_map.find(&call)->second.first;
6299 const std::vector<bool> &overwritten_args =
6300 (Mode == DerivativeMode::ForwardMode ||
6302 ? std::vector<bool>()
6303 : overwritten_args_map.find(&call)->second.second;
6304
6305 auto called = getFunctionFromCall(&call);
6306
6307 bool subretused = false;
6308 bool shadowReturnUsed = false;
6309 auto smode = Mode;
6312 DIFFE_TYPE subretType = gutils->getReturnDiffeType(
6313 &call, &subretused, &shadowReturnUsed, smode);
6314
6315 if (Mode == DerivativeMode::ForwardMode ||
6317 auto found = customFwdCallHandlers.find(funcName);
6318 if (found != customFwdCallHandlers.end()) {
6319 Value *invertedReturn = nullptr;
6320 auto ifound = gutils->invertedPointers.find(&call);
6321 if (ifound != gutils->invertedPointers.end()) {
6322 invertedReturn = cast<PHINode>(&*ifound->second);
6323 }
6324
6325 Value *normalReturn = subretused ? newCall : nullptr;
6326
6327 bool noMod = found->second(BuilderZ, &call, *gutils, normalReturn,
6328 invertedReturn);
6329 if (noMod) {
6330 if (subretused)
6331 assert(normalReturn == newCall);
6332 eraseIfUnused(call);
6333 }
6334
6335 ifound = gutils->invertedPointers.find(&call);
6336 if (ifound != gutils->invertedPointers.end()) {
6337 auto placeholder = cast<PHINode>(&*ifound->second);
6338 if (invertedReturn && invertedReturn != placeholder) {
6339 if (invertedReturn->getType() !=
6340 gutils->getShadowType(call.getType())) {
6341 llvm::errs() << " o: " << call << "\n";
6342 llvm::errs() << " ot: " << *call.getType() << "\n";
6343 llvm::errs() << " ir: " << *invertedReturn << "\n";
6344 llvm::errs() << " irt: " << *invertedReturn->getType() << "\n";
6345 llvm::errs() << " p: " << *placeholder << "\n";
6346 llvm::errs() << " PT: " << *placeholder->getType() << "\n";
6347 llvm::errs() << " newCall: " << *newCall << "\n";
6348 llvm::errs() << " newCallT: " << *newCall->getType() << "\n";
6349 }
6350 assert(invertedReturn->getType() ==
6351 gutils->getShadowType(call.getType()));
6352 placeholder->replaceAllUsesWith(invertedReturn);
6353 gutils->erase(placeholder);
6354 gutils->invertedPointers.insert(
6355 std::make_pair((const Value *)&call,
6356 InvertedPointerVH(gutils, invertedReturn)));
6357 } else {
6358 gutils->invertedPointers.erase(&call);
6359 gutils->erase(placeholder);
6360 }
6361 }
6362
6363 if (normalReturn && normalReturn != newCall) {
6364 assert(normalReturn->getType() == newCall->getType());
6365 gutils->replaceAWithB(newCall, normalReturn);
6366 gutils->erase(newCall);
6367 }
6368 return;
6369 }
6370 }
6371
6375 auto found = customCallHandlers.find(funcName);
6376 if (found != customCallHandlers.end()) {
6377 IRBuilder<> Builder2(&call);
6380 getReverseBuilder(Builder2);
6381
6382 Value *invertedReturn = nullptr;
6383 auto ifound = gutils->invertedPointers.find(&call);
6384 PHINode *placeholder = nullptr;
6385 if (ifound != gutils->invertedPointers.end()) {
6386 placeholder = cast<PHINode>(&*ifound->second);
6387 if (shadowReturnUsed)
6388 invertedReturn = placeholder;
6389 }
6390
6391 Value *normalReturn = subretused ? newCall : nullptr;
6392
6393 Value *tape = nullptr;
6394
6395 Type *tapeType = nullptr;
6396
6399 bool noMod = found->second.first(BuilderZ, &call, *gutils,
6400 normalReturn, invertedReturn, tape);
6401 if (noMod) {
6402 if (subretused)
6403 assert(normalReturn == newCall);
6404 eraseIfUnused(call);
6405 }
6406 if (tape) {
6407 tapeType = tape->getType();
6408 gutils->cacheForReverse(BuilderZ, tape,
6409 getIndex(&call, CacheType::Tape, BuilderZ));
6410 }
6412 assert(augmentedReturn);
6413 auto subaugmentations =
6414 (std::map<const llvm::CallInst *, AugmentedReturn *>
6415 *)&augmentedReturn->subaugmentations;
6417 *subaugmentations, &call, (AugmentedReturn *)tapeType);
6418 }
6419 }
6420
6424 augmentedReturn->tapeIndices.find(
6425 std::make_pair(&call, CacheType::Tape)) !=
6426 augmentedReturn->tapeIndices.end()) {
6427 assert(augmentedReturn);
6428 auto subaugmentations =
6429 (std::map<const llvm::CallInst *, AugmentedReturn *>
6430 *)&augmentedReturn->subaugmentations;
6431 auto fd = subaugmentations->find(&call);
6432 assert(fd != subaugmentations->end());
6433 // Note we are using the storage space here to persist
6434 // the LLVM type, as storing a new augmentedReturn has issues
6435 // regarding persisting the data structure, and when it will
6436 // be freed, since it will no longer live in the map in
6437 // EnzymeLogic.
6438 tapeType = (llvm::Type *)fd->second;
6439
6440#if LLVM_VERSION_MAJOR >= 18
6441 auto It = BuilderZ.GetInsertPoint();
6442 It.setHeadBit(true);
6443 BuilderZ.SetInsertPoint(It);
6444#endif
6445 tape = BuilderZ.CreatePHI(tapeType, 0);
6446 tape = gutils->cacheForReverse(
6447 BuilderZ, tape, getIndex(&call, CacheType::Tape, BuilderZ),
6448 /*ignoreType*/ true);
6449 }
6450 if (tape)
6451 tape = gutils->lookupM(tape, Builder2);
6452 found->second.second(Builder2, &call, *(DiffeGradientUtils *)gutils,
6453 tape);
6454 }
6455
6456 if (placeholder) {
6457 if (!shadowReturnUsed) {
6458 gutils->invertedPointers.erase(&call);
6459 gutils->erase(placeholder);
6460 } else {
6461 if (invertedReturn && invertedReturn != placeholder) {
6462 if (invertedReturn->getType() !=
6463 gutils->getShadowType(call.getType())) {
6464 llvm::errs() << " o: " << call << "\n";
6465 llvm::errs() << " ot: " << *call.getType() << "\n";
6466 llvm::errs() << " ir: " << *invertedReturn << "\n";
6467 llvm::errs() << " irt: " << *invertedReturn->getType() << "\n";
6468 llvm::errs() << " p: " << *placeholder << "\n";
6469 llvm::errs() << " PT: " << *placeholder->getType() << "\n";
6470 llvm::errs() << " newCall: " << *newCall << "\n";
6471 llvm::errs() << " newCallT: " << *newCall->getType() << "\n";
6472 }
6473 assert(invertedReturn->getType() ==
6474 gutils->getShadowType(call.getType()));
6475 placeholder->replaceAllUsesWith(invertedReturn);
6476 gutils->erase(placeholder);
6477 invertedReturn = gutils->cacheForReverse(
6478 BuilderZ, invertedReturn,
6479 getIndex(&call, CacheType::Shadow, BuilderZ));
6480 } else {
6481 auto idx = getIndex(&call, CacheType::Shadow, BuilderZ);
6482 invertedReturn =
6483 gutils->cacheForReverse(BuilderZ, placeholder, idx);
6484 if (idx == IndexMappingError) {
6485 if (placeholder->getType() != invertedReturn->getType())
6486 llvm::errs() << " place: " << *placeholder
6487 << " invRet: " << *invertedReturn;
6488 placeholder->replaceAllUsesWith(invertedReturn);
6489 gutils->erase(placeholder);
6490 }
6491 }
6492
6493 gutils->invertedPointers.insert(
6494 std::make_pair((const Value *)&call,
6495 InvertedPointerVH(gutils, invertedReturn)));
6496 }
6497 }
6498
6499 bool primalNeededInReverse;
6500
6501 if (gutils->knownRecomputeHeuristic.count(&call)) {
6502 primalNeededInReverse = !gutils->knownRecomputeHeuristic[&call];
6503 } else {
6504 std::map<UsageKey, bool> Seen;
6505 for (auto pair : gutils->knownRecomputeHeuristic)
6506 if (!pair.second)
6507 Seen[UsageKey(pair.first, QueryType::Primal)] = false;
6508 primalNeededInReverse =
6510 QueryType::Primal>(gutils, &call, Mode, Seen, oldUnreachable);
6511 }
6512 if (subretused && primalNeededInReverse) {
6513 if (normalReturn != newCall) {
6514 assert(normalReturn->getType() == newCall->getType());
6515 gutils->replaceAWithB(newCall, normalReturn);
6516 BuilderZ.SetInsertPoint(newCall->getNextNode());
6517 gutils->erase(newCall);
6518 }
6519 normalReturn = gutils->cacheForReverse(
6520 BuilderZ, normalReturn,
6521 getIndex(&call, CacheType::Self, BuilderZ));
6522 } else {
6523 if (normalReturn && normalReturn != newCall) {
6524 assert(normalReturn->getType() == newCall->getType());
6526 gutils->replaceAWithB(newCall, normalReturn);
6527 BuilderZ.SetInsertPoint(newCall->getNextNode());
6528 gutils->erase(newCall);
6529 } else if (Mode == DerivativeMode::ReverseModeGradient)
6530 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
6531 }
6532 return;
6533 }
6534 }
6535
6536 if (called) {
6537 if (funcName == "__kmpc_fork_call") {
6538 visitOMPCall(call);
6539 return;
6540 }
6541 }
6542
6543 if (handleKnownCallDerivatives(call, called, funcName,
6544 subsequent_calls_may_write, overwritten_args,
6545 newCall))
6546 return;
6547
6548 bool useConstantFallback =
6550 gutils, call, QueryType::Primal, nullptr);
6551 if (!useConstantFallback) {
6552 if (gutils->isConstantInstruction(&call) &&
6553 gutils->isConstantValue(&call)) {
6554 EmitWarning("ConstnatFallback", call,
6555 "Call was deduced inactive but still doing differential "
6556 "rewrite as it may escape an allocation",
6557 call);
6558 }
6559 }
6560 if (useConstantFallback) {
6561 if (!gutils->isConstantValue(&call)) {
6562 auto found = gutils->invertedPointers.find(&call);
6563 if (found != gutils->invertedPointers.end()) {
6564 PHINode *placeholder = cast<PHINode>(&*found->second);
6565 gutils->invertedPointers.erase(found);
6566 gutils->erase(placeholder);
6567 }
6568 }
6569 bool noFree = Mode == DerivativeMode::ForwardMode ||
6571 noFree |= call.hasFnAttr(Attribute::NoFree);
6572 if (!noFree && called) {
6573 noFree |= called->hasFnAttribute(Attribute::NoFree);
6574 }
6575
6576 std::map<UsageKey, bool> CacheResults;
6577 for (auto pair : gutils->knownRecomputeHeuristic) {
6578 if (!pair.second || gutils->unnecessaryIntermediates.count(
6579 cast<Instruction>(pair.first))) {
6580 CacheResults[UsageKey(pair.first, QueryType::Primal)] = false;
6581 }
6582 }
6583
6584 if (!noFree && !EnzymeGlobalActivity) {
6585 bool mayActiveFree = false;
6586 for (unsigned i = 0; i < call.arg_size(); ++i) {
6587 Value *a = call.getOperand(i);
6588
6589 if (EnzymeJuliaAddrLoad && isSpecialPtr(a->getType()))
6590 continue;
6591 // if could not be a pointer, it cannot be freed
6592 if (!TR.query(a)[{-1}].isPossiblePointer())
6593 continue;
6594 // if active value, we need to do memory preservation
6595 if (!gutils->isConstantValue(a)) {
6596 mayActiveFree = true;
6597 break;
6598 }
6599 // if used in reverse (even if just primal), need to do
6600 // memory preservation
6601 const auto obj = getBaseObject(a);
6602 // If not allocation/allocainst, it is possible this aliases
6603 // a pointer needed in the reverse pass
6604 bool isAllocation = false;
6605 for (auto objv = obj;;) {
6606 if (isAllocationCall(objv, gutils->TLI)) {
6607 isAllocation = true;
6608 break;
6609 }
6610 if (auto objC = dyn_cast<CallBase>(objv))
6611 if (auto F = getFunctionFromCall(objC))
6612 if (!F->empty()) {
6613 SmallPtrSet<Value *, 1> set;
6614 for (auto &B : *F) {
6615 if (auto RI = dyn_cast<ReturnInst>(B.getTerminator())) {
6616 auto v = getBaseObject(RI->getOperand(0));
6617 if (isa<ConstantPointerNull>(v))
6618 continue;
6619 set.insert(v);
6620 }
6621 }
6622 if (set.size() == 1) {
6623 objv = *set.begin();
6624 continue;
6625 }
6626 }
6627 break;
6628 }
6629 if (!isAllocation) {
6630 mayActiveFree = true;
6631 break;
6632 }
6633 {
6634 auto found = gutils->knownRecomputeHeuristic.find(obj);
6635 if (found != gutils->knownRecomputeHeuristic.end()) {
6636 if (!found->second) {
6637 auto CacheResults2(CacheResults);
6638 CacheResults2.erase(UsageKey(obj, QueryType::Primal));
6640 QueryType::Primal>(gutils, obj,
6642 CacheResults2, oldUnreachable)) {
6643 mayActiveFree = true;
6644 break;
6645 }
6646 }
6647 continue;
6648 }
6649 }
6650 auto CacheResults2(CacheResults);
6652 QueryType::Primal>(gutils, obj,
6654 CacheResults2, oldUnreachable)) {
6655 mayActiveFree = true;
6656 break;
6657 }
6658 }
6659 if (!mayActiveFree)
6660 noFree = true;
6661 }
6662 if (!noFree) {
6663 auto callval = call.getCalledOperand();
6664 if (!isa<Constant>(callval))
6665 callval = gutils->getNewFromOriginal(callval);
6666 newCall->setCalledOperand(gutils->Logic.CreateNoFree(
6667 RequestContext(&call, &BuilderZ), callval));
6668 }
6669 if (gutils->knownRecomputeHeuristic.find(&call) !=
6670 gutils->knownRecomputeHeuristic.end()) {
6671 if (!gutils->knownRecomputeHeuristic[&call]) {
6672 gutils->cacheForReverse(BuilderZ, newCall,
6673 getIndex(&call, CacheType::Self, BuilderZ));
6674 eraseIfUnused(call);
6675 return;
6676 }
6677 }
6678
6679 // If we need this value and it is illegal to recompute it (it writes or
6680 // may load overwritten data)
6681 // Store and reload it
6684 Mode != DerivativeMode::ForwardModeError && subretused &&
6685 (call.mayWriteToMemory() ||
6686 !gutils->legalRecompute(&call, ValueToValueMapTy(), nullptr))) {
6687 if (!gutils->unnecessaryIntermediates.count(&call)) {
6688
6689 std::map<UsageKey, bool> Seen;
6690 bool primalNeededInReverse = false;
6691 for (auto pair : gutils->knownRecomputeHeuristic)
6692 if (!pair.second) {
6693 if (pair.first == &call) {
6694 primalNeededInReverse = true;
6695 break;
6696 } else {
6697 Seen[UsageKey(pair.first, QueryType::Primal)] = false;
6698 }
6699 }
6700 if (!primalNeededInReverse) {
6701
6702 auto minCutMode = (Mode == DerivativeMode::ReverseModePrimal)
6704 : Mode;
6705 primalNeededInReverse =
6707 QueryType::Primal>(gutils, &call, minCutMode, Seen,
6708 oldUnreachable);
6709 }
6710 if (primalNeededInReverse) {
6711 gutils->cacheForReverse(BuilderZ, newCall,
6712 getIndex(&call, CacheType::Self, BuilderZ));
6713 eraseIfUnused(call);
6714 return;
6715 }
6716 }
6717 // Force erasure in reverse pass, since cached if needed
6720 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
6721 else
6722 eraseIfUnused(call);
6723 return;
6724 }
6725
6726 // If this call may write to memory and is a copy (in the just reverse
6727 // pass), erase it
6728 // Any uses of it should be handled by the case above so it is safe to
6729 // RAUW
6730 if (call.mayWriteToMemory() &&
6733 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
6734 return;
6735 }
6736
6737 // if call does not write memory and isn't used, we can erase it
6738 if (!call.mayWriteToMemory() && !subretused) {
6739 eraseIfUnused(call, /*erase*/ true, /*check*/ false);
6740 return;
6741 }
6742
6743 return;
6744 }
6745
6747 call, called, subsequent_calls_may_write, overwritten_args,
6748 shadowReturnUsed, subretType, subretused);
6749 }
6750};
6751
6752#endif // ENZYME_ADJOINT_GENERATOR_H
static bool isZero(llvm::Constant *cst)
std::pair< const llvm::Value *, QueryType > UsageKey
FunctionType * getFunctionTypeForClone(llvm::FunctionType *FTy, DerivativeMode mode, unsigned width, llvm::Type *additionalArg, llvm::ArrayRef< DIFFE_TYPE > constant_args, bool diffeReturnArg, bool returnTape, bool returnPrimal, bool returnShadow)
static bool isAllocationCall(const llvm::Value *TmpOrig, llvm::TargetLibraryInfo &TLI)
static bool isReadOnly(Operation *op)
static Operation * getFunctionFromCall(CallOpInterface iface)
llvm::cl::opt< bool > EnzymeGlobalActivity
constexpr const char * to_string(ActivityAnalyzer::UseActivity UA)
std::pair< SmallVector< Type *, 4 >, SmallVector< Type *, 4 > > getDefaultFunctionTypeForAugmentation(FunctionType *called, bool returnUsed, DIFFE_TYPE retType)
assuming not top level
bool legalCombinedForwardReverse(CallInst *origop, const std::map< ReturnInst *, StoreInst * > &replacedReturns, SmallVectorImpl< Instruction * > &postCreate, SmallVectorImpl< Instruction * > &userReplace, const GradientUtils *gutils, const SmallPtrSetImpl< const Instruction * > &unnecessaryInstructions, const SmallPtrSetImpl< BasicBlock * > &oldUnreachable, const bool subretused)
bool shouldAugmentCall(CallInst *op, const GradientUtils *gutils)
std::pair< SmallVector< Type *, 4 >, SmallVector< Type *, 4 > > getDefaultFunctionTypeForGradient(FunctionType *called, DIFFE_TYPE retType, ArrayRef< DIFFE_TYPE > tys)
assuming not top level
CacheType
Definition EnzymeLogic.h:80
static std::string str(AugmentedStruct c)
Definition EnzymeLogic.h:62
llvm::cl::opt< bool > looseTypeAnalysis
llvm::cl::opt< bool > nonmarkedglobals_inactiveloads
StringMap< std::pair< std::function< bool(IRBuilder<> &, CallInst *, GradientUtils &, Value *&, Value *&, Value *&)>, std::function< void(IRBuilder<> &, CallInst *, DiffeGradientUtils &, Value *)> > > customCallHandlers
SmallVector< unsigned int, 9 > MD_ToCopy
StringMap< std::function< bool(IRBuilder<> &, CallInst *, GradientUtils &, Value *&, Value *&)> > customFwdCallHandlers
void SubTransferHelper(GradientUtils *gutils, DerivativeMode mode, Type *secretty, Intrinsic::ID intrinsic, unsigned dstalign, unsigned srcalign, unsigned offset, bool dstConstant, Value *shadow_dst, bool srcConstant, Value *shadow_src, Value *length, Value *isVolatile, llvm::CallInst *MTI, bool allowForward, bool shadowsLookedUp, bool backwardsShadow)
constexpr int IndexMappingError
static TypeTree parseTBAA(TBAAStructTypeNode AccessType, llvm::Instruction &I, const llvm::DataLayout &DL, std::shared_ptr< llvm::ModuleSlotTracker > MST)
Given a TBAA access node return the corresponding TypeTree This includes recursively parsing the acce...
Definition TBAA.h:439
TypeTree defaultTypeTreeForLLVM(llvm::Type *ET, llvm::Instruction *I, bool intIsPointer)
static bool isMemFreeLibMFunction(llvm::StringRef str, llvm::Intrinsic::ID *ID=nullptr)
CallInst * CreateDealloc(llvm::IRBuilder<> &Builder, llvm::Value *ToFree)
Definition Utils.cpp:742
Function * getOrInsertDifferentialFloatMemcpy(Module &M, Type *elementType, unsigned dstalign, unsigned srcalign, unsigned dstaddr, unsigned srcaddr, unsigned bitwidth)
Create function for type that is equivalent to memcpy but adds to destination rather than a direct co...
Definition Utils.cpp:1009
void EmitNoTypeError(const std::string &message, llvm::Instruction &inst, GradientUtils *gutils, llvm::IRBuilder<> &Builder2)
Definition Utils.cpp:4377
LLVMValueRef(* CustomErrorHandler)(const char *, LLVMValueRef, ErrorType, const void *, LLVMValueRef, LLVMBuilderRef)
Definition Utils.cpp:62
llvm::Value * EmitNoDerivativeError(const std::string &message, llvm::Instruction &inst, GradientUtils *gutils, llvm::IRBuilder<> &Builder2, llvm::Value *condition)
Definition Utils.cpp:4295
void ErrorIfRuntimeInactive(llvm::IRBuilder<> &B, llvm::Value *primal, llvm::Value *shadow, const char *Message, llvm::DebugLoc &&loc, llvm::Instruction *orig)
Definition Utils.cpp:902
llvm::FastMathFlags getFast()
Get LLVM fast math flags.
Definition Utils.cpp:3731
llvm::Constant * getUndefinedValueForType(llvm::Module &M, llvm::Type *T, bool forceZero)
Definition Utils.cpp:3623
std::vector< std::tuple< llvm::Type *, size_t, size_t > > parseTrueType(const llvm::MDNode *md, DerivativeMode Mode, bool const_src)
Definition Utils.cpp:4411
static llvm::Value * checkedMul(bool strongZero, llvm::IRBuilder<> &Builder2, llvm::Value *idiff, llvm::Value *pres, const llvm::Twine &Name="")
Definition Utils.h:2018
static bool isIntelSubscriptIntrinsic(const llvm::IntrinsicInst &II)
Definition Utils.h:1445
static std::string getRenamedPerCallingConv(llvm::StringRef caller, llvm::StringRef callee)
Definition Utils.h:2413
static llvm::PointerType * getPointerType(llvm::Type *T, unsigned AddressSpace=0)
Definition Utils.h:1165
static void dumpMap(const llvm::ValueMap< T, N > &o, llvm::function_ref< bool(const llvm::Value *)> shouldPrint=[](T) { return true;})
Print a map, optionally with a shouldPrint function to decide to print a given value.
Definition Utils.h:286
@ Args
Return is a struct of all args.
static bool startsWith(llvm::StringRef string, llvm::StringRef prefix)
Definition Utils.h:713
DIFFE_TYPE
Potential differentiable argument classifications.
Definition Utils.h:374
static llvm::PointerType * getUnqual(llvm::Type *T)
Definition Utils.h:1179
void EmitFailure(llvm::StringRef RemarkName, const llvm::DiagnosticLocation &Loc, const llvm::Instruction *CodeRegion, Args &...args)
Definition Utils.h:203
static llvm::PointerType * getInt8PtrTy(llvm::LLVMContext &Context, unsigned AddressSpace=0)
Definition Utils.h:1174
static std::map< K, V >::iterator insert_or_assign2(std::map< K, V > &map, K key, V val)
Insert into a map.
Definition Utils.h:846
static llvm::Function * getIntrinsicDeclaration(llvm::Module *M, llvm::Intrinsic::ID id, llvm::ArrayRef< llvm::Type * > Tys={})
Definition Utils.h:2263
static bool isNoCapture(const llvm::CallBase *call, size_t idx)
Definition Utils.h:1840
static void addFunctionNoCapture(llvm::Function *call, size_t idx)
Definition Utils.h:2299
static llvm::Value * getBaseObject(llvm::Value *V, bool offsetAllowed=true)
Definition Utils.h:1507
static llvm::MDNode * hasMetadata(const llvm::GlobalObject *O, llvm::StringRef kind)
Check if a global has metadata.
Definition Utils.h:339
static bool containsOnlyAtMostTopBit(const llvm::Value *V, llvm::Type *FT, const llvm::DataLayout &dl, llvm::Type **vFT=nullptr)
Definition Utils.h:2047
static bool shouldDisableNoWrite(const llvm::CallInst *CI)
Definition Utils.h:1423
void EmitWarning(llvm::StringRef RemarkName, const llvm::DiagnosticLocation &Loc, const llvm::BasicBlock *BB, const Args &...args)
Definition Utils.h:133
static bool isSpecialPtr(llvm::Type *Ty)
Definition Utils.h:2354
llvm::cl::opt< bool > EnzymeJuliaAddrLoad
@ MixedActivityError
static llvm::Type * IntToFloatTy(llvm::Type *T)
Convert a integer type to a floating point type of the same size.
Definition Utils.h:665
static llvm::Value * CreateSelect(llvm::IRBuilder<> &Builder2, llvm::Value *cmp, llvm::Value *tval, llvm::Value *fval, const llvm::Twine &Name="")
Definition Utils.h:2005
static std::string convertSRetTypeToString(llvm::Type *T)
Definition Utils.h:2428
static llvm::StringRef getFuncNameFromCall(const llvm::CallBase *op)
Definition Utils.h:1269
static bool isWriteOnly(const llvm::Function *F, ssize_t arg=-1)
Definition Utils.h:1788
ValueType
Classification of value as an original program variable, a derivative variable, neither,...
Definition Utils.h:409
static llvm::Attribute::AttrKind PrimalParamAttrsToPreserve[]
Definition Utils.h:2208
static llvm::Attribute::AttrKind ShadowParamAttrsToPreserve[]
Definition Utils.h:2240
DerivativeMode
Definition Utils.h:390
static DIFFE_TYPE whatType(llvm::Type *arg, DerivativeMode mode, bool integersAreConstant, std::set< llvm::Type * > &seen)
Attempt to automatically detect the differentiable classification based off of a given type.
Definition Utils.h:519
llvm::Value * lookup(llvm::Value *val, llvm::IRBuilder<> &Builder)
void visitExtractElementInst(llvm::ExtractElementInst &EEI)
void forwardModeInvertedPointerFallback(llvm::Instruction &I)
llvm::Value * diffe(llvm::Value *val, llvm::IRBuilder<> &Builder)
void visitShuffleVectorInst(llvm::ShuffleVectorInst &SVI)
void createSelectInstAdjoint(llvm::SelectInst &SI)
void DifferentiableMemCopyFloats(llvm::CallInst &call, llvm::Value *origArg, llvm::Value *dsto, llvm::Value *srco, llvm::Value *len_arg, llvm::IRBuilder<> &Builder2, llvm::ArrayRef< llvm::OperandBundleDef > ReverseDefs)
void applyChainRule(llvm::ArrayRef< llvm::Value * > diffs, llvm::IRBuilder<> &Builder, Func rule)
Unwraps an collection of constant vector derivatives from their internal representations and applies ...
bool handleKnownCallDerivatives(llvm::CallInst &call, llvm::Function *called, llvm::StringRef funcName, bool subsequent_calls_may_write, const std::vector< bool > &overwritten_args, llvm::CallInst *const newCall)
void setDiffe(llvm::Value *val, llvm::Value *dif, llvm::IRBuilder<> &Builder)
void visitBinaryOperator(llvm::BinaryOperator &BO)
void visitInsertValueInst(llvm::InsertValueInst &IVI)
void visitCallInst(llvm::CallInst &call)
void visitGetElementPtrInst(llvm::GetElementPtrInst &gep)
llvm::Value * MPI_TYPE_SIZE(llvm::Value *DT, llvm::IRBuilder<> &B, llvm::Type *intType, llvm::Function *caller)
void visitFCmpInst(llvm::FCmpInst &I)
void visitExtractValueInst(llvm::ExtractValueInst &EVI)
void visitLoadLike(llvm::Instruction &I, llvm::MaybeAlign alignment, bool constantval, llvm::Value *mask=nullptr, llvm::Value *orig_maskInit=nullptr)
void visitCastInst(llvm::CastInst &I)
llvm::Value * applyChainRule(llvm::Type *diffType, llvm::IRBuilder<> &Builder, Func rule, Args... args)
Unwraps a vector derivative from its internal representation and applies a function f to each element...
void visitInsertElementInst(llvm::InsertElementInst &IEI)
void visitCommonStore(llvm::Instruction &I, llvm::Value *orig_ptr, llvm::Value *orig_val, llvm::MaybeAlign prevalign, bool isVolatile, llvm::AtomicOrdering ordering, llvm::SyncScope::ID syncScope, llvm::Value *mask)
void visitICmpInst(llvm::ICmpInst &I)
void visitFenceInst(llvm::FenceInst &FI)
void visitIntrinsicInst(llvm::IntrinsicInst &II)
void createBinaryOperatorAdjoint(llvm::BinaryOperator &BO)
void visitPHINode(llvm::PHINode &phi)
llvm::SmallVector< llvm::SelectInst *, 4 > addToDiffe(llvm::Value *val, llvm::Value *dif, llvm::IRBuilder<> &Builder, llvm::Type *T, llvm::Value *mask=nullptr)
void eraseIfUnused(llvm::Instruction &I, bool erase=true, bool check=true)
AdjointGenerator(DerivativeMode Mode, GradientUtils *gutils, llvm::ArrayRef< DIFFE_TYPE > constant_args, DIFFE_TYPE retType, std::function< unsigned(llvm::Instruction *, CacheType, llvm::IRBuilder<> &)> getIndex, const std::map< llvm::CallInst *, std::pair< bool, const std::vector< bool > > > overwritten_args_map, const AugmentedReturn *augmentedReturn, const std::map< llvm::ReturnInst *, llvm::StoreInst * > *replacedReturns, const llvm::SmallPtrSetImpl< const llvm::Value * > &unnecessaryValues, const llvm::SmallPtrSetImpl< const llvm::Instruction * > &unnecessaryInstructions, const llvm::SmallPtrSetImpl< const llvm::Instruction * > &unnecessaryStores, const llvm::SmallPtrSetImpl< llvm::BasicBlock * > &oldUnreachable)
void visitInstruction(llvm::Instruction &inst)
void visitAtomicRMWInst(llvm::AtomicRMWInst &I)
void visitLoadInst(llvm::LoadInst &LI)
void visitMemSetInst(llvm::MemSetInst &MS)
void visitMemTransferInst(llvm::MemTransferInst &MTI)
void visitOMPCall(llvm::CallInst &call)
void visitSelectInst(llvm::SelectInst &SI)
void createBinaryOperatorDual(llvm::BinaryOperator &BO)
void getForwardBuilder(llvm::IRBuilder<> &Builder2)
void getReverseBuilder(llvm::IRBuilder<> &Builder2, bool original=true)
void applyChainRule(llvm::IRBuilder<> &Builder, Func rule, Args... args)
Unwraps a vector derivative from its internal representation and applies a function f to each element...
void visitAllocaInst(llvm::AllocaInst &I)
void recursivelyHandleSubfunction(llvm::CallInst &call, llvm::Function *called, bool subsequent_calls_may_write, const std::vector< bool > &overwritten_args, bool shadowReturnUsed, DIFFE_TYPE subretType, bool subretused)
llvm::Value * MPI_COMM_SIZE(llvm::Value *comm, llvm::IRBuilder<> &B, llvm::Type *rankTy, llvm::Function *caller)
void visitMemTransferCommon(llvm::Intrinsic::ID ID, llvm::MaybeAlign srcAlign, llvm::MaybeAlign dstAlign, llvm::CallInst &MTI, llvm::Value *orig_dst, llvm::Value *orig_src, llvm::Value *new_size, llvm::Value *isVolatile)
llvm::Value * MPI_COMM_RANK(llvm::Value *comm, llvm::IRBuilder<> &B, llvm::Type *rankTy, llvm::Function *caller)
bool handleAdjointForIntrinsic(llvm::Intrinsic::ID ID, llvm::Instruction &I, llvm::SmallVectorImpl< llvm::Value * > &orig_ops)
void handleMPI(llvm::CallInst &call, llvm::Function *called, llvm::StringRef funcName)
void visitStoreInst(llvm::StoreInst &SI)
void visitMemSetCommon(llvm::CallInst &MS)
return structtype if recursive function
llvm::Function * fn
const std::vector< DIFFE_TYPE > constant_args
std::set< ssize_t > tapeIndiciesToFree
std::map< const llvm::CallInst *, const AugmentedReturn * > subaugmentations
Map from original call to sub augmentation data.
std::map< std::pair< llvm::Instruction *, CacheType >, int > tapeIndices
llvm::Type * tapeType
return structtype if recursive function
std::map< AugmentedStruct, int > returns
Map from information desired from a augmented return to its index in the returned struct.
llvm::Function *const newFunc
The function whose instructions we are caching.
llvm::TargetLibraryInfo & TLI
Various analysis results of newFunc.
llvm::BasicBlock * inversionAllocs
Concrete SubType of a given value.
llvm::Function * CreateNoFree(RequestContext context, llvm::Function *todiff)
const AugmentedReturn & CreateAugmentedPrimal(RequestContext context, llvm::Function *todiff, DIFFE_TYPE retType, llvm::ArrayRef< DIFFE_TYPE > constant_args, TypeAnalysis &TA, bool returnUsed, bool shadowReturnUsed, const FnTypeInfo &typeInfo, bool subsequent_calls_may_write, const std::vector< bool > _overwritten_args, bool forceAnonymousTape, bool runtimeActivity, bool strongZero, unsigned width, bool AtomicAdd, bool omp=false)
Create an augmented forward pass.
llvm::Function * CreateForwardDiff(RequestContext context, llvm::Function *todiff, DIFFE_TYPE retType, llvm::ArrayRef< DIFFE_TYPE > constant_args, TypeAnalysis &TA, bool returnValue, DerivativeMode mode, bool freeMemory, bool runtimeActivity, bool strongZero, unsigned width, llvm::Type *additionalArg, const FnTypeInfo &typeInfo, bool subsequent_calls_may_write, const std::vector< bool > _overwritten_args, const AugmentedReturn *augmented, bool omp=false)
Create the forward (or forward split) mode derivative function.
llvm::Function * CreatePrimalAndGradient(RequestContext context, const ReverseCacheKey &&key, TypeAnalysis &TA, const AugmentedReturn *augmented, bool omp=false)
Create the reverse pass, or combined forward+reverse derivative function.
llvm::ValueMap< const llvm::Instruction *, AssertingReplacingVH > unwrappedLoads
llvm::SmallPtrSet< llvm::Instruction *, 4 > unnecessaryIntermediates
llvm::ValueMap< const llvm::Value *, InvertedPointerVH > invertedPointers
bool getContext(llvm::BasicBlock *BB, LoopContext &lc)
const std::map< llvm::Instruction *, bool > * can_modref_map
llvm::DebugLoc getNewFromOriginal(const llvm::DebugLoc L) const
static llvm::Value * extractMeta(llvm::IRBuilder<> &Builder, llvm::Value *Agg, unsigned off, const llvm::Twine &name="")
Helper routine to extract a nested element from a struct/array. This is.
llvm::SmallPtrSet< llvm::Instruction *, 4 > TapesToPreventRecomputation
A set of tape extractions to enforce a cache of rather than attempting to recompute.
void eraseWithPlaceholder(llvm::Instruction *I, llvm::Instruction *orig, const llvm::Twine &suffix="_replacementA", bool erase=true)
bool legalRecompute(const llvm::Value *val, const llvm::ValueToValueMapTy &available, llvm::IRBuilder<> *BuilderM, bool reverse=false, bool legalRecomputeCache=true) const
TypeResults TR
std::map< const llvm::Value *, bool > knownRecomputeHeuristic
llvm::Value * getOrInsertTotalMultiplicativeProduct(llvm::Value *val, LoopContext &lc)
llvm::BasicBlock * addReverseBlock(llvm::BasicBlock *currentBlock, llvm::Twine const &name, bool forkCache=true, bool push=true)
llvm::ValueMap< llvm::Value *, Rematerializer > rematerializableAllocations
llvm::ValueMap< llvm::PHINode *, llvm::WeakTrackingVH > fictiousPHIs
llvm::BasicBlock * getOriginalFromNew(const llvm::BasicBlock *newinst) const
unsigned getWidth()
llvm::Function * oldFunc
llvm::SmallVector< llvm::OperandBundleDef, 2 > getInvertedBundles(llvm::CallInst *orig, llvm::ArrayRef< ValueType > types, llvm::IRBuilder<> &Builder2, bool lookup, const llvm::ValueToValueMapTy &available=llvm::ValueToValueMapTy())
llvm::LoopInfo * OrigLI
void replaceAWithB(llvm::Value *A, llvm::Value *B, bool storeInCache=false) override
Replace this instruction both in LLVM modules and any local data-structures.
EnzymeLogic & Logic
DIFFE_TYPE getDiffeType(llvm::Value *v, bool foreignFunction) const
void replaceAndRemoveUnwrapCacheFor(llvm::Value *A, llvm::Value *B)
llvm::Value * isOriginal(const llvm::Value *newinst) const
llvm::Value * unwrapM(llvm::Value *const val, llvm::IRBuilder<> &BuilderM, const llvm::ValueToValueMapTy &available, UnwrapMode unwrapMode, llvm::BasicBlock *scope=nullptr, bool permitCache=true) override final
if full unwrap, don't just unwrap this instruction, but also its operands, etc
llvm::Value * lookupM(llvm::Value *val, llvm::IRBuilder<> &BuilderM, const llvm::ValueToValueMapTy &incoming_availalble=llvm::ValueToValueMapTy(), bool tryLegalRecomputeCheck=true, llvm::BasicBlock *scope=nullptr) override
High-level utility to get the value an instruction at a new location specified by BuilderM.
llvm::MDNode * getDerivativeAliasScope(const llvm::Value *origptr, ssize_t newptr)
llvm::ValueMap< const llvm::Value *, AssertingReplacingVH > originalToNewFn
llvm::Value * getOrInsertConditionalIndex(llvm::Value *val, LoopContext &lc, bool pickTrue)
DIFFE_TYPE getReturnDiffeType(llvm::Value *orig, bool *primalReturnUsedP, bool *shadowReturnUsedP, DerivativeMode cmode) const
void setPtrDiffe(llvm::Instruction *orig, llvm::Value *ptr, llvm::Value *newval, llvm::IRBuilder<> &BuilderM, llvm::MaybeAlign align, unsigned start, unsigned size, bool isVolatile, llvm::AtomicOrdering ordering, llvm::SyncScope::ID syncScope, llvm::Value *mask, llvm::ArrayRef< llvm::Metadata * > noAlias, llvm::ArrayRef< llvm::Metadata * > scopes, bool needs_post_cache=false)
bool isConstantInstruction(const llvm::Instruction *inst) const
llvm::ValueMap< const llvm::Value *, AssertingReplacingVH > newToOriginalFn
static llvm::Type * getShadowType(llvm::Type *ty, unsigned width)
llvm::Value * cacheForReverse(llvm::IRBuilder<> &BuilderQ, llvm::Value *malloc, int idx, bool replace=true)
llvm::ValueMap< llvm::Value *, ShadowRematerializer > backwardsOnlyShadows
Only loaded from and stored to (not captured), mapped to the stores (and memset).
bool isConstantValue(llvm::Value *val) const
llvm::Value * invertPointerM(llvm::Value *val, llvm::IRBuilder<> &BuilderM, bool nullShadow=false)
void erase(llvm::Instruction *I) override
Erase this instruction both from LLVM modules and any local data-structures.
TypeAnalysis & interprocedural
Calling TypeAnalysis to be used in the case of calls to other functions.
std::map< llvm::Value *, TypeTree > analysis
Intermediate conservative, but correct Type analysis results.
A holder class representing the results of running TypeAnalysis on a given function.
ConcreteType intType(size_t num, llvm::Value *val, bool errIfNotFound=true, bool pointerIntSame=false) const
llvm::Type * addingType(size_t num, llvm::Value *val, size_t start=0) const
void dump(llvm::raw_ostream &ss=llvm::errs()) const
Prints all known information.
TypeAnalyzer * analyzer
bool anyPointer(llvm::Value *val) const
Whether any part of the top level register can contain a pointer e.g.
std::set< int64_t > knownIntegralValues(llvm::Value *val) const
The set of values val will take on during this program.
TypeTree query(llvm::Value *val) const
The TypeTree of a particular Value.
llvm::Function * getFunction() const
FnTypeInfo getCallInfo(llvm::CallBase &CI, llvm::Function &fn) const
ConcreteType firstPointer(size_t num, llvm::Value *val, llvm::Instruction *I, bool errIfNotFound=true, bool pointerIntSame=false) const
Returns whether in the first num bytes there is pointer, int, float, or none If pointerIntSame is set...
Class representing the underlying types of values as sequences of offsets to a ConcreteType.
Definition TypeTree.h:72
TypeTree Only(int Off, llvm::Instruction *orig) const
Prepend an offset to all mappings.
Definition TypeTree.h:471
TypeTree Data0() const
Peel off the outermost index at offset 0.
Definition TypeTree.h:513
static TypeTree parse(llvm::StringRef str, llvm::LLVMContext &ctx)
Definition TypeTree.h:86
TypeTree ShiftIndices(const llvm::DataLayout &dl, const int offset, const int maxSize, size_t addOffset=0) const
Replace mappings in the range in [offset, offset+maxSize] with those in.
Definition TypeTree.h:840
TypeTree Lookup(size_t len, const llvm::DataLayout &dl) const
Select all submappings whose first index is in range [0, len) and remove the first index.
Definition TypeTree.h:593
llvm::Type * IsAllFloat(const size_t size, const llvm::DataLayout &dl) const
Definition TypeTree.h:814
std::string str() const
Returns a string representation of this TypeTree.
Definition TypeTree.h:1383
bool insert(const std::vector< int > Seq, ConcreteType CT, bool PointerIntSame=false)
Return if changed.
Definition TypeTree.h:234
TypeTree PurgeAnything() const
Keep only mappings where the type is not an Anything
Definition TypeTree.h:1041
bool callShouldNotUseDerivative(const GradientUtils *gutils, llvm::CallBase &orig, QueryType qtype, const llvm::Value *val)
Return whether or not this is a constant and should use reverse pass.
bool is_value_needed_in_reverse(const GradientUtils *gutils, const llvm::Value *inst, DerivativeMode mode, std::map< UsageKey, bool > &seen, const llvm::SmallPtrSetImpl< llvm::BasicBlock * > &oldUnreachable)
Struct containing all contextual type information for a particular function call.
std::map< llvm::Argument *, TypeTree > Arguments
Types of arguments.
TypeTree Return
Type of return.
std::map< llvm::Argument *, std::set< int64_t > > KnownValues
The specific constant(s) known to represented by an argument, if constant.
Container for all loop information to synthesize gradients.
llvm::BasicBlock * header
Header of this loop.
llvm::AssertingVH< llvm::Instruction > incvar
Increment of the induction.
llvm::SmallPtrSet< llvm::BasicBlock *, 8 > exitBlocks
All blocks this loop exits too.
llvm::BasicBlock * preheader
Preheader of this loop.
todiff is the function to differentiate retType is the activity info of the return.