32#include "llvm/ADT/ArrayRef.h"
33#include "llvm/ADT/SmallPtrSet.h"
34#include "llvm/ADT/SmallVector.h"
36#include "llvm/IR/BasicBlock.h"
37#include "llvm/IR/DebugInfoMetadata.h"
38#include "llvm/IR/Dominators.h"
39#include "llvm/IR/IRBuilder.h"
40#include "llvm/IR/Instructions.h"
41#include "llvm/IR/Type.h"
42#include "llvm/IR/Value.h"
44#include "llvm/Transforms/Utils/BasicBlockUtils.h"
46#include "llvm/Support/Casting.h"
47#include "llvm/Support/ErrorHandling.h"
55bool elementwiseReadForContext(
const Instruction *orig,
const Value *origptr) {
57 if (
const Function *F = orig->getFunction()) {
58 if (F->hasFnAttribute(
"enzyme_elementwise_read")) {
64 if (
auto *arg = dyn_cast<Argument>(base)) {
65 if (
const Function *F = arg->getParent()) {
66 return F->getAttributes().hasParamAttr(arg->getArgNo(),
67 "enzyme_elementwise_read");
74DiffeGradientUtils::DiffeGradientUtils(
75 EnzymeLogic &Logic, Function *newFunc_, Function *oldFunc_,
77 ValueToValueMapTy &invertedPointers_,
78 const SmallPtrSetImpl<Value *> &constantvalues_,
79 const SmallPtrSetImpl<Value *> &returnvals_,
DIFFE_TYPE ActiveReturn,
80 bool shadowReturnUsed, ArrayRef<DIFFE_TYPE> constant_values,
81 llvm::ValueMap<const llvm::Value *, AssertingReplacingVH> &origToNew_,
82 DerivativeMode mode,
bool runtimeActivity,
bool strongZero,
unsigned width,
84 :
GradientUtils(Logic, newFunc_, oldFunc_, TLI, TA, TR, invertedPointers_,
85 constantvalues_, returnvals_, ActiveReturn,
86 shadowReturnUsed, constant_values, origToNew_, mode,
87 runtimeActivity, strongZero, width, omp) {
88 if (oldFunc_->empty())
90 assert(reverseBlocks.size() == 0);
96 for (BasicBlock *BB : originalBlocks) {
97 if (BB == inversionAllocs)
100 BasicBlock::Create(BB->getContext(),
"invert" + BB->getName(), newFunc);
101 reverseBlocks[BB].push_back(RBB);
102 reverseBlockToPrimal[RBB] = BB;
104 assert(reverseBlocks.size() != 0);
109 bool strongZero,
unsigned width, Function *todiff, TargetLibraryInfo &TLI,
111 bool shadowReturn,
bool diffeReturnArg, ArrayRef<DIFFE_TYPE> constant_args,
112 bool returnTape,
bool returnPrimal, Type *additionalArg,
bool omp) {
120 SmallPtrSet<Instruction *, 4> constants;
121 SmallPtrSet<Instruction *, 20> nonconstant;
122 SmallPtrSet<Value *, 2> returnvals;
123 llvm::ValueMap<const llvm::Value *, AssertingReplacingVH> originalToNew;
125 SmallPtrSet<Value *, 4> constant_values;
126 SmallPtrSet<Value *, 4> nonconstant_values;
141 llvm_unreachable(
"invalid DerivativeMode: ReverseModePrimal\n");
145 prefix += std::to_string(width);
149 nonconstant_values, returnvals, returnTape, returnPrimal,
151 prefix +
oldFunc->getName(), &originalToNew,
152 diffeReturnArg, additionalArg);
159 auto toarg = todiff->arg_begin();
160 auto olarg =
oldFunc->arg_begin();
161 for (; toarg != todiff->arg_end(); ++toarg, ++olarg) {
164 auto fd = oldTypeInfo.
Arguments.find(toarg);
165 assert(fd != oldTypeInfo.
Arguments.end());
167 std::pair<Argument *, TypeTree>(olarg, fd->second));
174 std::pair<Argument *, std::set<int64_t>>(olarg, cfd->second));
186 nonconstant_values, retType, shadowReturn, constant_args, originalToNew,
198 if (
auto arg = dyn_cast<Argument>(val))
199 assert(arg->getParent() ==
oldFunc);
200 if (
auto inst = dyn_cast<Instruction>(val))
201 assert(inst->getParent()->getParent() ==
oldFunc);
208 entryBuilder.setFastMathFlags(
getFast());
210 entryBuilder.CreateAlloca(type,
nullptr, val->getName() +
"'de");
212 oldFunc->getParent()->getDataLayout().getPrefTypeAlign(type);
217#if LLVM_VERSION_MAJOR < 17
218 if (val->getContext().supportsTypedPointers()) {
219 assert(
differentials[val]->getType()->getPointerElementType() == type);
227 if (
auto arg = dyn_cast<Argument>(val))
228 assert(arg->getParent() ==
oldFunc);
229 if (
auto inst = dyn_cast<Instruction>(val))
230 assert(inst->getParent()->getParent() ==
oldFunc);
234 llvm::errs() << *
newFunc <<
"\n";
235 llvm::errs() << *val <<
"\n";
236 assert(0 &&
"getting diffe of constant value");
242 if (val->getType()->isPointerTy()) {
243 llvm::errs() << *
newFunc <<
"\n";
244 llvm::errs() << *val <<
"\n";
246 assert(!val->getType()->isPointerTy());
247 assert(!val->getType()->isVoidTy());
253 Value *val, Value *dif, IRBuilder<> &BuilderM, Type *addingType,
254 unsigned start,
unsigned size, llvm::ArrayRef<llvm::Value *> idxs,
255 llvm::Value *mask,
size_t ignoreFirstSlicesOfDif) {
257 auto &DL =
oldFunc->getParent()->getDataLayout();
258 Type *VT = val->getType();
259 for (
auto cv : idxs) {
260 auto i = dyn_cast<ConstantInt>(cv)->getSExtValue();
261 if (
auto ST = dyn_cast<StructType>(VT)) {
262 VT = ST->getElementType(i);
265 if (
auto AT = dyn_cast<ArrayType>(VT)) {
266 assert((
size_t)i < AT->getNumElements());
267 VT = AT->getElementType();
270 assert(0 &&
"illegal indexing type");
272 auto storeSize = (DL.getTypeSizeInBits(VT) + 7) / 8;
274 assert(start < storeSize);
275 assert(start + size <= storeSize);
279 if (start == 0 && size == storeSize && !isa<StructType>(VT)) {
281 SmallVector<unsigned, 1> eidxs;
282 for (
auto idx : idxs.slice(ignoreFirstSlicesOfDif)) {
283 eidxs.push_back((
unsigned)cast<ConstantInt>(idx)->getZExtValue());
286 addingType, idxs, mask);
288 SmallVector<SelectInst *, 4> res;
289 for (
unsigned j = 0; j <
getWidth(); j++) {
290 SmallVector<Value *, 1> lidxs;
291 SmallVector<unsigned, 1> eidxs = {(unsigned)j};
293 ConstantInt::get(Type::getInt32Ty(val->getContext()), j));
294 for (
auto idx : idxs.slice(ignoreFirstSlicesOfDif)) {
295 eidxs.push_back((
unsigned)cast<ConstantInt>(idx)->getZExtValue());
297 for (
auto idx : idxs) {
298 lidxs.push_back(idx);
301 BuilderM, addingType, lidxs, mask))
307 if (
auto ST = dyn_cast<StructType>(VT)) {
308 auto SL = DL.getStructLayout(ST);
309 auto left_idx = SL->getElementContainingOffset(start);
310 auto right_idx = ST->getNumElements();
311 if (storeSize != start + size) {
312 right_idx = SL->getElementContainingOffset(start + size);
315 if (SL->getElementOffset(right_idx) != start + size)
318 SmallVector<SelectInst *, 4> res;
319 for (
auto i = left_idx; i < right_idx; i++) {
320 auto subType = ST->getElementType(i);
321 SmallVector<Value *, 1> lidxs(idxs.begin(), idxs.end());
322 lidxs.push_back(ConstantInt::get(Type::getInt32Ty(val->getContext()), i));
324 (i == left_idx) ? (start - (
unsigned)SL->getElementOffset(i)) : 0;
325 auto subTypeSize = (DL.getTypeSizeInBits(subType) + 7) / 8;
326 auto sub_end = (i == right_idx - 1)
327 ?
min(start + size - (
unsigned)SL->getElementOffset(i),
328 (
unsigned)subTypeSize)
331 addToDiffe(val, dif, BuilderM, addingType, sub_start,
332 sub_end - sub_start, lidxs, mask, ignoreFirstSlicesOfDif))
338 if (
auto AT = dyn_cast<ArrayType>(VT)) {
339 auto subType = AT->getElementType();
340 auto subTypeSize = (DL.getTypeSizeInBits(subType) + 7) / 8;
341 auto left_idx = start / subTypeSize;
342 auto right_idx = AT->getNumElements();
343 if (storeSize != start + size) {
344 right_idx = (start + size) / subTypeSize;
347 if (right_idx * subTypeSize != start + size)
350 SmallVector<SelectInst *, 4> res;
351 for (
auto i = left_idx; i < right_idx; i++) {
352 SmallVector<Value *, 1> lidxs(idxs.begin(), idxs.end());
353 lidxs.push_back(ConstantInt::get(Type::getInt32Ty(val->getContext()), i));
354 auto sub_start = (i == left_idx) ? (start - (i * subTypeSize)) : 0;
355 auto sub_end = (i == right_idx - 1)
356 ?
min(start + size - (
unsigned)(i * subTypeSize),
357 (
unsigned)subTypeSize)
360 addToDiffe(val, dif, BuilderM, addingType, sub_start,
361 sub_end - sub_start, lidxs, mask, ignoreFirstSlicesOfDif))
367 if (
auto VecT = dyn_cast<VectorType>(VT)) {
368 if (!VecT->getElementCount().isScalable()) {
369 Type *elemTy = VecT->getElementType();
370 auto elemBytes = (DL.getTypeSizeInBits(elemTy) + 7) / 8;
373 if (elemBytes != 0 && start % elemBytes == 0 && size % elemBytes == 0) {
374 unsigned left_idx = start / elemBytes;
375 unsigned right_idx = (start + size) / elemBytes;
377 unsigned numElts = VecT->getElementCount().getFixedValue();
378 if (left_idx > numElts)
380 if (right_idx > numElts)
383 auto maskVec = [&](Value *dsub) -> Value * {
384 if (left_idx == 0 && right_idx == numElts)
386 Value *masked = Constant::getNullValue(VT);
387 for (
unsigned i = left_idx; i < right_idx; i++) {
389 ConstantInt::get(Type::getInt32Ty(val->getContext()), i);
390 Value *el = BuilderM.CreateExtractElement(dsub, vidx);
391 masked = BuilderM.CreateInsertElement(masked, el, vidx);
397 SmallVector<unsigned, 1> eidxs;
398 for (
auto idx : idxs.slice(ignoreFirstSlicesOfDif))
399 eidxs.push_back((
unsigned)cast<ConstantInt>(idx)->getZExtValue());
402 return addToDiffe(val, maskVec(subdif), BuilderM, addingType, idxs,
405 SmallVector<SelectInst *, 4> res;
406 for (
unsigned j = 0; j <
getWidth(); j++) {
407 SmallVector<Value *, 1> lidxs;
408 SmallVector<unsigned, 1> eidxs = {(unsigned)j};
411 ConstantInt::get(Type::getInt32Ty(val->getContext()), j));
412 for (
auto idx : idxs.slice(ignoreFirstSlicesOfDif))
413 eidxs.push_back((
unsigned)cast<ConstantInt>(idx)->getZExtValue());
414 for (
auto idx : idxs)
415 lidxs.push_back(idx);
418 for (
auto v :
addToDiffe(val, maskVec(subdif), BuilderM, addingType,
428 llvm::errs() <<
" VT: " << *VT <<
" idxs:{";
429 for (
auto idx : idxs)
430 llvm::errs() << *idx <<
",";
431 llvm::errs() <<
"} start=" << start <<
" size=" << size
432 <<
" storeSize=" << storeSize <<
" val=" << *val <<
"\n";
433 assert(0 &&
"unhandled accumulate with partial sizes");
438#if LLVM_VERSION_MAJOR >= 22
439 return cst->isNullValue() || cst->isNegativeZeroValue();
441 return cst->isZeroValue();
444SmallVector<SelectInst *, 4>
446 Type *addingType, ArrayRef<Value *> idxs,
452 if (
auto arg = dyn_cast<Argument>(val))
453 assert(arg->getParent() ==
oldFunc);
454 if (
auto inst = dyn_cast<Instruction>(val))
455 assert(inst->getParent()->getParent() ==
oldFunc);
458 SmallVector<SelectInst *, 4> addedSelects;
460 auto faddForNeg = [&](Value *old, Value *inc,
bool san) {
461 if (
auto bi = dyn_cast<BinaryOperator>(inc)) {
462 if (
auto ci = dyn_cast<ConstantFP>(bi->getOperand(0))) {
463 if (bi->getOpcode() == BinaryOperator::FSub && ci->isZero()) {
464 Value *res = BuilderM.CreateFSub(old, bi->getOperand(1));
471 Value *res = BuilderM.CreateFAdd(old, inc);
477 auto faddForSelect = [&](Value *old, Value *dif) -> Value * {
479 if (SelectInst *select = dyn_cast<SelectInst>(dif)) {
480 if (
Constant *ci = dyn_cast<Constant>(select->getTrueValue())) {
482 SelectInst *res = cast<SelectInst>(BuilderM.CreateSelect(
483 select->getCondition(), old,
484 faddForNeg(old, select->getFalseValue(),
false)));
485 addedSelects.push_back(res);
489 if (
Constant *ci = dyn_cast<Constant>(select->getFalseValue())) {
491 SelectInst *res = cast<SelectInst>(BuilderM.CreateSelect(
492 select->getCondition(),
493 faddForNeg(old, select->getTrueValue(),
false), old));
494 addedSelects.push_back(res);
501 if (BitCastInst *bc = dyn_cast<BitCastInst>(dif)) {
502 if (SelectInst *select = dyn_cast<SelectInst>(bc->getOperand(0))) {
503 if (
Constant *ci = dyn_cast<Constant>(select->getTrueValue())) {
505 SelectInst *res = cast<SelectInst>(BuilderM.CreateSelect(
506 select->getCondition(), old,
508 BuilderM.CreateCast(bc->getOpcode(),
509 select->getFalseValue(),
512 addedSelects.push_back(res);
516 if (
Constant *ci = dyn_cast<Constant>(select->getFalseValue())) {
518 SelectInst *res = cast<SelectInst>(BuilderM.CreateSelect(
519 select->getCondition(),
521 BuilderM.CreateCast(bc->getOpcode(),
522 select->getTrueValue(),
526 addedSelects.push_back(res);
534 return faddForNeg(old, dif,
true);
537 if (val->getType()->isPointerTy()) {
538 llvm::errs() << *
newFunc <<
"\n";
539 llvm::errs() << *val <<
"\n";
542 llvm::errs() << *
newFunc <<
"\n";
543 llvm::errs() << *val <<
"\n";
545 assert(!val->getType()->isPointerTy());
551 if (idxs.size() != 0) {
552 SmallVector<Value *, 4> sv = {
553 ConstantInt::get(Type::getInt32Ty(val->getContext()), 0)};
556 ptr = BuilderM.CreateGEP(
getShadowType(val->getType()), ptr, sv);
557 cast<GetElementPtrInst>(ptr)->setIsInBounds(
true);
558 old = BuilderM.CreateLoad(
559 GetElementPtrInst::getIndexedType(
getShadowType(val->getType()), sv),
562 old = BuilderM.CreateLoad(
getShadowType(val->getType()), ptr);
564 if (dif->getType() != old->getType()) {
565 if (
auto inst = dyn_cast<Instruction>(val)) {
566 EmitFailure(
"IllegalAddingType", inst->getDebugLoc(), inst,
"val ", *val,
567 " dif ", *dif,
" old ", *old);
570 llvm::errs() <<
" IllegalAddingType val: " << *val <<
" dif: " << *dif
571 <<
" old: " << *old <<
"\n";
572 llvm_unreachable(
"IllegalAddingType");
575 assert(dif->getType() == old->getType());
576 Value *res =
nullptr;
577 if (old->getType()->isIntOrIntVectorTy() || old->getType()->isPointerTy()) {
580 if (old->getType()->isIntegerTy(64))
581 addingType = Type::getDoubleTy(old->getContext());
582 else if (old->getType()->isIntegerTy(32))
583 addingType = Type::getFloatTy(old->getContext());
588 llvm::raw_string_ostream ss(s);
589 ss <<
"oldFunc: " << *
oldFunc <<
"\n";
590 ss <<
"Cannot deduce adding type of: " << *val <<
"\n";
592 for (
auto idx : idxs)
595 if (
auto inst = dyn_cast<Instruction>(val)) {
604 llvm::errs() << ss.str() <<
"\n";
605 llvm_unreachable(
"Cannot deduce adding type");
610 assert(addingType->isFPOrFPVectorTy());
613 oldFunc->getParent()->getDataLayout().getTypeSizeInBits(old->getType());
615 oldFunc->getParent()->getDataLayout().getTypeSizeInBits(addingType);
617 if (oldBitSize == newBitSize) {
618 }
else if (oldBitSize > newBitSize && oldBitSize % newBitSize == 0) {
619 if (!addingType->isVectorTy())
621 VectorType::get(addingType, oldBitSize / newBitSize,
false);
624 llvm::raw_string_ostream ss(s);
625 ss <<
"oldFunc: " << *
oldFunc <<
"\n";
626 ss <<
"Illegal intermediate when adding to: " << *val
627 <<
" with addingType: " << *addingType <<
"\n"
628 <<
" old: " << *old <<
" dif: " << *dif <<
"\n"
629 <<
" oldBitSize: " << oldBitSize <<
" newBitSize: " << newBitSize
637 if (
auto inst = dyn_cast<Instruction>(val))
638 EmitFailure(
"CannotDeduceType", inst->getDebugLoc(), inst, ss.str());
640 llvm::errs() << ss.str() <<
"\n";
641 llvm_unreachable(
"Cannot deduce adding type");
649 Type *intTy =
nullptr;
650 if (old->getType()->isPointerTy()) {
651 auto &DL =
oldFunc->getParent()->getDataLayout();
652 intTy = Type::getIntNTy(old->getContext(), DL.getPointerSizeInBits());
653 bcold = BuilderM.CreatePtrToInt(bcold, intTy);
654 bcdif = BuilderM.CreatePtrToInt(bcdif, intTy);
656 intTy = old->getType();
659 bcold = BuilderM.CreateBitCast(bcold, addingType);
660 bcdif = BuilderM.CreateBitCast(bcdif, addingType);
662 res = faddForSelect(bcold, bcdif);
663 if (SelectInst *select = dyn_cast<SelectInst>(res)) {
664 assert(addedSelects.back() == select);
665 addedSelects.erase(addedSelects.end() - 1);
667 Value *tval = BuilderM.CreateBitCast(select->getTrueValue(), intTy);
668 Value *fval = BuilderM.CreateBitCast(select->getFalseValue(), intTy);
669 if (old->getType()->isPointerTy()) {
670 tval = BuilderM.CreateIntToPtr(tval, old->getType());
671 fval = BuilderM.CreateIntToPtr(fval, old->getType());
673 res = BuilderM.CreateSelect(select->getCondition(), tval, fval);
674 assert(select->getNumUses() == 0);
676 res = BuilderM.CreateBitCast(res, intTy);
677 if (old->getType()->isPointerTy())
678 res = BuilderM.CreateIntToPtr(res, old->getType());
681 BuilderM.CreateStore(res, ptr);
684 Type *tys[] = {res->getType(), ptr->getType()};
686 Intrinsic::masked_store, tys);
687 auto align = cast<AllocaInst>(ptr)->getAlign().value();
690 ConstantInt::get(Type::getInt32Ty(mask->getContext()), align);
691 Value *args[] = {res, ptr, alignv, mask};
692 BuilderM.CreateCall(F, args);
695 }
else if (old->getType()->isFPOrFPVectorTy()) {
697 res = faddForSelect(old, dif);
700 BuilderM.CreateStore(res, ptr);
703 Type *tys[] = {res->getType(), ptr->getType()};
705 Intrinsic::masked_store, tys);
706 auto align = cast<AllocaInst>(ptr)->getAlign().value();
709 ConstantInt::get(Type::getInt32Ty(mask->getContext()), align);
710 Value *args[] = {res, ptr, alignv, mask};
711 BuilderM.CreateCall(F, args);
714 }
else if (
auto st = dyn_cast<StructType>(old->getType())) {
717 llvm_unreachable(
"cannot handle recursive addToDiffe with mask");
718 for (
unsigned i = 0; i < st->getNumElements(); ++i) {
720 if (st->getElementType(i)->isPointerTy())
722 if (st->getElementType(i)->isIntegerTy(8) ||
723 st->getElementType(i)->isIntegerTy(1))
725 Value *v = ConstantInt::get(Type::getInt32Ty(st->getContext()), i);
726 SmallVector<Value *, 2> idx2(idxs.begin(), idxs.end());
731 for (
auto select : selects) {
732 addedSelects.push_back(select);
736 }
else if (
auto at = dyn_cast<ArrayType>(old->getType())) {
739 llvm_unreachable(
"cannot handle recursive addToDiffe with mask");
740 if (at->getElementType()->isPointerTy())
742 for (
unsigned i = 0; i < at->getNumElements(); ++i) {
744 Value *v = ConstantInt::get(Type::getInt32Ty(at->getContext()), i);
745 SmallVector<Value *, 2> idx2(idxs.begin(), idxs.end());
749 for (
auto select : selects) {
750 addedSelects.push_back(select);
755 llvm::errs() <<
" idx: {";
757 llvm::errs() << *i <<
", ";
758 llvm::errs() <<
"}\n";
760 llvm::errs() <<
" addingType: " << *addingType <<
"\n";
762 llvm::errs() <<
" addingType: null\n";
763 llvm::errs() <<
" oldType:" << *old->getType() <<
" old:" << *old <<
"\n";
764 llvm_unreachable(
"unknown type to add to diffe");
770 IRBuilder<> &BuilderM) {
772 if (
auto arg = dyn_cast<Argument>(val))
773 assert(arg->getParent() ==
oldFunc);
774 if (
auto inst = dyn_cast<Instruction>(val))
775 assert(inst->getParent()->getParent() ==
oldFunc);
777 llvm::errs() << *
newFunc <<
"\n";
778 llvm::errs() << *val <<
"\n";
789 auto placeholder0 = &*found->second;
790 auto placeholder = cast<PHINode>(placeholder0);
793 placeholder->replaceAllUsesWith(toset);
800#if LLVM_VERSION_MAJOR < 17
801 if (toset->getContext().supportsTypedPointers()) {
802 if (toset->getType() != tostore->getType()->getPointerElementType()) {
803 llvm::errs() <<
"toset:" << *toset <<
"\n";
804 llvm::errs() <<
"tostore:" << *tostore <<
"\n";
806 assert(toset->getType() == tostore->getType()->getPointerElementType());
809 BuilderM.CreateStore(toset, tostore);
814 AllocaInst *alloc, llvm::Type *T,
815 ConstantInt *byteSizeOfType,
816 Value *storeInto, MDNode *InvariantMD) {
822 tbuild.setFastMathFlags(
getFast());
825 if (tbuild.GetInsertBlock()->size() &&
826 tbuild.GetInsertBlock()->getTerminator()) {
827 tbuild.SetInsertPoint(tbuild.GetInsertBlock()->getTerminator());
830 ValueToValueMapTy antimap;
831 for (
int j = sublimits.size() - 1; j >= i; j--) {
832 auto &innercontainedloops = sublimits[j].second;
833 for (
auto riter = innercontainedloops.rbegin(),
834 rend = innercontainedloops.rend();
835 riter != rend; ++riter) {
836 const auto &idx = riter->first;
839 tbuild.CreateLoad(idx.var->getType(), idx.antivaralloc);
844 Value *metaforfree =
unwrapM(storeInto, tbuild, antimap,
847#if LLVM_VERSION_MAJOR < 17
848 if (metaforfree->getContext().supportsTypedPointers()) {
849 assert(T == metaforfree->getType()->getPointerElementType());
853 LoadInst *forfree = cast<LoadInst>(tbuild.CreateLoad(T, metaforfree));
854 forfree->setMetadata(LLVMContext::MD_invariant_group, InvariantMD);
855 forfree->setMetadata(LLVMContext::MD_dereferenceable,
856 MDNode::get(forfree->getContext(),
857 ArrayRef<Metadata *>(ConstantAsMetadata::get(
859 forfree->setName(
"forfree");
861 (
unsigned)
newFunc->getParent()->getDataLayout().getPointerSize());
862 forfree->setAlignment(Align(align));
867 ci->setDebugLoc(DILocation::get(
newFunc->getContext(), 0, 0,
875 Value *origVal, Type *addingType,
876 unsigned start,
unsigned size,
877 Value *origptr, Value *dif,
878 IRBuilder<> &BuilderM,
879 MaybeAlign align, Value *mask) {
880 auto &DL =
oldFunc->getParent()->getDataLayout();
882 auto addingSize = (DL.getTypeSizeInBits(addingType) + 1) / 8;
883 if (addingSize != size) {
884 assert(size > addingSize);
886 VectorType::get(addingType, size / addingSize,
false);
887 size = (size / addingSize) * addingSize;
899 assert(
false &&
"Invalid derivative mode (ReverseModePrimal)");
907 bool needsCast =
false;
908#if LLVM_VERSION_MAJOR < 17
909 if (isa<PointerType>(origptr->getType()) &&
910 origptr->getContext().supportsTypedPointers()) {
911 needsCast = origptr->getType()->getPointerElementType() != addingType;
916 if (start != 0 || needsCast || !isa<PointerType>(origptr->getType())) {
917 auto rule = [&](Value *ptr) {
918 if (!isa<PointerType>(origptr->getType())) {
919 ptr = BuilderM.CreateIntToPtr(ptr,
getUnqual(addingType));
922 auto i8 = Type::getInt8Ty(ptr->getContext());
923 ptr = BuilderM.CreatePointerCast(
924 ptr, PointerType::get(
925 i8, cast<PointerType>(ptr->getType())->getAddressSpace()));
926 auto off = ConstantInt::get(Type::getInt64Ty(ptr->getContext()), start);
927 ptr = BuilderM.CreateInBoundsGEP(i8, ptr, off);
930 ptr = BuilderM.CreatePointerCast(
931 ptr, PointerType::get(
933 cast<PointerType>(ptr->getType())->getAddressSpace()));
940 isa<PointerType>(origptr->getType())
941 ? cast<PointerType>(origptr->getType())->getAddressSpace()
943 BuilderM, rule, ptr);
947 needsCast = dif->getType() != addingType;
948 else if (
auto AT = cast<ArrayType>(dif->getType()))
949 needsCast = AT->getElementType() != addingType;
952 cast<VectorType>(dif->getType())->getElementType() != addingType;
954 if (start != 0 || needsCast) {
955 auto rule = [&](Value *dif) {
958 auto i8 = Type::getInt8Ty(ptr->getContext());
959 auto prevSize = (DL.getTypeSizeInBits(dif->getType()) + 1) / 8;
960 Type *tys[] = {ArrayType::get(i8, start), addingType,
961 ArrayType::get(i8, prevSize - start - size)};
962 auto ST = StructType::get(i8->getContext(), tys,
true);
963 auto Al = A.CreateAlloca(ST);
964 BuilderM.CreateStore(
965 dif, BuilderM.CreatePointerCast(Al,
getUnqual(dif->getType())));
967 ConstantInt::get(Type::getInt64Ty(ptr->getContext()), 0),
968 ConstantInt::get(Type::getInt32Ty(ptr->getContext()), 1)};
970 auto difp = BuilderM.CreateInBoundsGEP(ST, Al, idxs);
971 dif = BuilderM.CreateLoad(addingType, difp);
973 if (dif->getType() != addingType) {
974 auto difSize = (DL.getTypeSizeInBits(dif->getType()) + 1) / 8;
975 if (difSize < size) {
976 llvm::errs() <<
" ds: " << difSize <<
" as: " << size <<
"\n";
977 llvm::errs() <<
" dif: " << *dif <<
" adding: " << *addingType
980 assert(difSize >= size);
981 if (CastInst::castIsValid(Instruction::CastOps::BitCast, dif,
983 dif = BuilderM.CreateBitCast(dif, addingType);
986 auto Al = A.CreateAlloca(addingType);
987 BuilderM.CreateStore(
988 dif, BuilderM.CreatePointerCast(Al,
getUnqual(dif->getType())));
989 dif = BuilderM.CreateLoad(addingType, Al);
1001 auto Arch = llvm::Triple(
newFunc->getParent()->getTargetTriple()).getArch();
1005 if (isa<AllocaInst>(TmpOrig) &&
1006 (Arch == Triple::nvptx || Arch == Triple::nvptx64 ||
1007 Arch == Triple::amdgcn)) {
1015 if (Atomic && elementwiseReadForContext(orig, origptr))
1021 if (Arch == Triple::amdgcn &&
1022 cast<PointerType>(origptr->getType())->getAddressSpace() == 4) {
1023 auto rule = [&](Value *ptr) {
1024 return BuilderM.CreateAddrSpaceCast(ptr,
1025 PointerType::get(addingType, 1));
1028 applyChainRule(PointerType::get(addingType, 1), BuilderM, rule, ptr);
1033 llvm::raw_string_ostream ss(s);
1034 ss <<
"Unimplemented masked atomic fadd for ptr:" << *ptr
1035 <<
" dif:" << *dif <<
" mask: " << *mask <<
" orig: " << *orig <<
"\n";
1042 EmitFailure(
"NoDerivative", orig->getDebugLoc(), orig, ss.str());
1057 AtomicRMWInst::BinOp op = AtomicRMWInst::FAdd;
1058 if (
auto vt = dyn_cast<VectorType>(addingType)) {
1059 assert(!vt->getElementCount().isScalable());
1060 size_t numElems = vt->getElementCount().getKnownMinValue();
1061 auto rule = [&](Value *dif, Value *ptr) {
1062 for (
size_t i = 0; i < numElems; ++i) {
1063 auto vdif = BuilderM.CreateExtractElement(dif, i);
1066 ConstantInt::get(Type::getInt64Ty(vt->getContext()), 0),
1067 ConstantInt::get(Type::getInt32Ty(vt->getContext()), i)};
1068 auto vptr = BuilderM.CreateGEP(addingType, ptr, Idxs);
1069 MaybeAlign alignv = align;
1073 assert((*alignv).value() != 0);
1074 if (start % (*alignv).value() != 0) {
1079 BuilderM.CreateAtomicRMW(op, vptr, vdif, alignv,
1080 AtomicOrdering::Monotonic,
1086 auto rule = [&](Value *dif, Value *ptr) {
1088 MaybeAlign alignv = align;
1092 assert((*alignv).value() != 0);
1093 if (start % (*alignv).value() != 0) {
1098 BuilderM.CreateAtomicRMW(op, ptr, dif, alignv,
1099 AtomicOrdering::Monotonic, SyncScope::System);
1109 auto rule = [&](Value *ptr, Value *dif) {
1110 auto LI = BuilderM.CreateLoad(addingType, ptr);
1112 Value *res = BuilderM.CreateFAdd(
LI, dif);
1114 StoreInst *st = BuilderM.CreateStore(res, ptr);
1116 SmallVector<Metadata *, 1> scopeMD = {
1118 if (
auto origValI = dyn_cast_or_null<Instruction>(origVal))
1119 if (
auto MD = origValI->getMetadata(LLVMContext::MD_alias_scope)) {
1120 auto MDN = cast<MDNode>(MD);
1121 for (
auto &o : MDN->operands())
1122 scopeMD.push_back(o);
1124 auto scope = MDNode::get(
LI->getContext(), scopeMD);
1125 LI->setMetadata(LLVMContext::MD_alias_scope, scope);
1126 st->setMetadata(LLVMContext::MD_alias_scope, scope);
1128 SmallVector<Metadata *, 1> MDs;
1129 for (ssize_t j = -1; j <
getWidth(); j++) {
1130 if (j != (ssize_t)idx)
1133 if (
auto origValI = dyn_cast_or_null<Instruction>(origVal))
1134 if (
auto MD = origValI->getMetadata(LLVMContext::MD_noalias)) {
1135 auto MDN = cast<MDNode>(MD);
1136 for (
auto &o : MDN->operands())
1140 auto noscope = MDNode::get(ptr->getContext(), MDs);
1141 LI->setMetadata(LLVMContext::MD_noalias, noscope);
1142 st->setMetadata(LLVMContext::MD_noalias, noscope);
1144 if (origVal && isa<Instruction>(origVal) && start == 0 &&
1145 size == (DL.getTypeSizeInBits(origVal->getType()) + 7) / 8) {
1146 auto origValI = cast<Instruction>(origVal);
1148 unsigned int StoreData[] = {LLVMContext::MD_tbaa,
1149 LLVMContext::MD_tbaa_struct};
1150 for (
auto MD : StoreData)
1151 st->setMetadata(MD, origValI->getMetadata(MD));
1158 auto alignv = align ? (*align).value() : 0;
1162 if (start % alignv != 0) {
1167 LI->setAlignment(Align(alignv));
1168 st->setAlignment(Align(alignv));
1174 Type *tys[] = {addingType, origptr->getType()};
1176 Intrinsic::masked_load, tys);
1178 Intrinsic::masked_store, tys);
1179 unsigned aligni = align ? align->value() : 0;
1184 if (start % aligni != 0) {
1189 ConstantInt::get(Type::getInt32Ty(mask->getContext()), aligni);
1190 auto rule = [&](Value *ptr, Value *dif) {
1191 Value *largs[] = {ptr, alignv, mask,
1192 Constant::getNullValue(dif->getType())};
1193 Value *
LI = BuilderM.CreateCall(LF, largs);
1194 Value *res = BuilderM.CreateFAdd(
LI, dif);
1196 Value *sargs[] = {res, ptr, alignv, mask};
1197 BuilderM.CreateCall(SF, sargs);
1204 llvm::Instruction *orig, llvm::Value *origVal,
TypeTree vd,
1205 unsigned LoadSize, llvm::Value *origptr, llvm::Value *prediff,
1206 llvm::IRBuilder<> &Builder2, MaybeAlign alignment, llvm::Value *premask)
1211 unsigned size = LoadSize;
1215 BasicBlock *merge =
nullptr;
1218 unsigned nextStart = size;
1221 for (
size_t i = start; i < size; ++i) {
1223 dt.checkedOrIn(vd[{(int)i}],
true, Legal);
1229 if (!dt.isKnown()) {
1231 llvm::errs() <<
" vd:" << vd.
str() <<
" start:" << start
1232 <<
" size: " << size <<
" dt:" << dt.str() <<
"\n";
1234 assert(dt.isKnown());
1236 if (Type *isfloat = dt.isFloat()) {
1239 if (start == 0 && nextStart == LoadSize) {
1246 auto i8 = Type::getInt8Ty(tostore->getContext());
1248 tostore = Builder2.CreatePointerCast(
1252 cast<PointerType>(tostore->getType())->getAddressSpace()));
1253 auto off = ConstantInt::get(Type::getInt64Ty(tostore->getContext()),
1255 tostore = Builder2.CreateInBoundsGEP(i8, tostore, off);
1257 auto AT = ArrayType::get(i8, nextStart - start);
1258 tostore = Builder2.CreatePointerCast(
1262 cast<PointerType>(tostore->getType())->getAddressSpace()));
1263 Builder2.CreateStore(Constant::getNullValue(AT), tostore);
1279 shadow_val =
extractMeta(Builder2, shadow_val, 0);
1281 Value *shadow = Builder2.CreateICmpNE(primal_val, shadow_val);
1283 BasicBlock *current = Builder2.GetInsertBlock();
1284 BasicBlock *conditional =
1287 Builder2.CreateCondBr(shadow, conditional, merge);
1288 Builder2.SetInsertPoint(conditional);
1292 assert(start == 0 && nextStart == LoadSize);
1294 origptr, prediff, Builder2, alignment, premask);
1298 if (nextStart == size)
1303 Builder2.CreateBr(merge);
1304 Builder2.SetInsertPoint(merge);
@ AttemptFullUnwrapWithLookup
static bool isZero(llvm::Constant *cst)
static bool isAllocationCall(const llvm::Value *TmpOrig, llvm::TargetLibraryInfo &TLI)
llvm::cl::opt< bool > looseTypeAnalysis
SmallVector< unsigned int, 9 > MD_ToCopy
CallInst * CreateDealloc(llvm::IRBuilder<> &Builder, llvm::Value *ToFree)
void ZeroMemory(llvm::IRBuilder<> &Builder, llvm::Type *T, llvm::Value *obj, bool isTape)
void EmitNoTypeError(const std::string &message, llvm::Instruction &inst, GradientUtils *gutils, llvm::IRBuilder<> &Builder2)
LLVMValueRef(* CustomErrorHandler)(const char *, LLVMValueRef, ErrorType, const void *, LLVMValueRef, LLVMBuilderRef)
llvm::FastMathFlags getFast()
Get LLVM fast math flags.
llvm::Value * SanitizeDerivatives(llvm::Value *val, llvm::Value *toset, llvm::IRBuilder<> &BuilderM, llvm::Value *mask)
DIFFE_TYPE
Potential differentiable argument classifications.
static llvm::PointerType * getUnqual(llvm::Type *T)
void EmitFailure(llvm::StringRef RemarkName, const llvm::DiagnosticLocation &Loc, const llvm::Instruction *CodeRegion, Args &...args)
static llvm::Function * getIntrinsicDeclaration(llvm::Module *M, llvm::Intrinsic::ID id, llvm::ArrayRef< llvm::Type * > Tys={})
static llvm::Value * getBaseObject(llvm::Value *V, bool offsetAllowed=true)
static T min(T a, T b)
Pick the maximum value.
llvm::Function *const newFunc
The function whose instructions we are caching.
std::map< llvm::AllocaInst *, std::set< llvm::AssertingVH< llvm::CallInst > > > scopeFrees
A map of allocations to a set of instructions which free memory as part of the cache.
llvm::TargetLibraryInfo & TLI
Various analysis results of newFunc.
llvm::BasicBlock * inversionAllocs
llvm::SmallVector< std::pair< llvm::Value *, llvm::SmallVector< std::pair< LoopContext, llvm::Value * >, 4 > >, 0 > SubLimitType
Given a LimitContext ctx, representing a location inside a loop nest, break each of the loops up into...
unsigned getCacheAlignment(unsigned bsize) const
llvm::ValueMap< const llvm::Value *, llvm::TrackingVH< llvm::AllocaInst > > differentials
void addToInvertedPtrDiffe(llvm::Instruction *orig, llvm::Value *origVal, llvm::Type *addingType, unsigned start, unsigned size, llvm::Value *origptr, llvm::Value *dif, llvm::IRBuilder<> &BuilderM, llvm::MaybeAlign align=llvm::MaybeAlign(), llvm::Value *mask=nullptr)
align is the alignment that should be specified for load/store to pointer
llvm::Value * diffe(llvm::Value *val, llvm::IRBuilder<> &BuilderM)
llvm::AllocaInst * getDifferential(llvm::Value *val)
bool FreeMemory
Whether to free memory in reverse pass or split forward.
static DiffeGradientUtils * CreateFromClone(EnzymeLogic &Logic, DerivativeMode mode, bool runtimeActivity, bool strongZero, unsigned width, llvm::Function *todiff, llvm::TargetLibraryInfo &TLI, TypeAnalysis &TA, FnTypeInfo &oldTypeInfo, DIFFE_TYPE retType, bool shadowReturnArg, bool diffeReturnArg, llvm::ArrayRef< DIFFE_TYPE > constant_args, bool returnTape, bool returnPrimal, llvm::Type *additionalArg, bool omp)
llvm::CallInst * freeCache(llvm::BasicBlock *forwardPreheader, const SubLimitType &sublimits, int i, llvm::AllocaInst *alloc, llvm::Type *myType, llvm::ConstantInt *byteSizeOfType, llvm::Value *storeInto, llvm::MDNode *InvariantMD) override
If an allocation is requested to be freed, this subclass will be called to chose how and where to fre...
void setDiffe(llvm::Value *val, llvm::Value *toset, llvm::IRBuilder<> &BuilderM)
llvm::SmallVector< llvm::SelectInst *, 4 > addToDiffe(llvm::Value *val, llvm::Value *dif, llvm::IRBuilder<> &BuilderM, llvm::Type *addingType, llvm::ArrayRef< llvm::Value * > idxs={}, llvm::Value *mask=nullptr)
Returns created select instructions, if any.
llvm::ValueMap< const llvm::Value *, InvertedPointerVH > invertedPointers
llvm::DebugLoc getNewFromOriginal(const llvm::DebugLoc L) const
static llvm::Value * extractMeta(llvm::IRBuilder<> &Builder, llvm::Value *Agg, unsigned off, const llvm::Twine &name="")
Helper routine to extract a nested element from a struct/array. This is.
llvm::Value * applyChainRule(llvm::Type *diffType, llvm::IRBuilder<> &Builder, Func rule, Args... args)
Unwraps a vector derivative from its internal representation and applies a function f to each element...
llvm::BasicBlock * addReverseBlock(llvm::BasicBlock *currentBlock, llvm::Twine const &name, bool forkCache=true, bool push=true)
std::map< llvm::BasicBlock *, llvm::SmallVector< llvm::BasicBlock *, 4 > > reverseBlocks
Map of primal block to corresponding block(s) in reverse.
void replaceAWithB(llvm::Value *A, llvm::Value *B, bool storeInCache=false) override
Replace this instruction both in LLVM modules and any local data-structures.
llvm::Value * unwrapM(llvm::Value *const val, llvm::IRBuilder<> &BuilderM, const llvm::ValueToValueMapTy &available, UnwrapMode unwrapMode, llvm::BasicBlock *scope=nullptr, bool permitCache=true) override final
if full unwrap, don't just unwrap this instruction, but also its operands, etc
llvm::Value * lookupM(llvm::Value *val, llvm::IRBuilder<> &BuilderM, const llvm::ValueToValueMapTy &incoming_availalble=llvm::ValueToValueMapTy(), bool tryLegalRecomputeCheck=true, llvm::BasicBlock *scope=nullptr) override
High-level utility to get the value an instruction at a new location specified by BuilderM.
llvm::MDNode * getDerivativeAliasScope(const llvm::Value *origptr, ssize_t newptr)
static llvm::Type * getShadowType(llvm::Type *ty, unsigned width)
llvm::ValueMap< llvm::Value *, ShadowRematerializer > backwardsOnlyShadows
Only loaded from and stored to (not captured), mapped to the stores (and memset).
bool isConstantValue(llvm::Value *val) const
llvm::Value * invertPointerM(llvm::Value *val, llvm::IRBuilder<> &BuilderM, bool nullShadow=false)
void erase(llvm::Instruction *I) override
Erase this instruction both from LLVM modules and any local data-structures.
llvm::Function * CloneFunctionWithReturns(DerivativeMode mode, unsigned width, llvm::Function *&F, llvm::ValueToValueMapTy &ptrInputs, llvm::ArrayRef< DIFFE_TYPE > constant_args, llvm::SmallPtrSetImpl< llvm::Value * > &constants, llvm::SmallPtrSetImpl< llvm::Value * > &nonconstant, llvm::SmallPtrSetImpl< llvm::Value * > &returnvals, bool returnTape, bool returnPrimal, bool returnShadow, const llvm::Twine &name, llvm::ValueMap< const llvm::Value *, AssertingReplacingVH > *VMapO, bool diffeReturnArg, llvm::Type *additionalArg=nullptr)
Full interprocedural TypeAnalysis.
TypeResults analyzeFunction(const FnTypeInfo &fn)
Analyze a particular function, returning the results.
A holder class representing the results of running TypeAnalysis on a given function.
void dump(llvm::raw_ostream &ss=llvm::errs()) const
Prints all known information.
llvm::Function * getFunction() const
Class representing the underlying types of values as sequences of offsets to a ConcreteType.
std::string str() const
Returns a string representation of this TypeTree.
Struct containing all contextual type information for a particular function call.
std::map< llvm::Argument *, TypeTree > Arguments
Types of arguments.
TypeTree Return
Type of return.
std::map< llvm::Argument *, std::set< int64_t > > KnownValues
The specific constant(s) known to represented by an argument, if constant.