26#include "GradientUtils.h"
29#if LLVM_VERSION_MAJOR >= 16
30#include "llvm/Analysis/ScalarEvolution.h"
31#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
33#include "SCEV/ScalarEvolution.h"
34#include "SCEV/ScalarEvolutionExpander.h"
38#include "llvm/IR/BasicBlock.h"
39#include "llvm/IR/DerivedTypes.h"
40#include "llvm/IR/Function.h"
41#include "llvm/IR/GetElementPtrTypeIterator.h"
42#include "llvm/IR/IRBuilder.h"
43#include "llvm/IR/InlineAsm.h"
44#include "llvm/IR/Module.h"
45#include "llvm/IR/Type.h"
46#include "llvm/IR/Verifier.h"
48#if LLVM_VERSION_MAJOR >= 16
49#include "llvm/TargetParser/Triple.h"
51#include "llvm/ADT/Triple.h"
54#include "llvm-c/Core.h"
56#include "BlasAttributor.inc"
63 const void *, LLVMValueRef,
64 LLVMBuilderRef) =
nullptr;
67 LLVMValueRef, uint8_t,
68 LLVMValueRef *) =
nullptr;
70 LLVMValueRef, uint8_t) =
nullptr;
73 LLVMValueRef) =
nullptr;
74LLVMValueRef *(*EnzymePostCacheStore)(LLVMValueRef, LLVMBuilderRef,
75 uint64_t *size) =
nullptr;
82 LLVMValueRef) =
nullptr;
89 cl::desc(
"Use blas copy calls to cache matrices"));
92 cl::desc(
"Use blas copy calls to cache vectors"));
95 cl::desc(
"Use fast math on derivative compuation"));
97 "enzyme-memmove-warning", cl::init(
true), cl::Hidden,
98 cl::desc(
"Warn if using memmove implementation as a fallback for memmove"));
100 "enzyme-runtime-error", cl::init(
false), cl::Hidden,
101 cl::desc(
"Emit Runtime errors instead of compile time ones"));
104 "enzyme-check-nan", cl::init(
false), cl::Hidden,
105 cl::desc(
"Add NaN checks to all derivative intermediate values"));
108 "enzyme-non-power2-cache", cl::init(
false), cl::Hidden,
109 cl::desc(
"Disable caching of integers which are not a power of 2"));
112#define addAttribute addAttributeAtIndex
113#define getAttribute getAttributeAtIndex
115 bool changed =
false;
116 if (F.getName() ==
"fprintf") {
117 for (
auto &arg : F.args()) {
118 if (arg.getType()->isPointerTy()) {
124 if (F.getName().contains(
"__enzyme_float") ||
125 F.getName().contains(
"__enzyme_double") ||
126 F.getName().contains(
"__enzyme_integer") ||
127 F.getName().contains(
"__enzyme_pointer") ||
128 F.getName().contains(
"__enzyme_todense") ||
129 F.getName().contains(
"__enzyme_ignore_derivatives") ||
130 F.getName().contains(
"__enzyme_iter") ||
131 F.getName().contains(
"__enzyme_virtualreverse")) {
133#if LLVM_VERSION_MAJOR >= 16
134 F.setOnlyReadsMemory();
135 F.setOnlyWritesMemory();
137 F.addFnAttr(Attribute::ReadNone);
139 if (!(F.getName().contains(
"__enzyme_todense") ||
140 F.getName().contains(
"__enzyme_ignore_derivatives"))) {
141 for (
auto &arg : F.args()) {
142 if (arg.getType()->isPointerTy()) {
143 arg.addAttr(Attribute::ReadNone);
149 if (F.getName() ==
"memcmp") {
151#if LLVM_VERSION_MAJOR >= 16
152 F.setOnlyAccessesArgMemory();
153 F.setOnlyReadsMemory();
155 F.addFnAttr(Attribute::ArgMemOnly);
156 F.addFnAttr(Attribute::ReadOnly);
158 F.addFnAttr(Attribute::NoUnwind);
159 F.addFnAttr(Attribute::NoRecurse);
160 F.addFnAttr(Attribute::WillReturn);
161 F.addFnAttr(Attribute::NoFree);
162 F.addFnAttr(Attribute::NoSync);
163 for (
int i = 0; i < 2; i++)
164 if (F.getFunctionType()->getParamType(i)->isPointerTy()) {
166 F.addParamAttr(i, Attribute::ReadOnly);
171 "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE9_M_createERmm") {
173 F.addFnAttr(Attribute::NoFree);
175 if (F.getName() ==
"MPI_Irecv" || F.getName() ==
"PMPI_Irecv") {
176 auto FT = F.getFunctionType();
177 bool PointerABI =
true;
179 F.addFnAttr(Attribute::NoUnwind);
180 F.addFnAttr(Attribute::NoRecurse);
181 F.addFnAttr(Attribute::WillReturn);
182 F.addFnAttr(Attribute::NoFree);
183 F.addFnAttr(Attribute::NoSync);
184 if (FT->getParamType(0)->isPointerTy()) {
185 F.addParamAttr(0, Attribute::WriteOnly);
190 if (FT->getParamType(2)->isPointerTy()) {
192 F.addParamAttr(2, Attribute::WriteOnly);
194 if (FT->getParamType(6)->isPointerTy()) {
195 F.addParamAttr(6, Attribute::WriteOnly);
200#if LLVM_VERSION_MAJOR >= 16
201 F.setOnlyAccessesInaccessibleMemOrArgMem();
203 F.addFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
208 if (name ==
"MPI_Isend" || name ==
"PMPI_Isend") {
209 auto FT = F.getFunctionType();
210 bool PointerABI =
true;
212 F.addFnAttr(Attribute::NoUnwind);
213 F.addFnAttr(Attribute::NoRecurse);
214 F.addFnAttr(Attribute::WillReturn);
215 F.addFnAttr(Attribute::NoFree);
216 F.addFnAttr(Attribute::NoSync);
217 if (FT->getParamType(0)->isPointerTy()) {
218 F.addParamAttr(0, Attribute::ReadOnly);
223 if (FT->getParamType(2)->isPointerTy()) {
225 F.addParamAttr(2, Attribute::ReadOnly);
227 if (FT->getParamType(6)->isPointerTy()) {
228 F.addParamAttr(6, Attribute::WriteOnly);
233#if LLVM_VERSION_MAJOR >= 16
234 F.setOnlyAccessesInaccessibleMemOrArgMem();
236 F.addFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
240 if (name ==
"MPI_Comm_rank" || name ==
"PMPI_Comm_rank" ||
241 name ==
"MPI_Comm_size" || name ==
"PMPI_Comm_size") {
242 auto FT = F.getFunctionType();
243 bool PointerABI =
true;
245 F.addFnAttr(Attribute::NoUnwind);
246 F.addFnAttr(Attribute::NoRecurse);
247 F.addFnAttr(Attribute::WillReturn);
248 F.addFnAttr(Attribute::NoFree);
249 F.addFnAttr(Attribute::NoSync);
252 if (FT->getParamType(0)->isPointerTy()) {
254 F.addParamAttr(0, Attribute::ReadOnly);
256 if (FT->getParamType(1)->isPointerTy()) {
257 F.addParamAttr(1, Attribute::WriteOnly);
263#if LLVM_VERSION_MAJOR >= 16
264 F.setOnlyAccessesInaccessibleMemOrArgMem();
266 F.addFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
270 if (name ==
"MPI_Wait" || name ==
"PMPI_Wait") {
272 F.addFnAttr(Attribute::NoUnwind);
273 F.addFnAttr(Attribute::NoRecurse);
274 F.addFnAttr(Attribute::WillReturn);
275 F.addFnAttr(Attribute::NoFree);
276 F.addFnAttr(Attribute::NoSync);
277 if (F.getFunctionType()->getParamType(0)->isPointerTy()) {
280 if (F.getFunctionType()->getParamType(1)->isPointerTy()) {
281 F.addParamAttr(1, Attribute::WriteOnly);
285 if (name ==
"MPI_Waitall" || name ==
"PMPI_Waitall") {
287 F.addFnAttr(Attribute::NoUnwind);
288 F.addFnAttr(Attribute::NoRecurse);
289 F.addFnAttr(Attribute::WillReturn);
290 F.addFnAttr(Attribute::NoFree);
291 F.addFnAttr(Attribute::NoSync);
292 if (F.getFunctionType()->getParamType(1)->isPointerTy()) {
295 if (F.getFunctionType()->getParamType(2)->isPointerTy()) {
296 F.addParamAttr(2, Attribute::WriteOnly);
301 std::map<std::string, int> MPI_TYPE_ARGS = {
302 {
"MPI_Send", 2}, {
"MPI_Ssend", 2}, {
"MPI_Bsend", 2},
303 {
"MPI_Recv", 2}, {
"MPI_Brecv", 2}, {
"PMPI_Send", 2},
304 {
"PMPI_Ssend", 2}, {
"PMPI_Bsend", 2}, {
"PMPI_Recv", 2},
307 {
"MPI_Isend", 2}, {
"MPI_Irecv", 2}, {
"PMPI_Isend", 2},
310 {
"MPI_Reduce", 3}, {
"PMPI_Reduce", 3},
312 {
"MPI_Allreduce", 3}, {
"PMPI_Allreduce", 3}};
314 auto found = MPI_TYPE_ARGS.find(name.str());
315 if (found != MPI_TYPE_ARGS.end()) {
316 for (
auto user : F.users()) {
317 if (
auto CI = dyn_cast<CallBase>(user))
318 if (CI->getCalledFunction() == &F) {
320 dyn_cast<Constant>(CI->getArgOperand(found->second))) {
321 while (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) {
322 C = CE->getOperand(0);
324 if (
auto GV = dyn_cast<GlobalVariable>(C)) {
325 if (GV->getName() ==
"ompi_mpi_cxx_bool") {
328 AttributeList::FunctionIndex,
329 Attribute::get(CI->getContext(),
"enzyme_inactive"));
338 if (F.getName() ==
"omp_get_max_threads" ||
339 F.getName() ==
"omp_get_thread_num") {
341#if LLVM_VERSION_MAJOR >= 16
342 F.setOnlyAccessesInaccessibleMemory();
343 F.setOnlyReadsMemory();
345 F.addFnAttr(Attribute::InaccessibleMemOnly);
346 F.addFnAttr(Attribute::ReadOnly);
349 if (F.getName() ==
"frexp" || F.getName() ==
"frexpf" ||
350 F.getName() ==
"frexpl") {
352#if LLVM_VERSION_MAJOR >= 16
353 F.setOnlyAccessesArgMemory();
355 F.addFnAttr(Attribute::ArgMemOnly);
357 F.addParamAttr(1, Attribute::WriteOnly);
359 if (F.getName() ==
"__fd_sincos_1" || F.getName() ==
"__fd_cos_1" ||
360 F.getName() ==
"__mth_i_ipowi") {
362#if LLVM_VERSION_MAJOR >= 16
363 F.setOnlyReadsMemory();
364 F.setOnlyWritesMemory();
366 F.addFnAttr(Attribute::ReadNone);
370 const char *NonEscapingFns[] = {
372 "julia.get_pgcstack",
375 "_ZNSt6chrono3_V212steady_clock3nowEv",
376 "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE9_M_"
378 "_ZNKSt8__detail20_Prime_rehash_policy14_M_need_rehashEmmm",
386 "cublasSetStream_v2",
388 "cuDeviceGetMemPool",
389 "cuStreamSynchronize",
395 "cuDriverGetVersion",
396 "cudaRuntimeGetVersion",
398 "cuMemPoolGetAttribute",
400 "cuDeviceGetAttribute",
401 "cuDevicePrimaryCtxRetain",
403 for (
auto fname : NonEscapingFns)
407 AttributeList::FunctionIndex,
408 Attribute::get(F.getContext(),
"enzyme_no_escaping_allocation"));
410 changed |= attributeTablegen(F);
414void ZeroMemory(llvm::IRBuilder<> &Builder, llvm::Type *T, llvm::Value *obj,
417 CustomZero(wrap(&Builder), wrap(T), wrap(obj), isTape);
419 Builder.CreateStore(Constant::getNullValue(T), obj);
424 llvm::IRBuilder<> &B) {
425 SmallVector<llvm::Instruction *, 2> res;
429 for (
size_t i = 0; i < size; i++) {
430 res.push_back(cast<Instruction>(unwrap(ptr[i])));
444 bool ZeroInit, llvm::Type *RT) {
446 llvm::PointerType *allocType;
448 auto i64 = Type::getInt64Ty(newFunc->getContext());
449 BasicBlock *BB = BasicBlock::Create(M.getContext(),
"entry", newFunc);
451 auto P = B.CreatePHI(i64, 1);
452 CallInst *malloccall;
453 Instruction *SubZero =
nullptr;
456 custom = F->getName() !=
"malloc";
458 allocType = cast<PointerType>(malloccall->getType());
459 if (ZeroInit && !SubZero)
461 BB->eraseFromParent();
464 Type *types[] = {allocType, Type::getInt64Ty(M.getContext()),
465 Type::getInt64Ty(M.getContext())};
466 std::string name =
"__enzyme_exponentialallocation";
470 name +=
".custom@" + std::to_string((
size_t)RT);
472 FunctionType *FT = FunctionType::get(allocType, types,
false);
474 if (newFunc->hasFnAttribute(
"enzymejl_world")) {
475 AL = AL.addFnAttribute(newFunc->getContext(),
476 newFunc->getFnAttribute(
"enzymejl_world"));
478 Function *F = cast<Function>(M.getOrInsertFunction(name, FT, AL).getCallee());
483 F->setLinkage(Function::LinkageTypes::InternalLinkage);
484 F->addFnAttr(Attribute::AlwaysInline);
485 F->addFnAttr(Attribute::NoUnwind);
486 BasicBlock *entry = BasicBlock::Create(M.getContext(),
"entry", F);
487 BasicBlock *grow = BasicBlock::Create(M.getContext(),
"grow", F);
488 BasicBlock *ok = BasicBlock::Create(M.getContext(),
"ok", F);
490 IRBuilder<> B(entry);
492 Argument *ptr = F->arg_begin();
494 Argument *size = ptr + 1;
495 size->setName(
"size");
496 Argument *tsize = size + 1;
497 tsize->setName(
"tsize");
499 Value *hasOne = B.CreateICmpNE(
500 B.CreateAnd(size, ConstantInt::get(size->getType(), 1,
false)),
501 ConstantInt::get(size->getType(), 0,
false));
505 B.CreateAnd(B.CreateICmpULT(B.CreateCall(popCnt, {size}),
506 ConstantInt::get(types[1], 3,
false)),
510 B.SetInsertPoint(grow);
514 {size, ConstantInt::getTrue(M.getContext())});
516 B.CreateShl(tsize, B.CreateSub(ConstantInt::get(types[1], 64,
false), lz,
522 B.CreateSelect(B.CreateICmpEQ(size, ConstantInt::get(size->getType(), 1)),
523 ConstantInt::get(next->getType(), 0),
524 B.CreateLShr(next, ConstantInt::get(next->getType(), 1)));
526 auto Arch = llvm::Triple(M.getTargetTriple()).getArch();
527 bool forceMalloc = Arch == Triple::nvptx || Arch == Triple::nvptx64;
529 if (!custom && !forceMalloc) {
530 auto reallocF = M.getOrInsertFunction(
"realloc", allocType, allocType,
531 Type::getInt64Ty(M.getContext()));
533 Value *args[] = {B.CreatePointerCast(ptr, allocType), next};
534 gVal = B.CreateCall(reallocF, args);
536 Value *tsize = ConstantInt::get(
538 newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits(RT) / 8);
539 auto elSize = B.CreateUDiv(next, tsize,
"",
true);
540 Instruction *SubZero =
nullptr;
544 PointerType::get(Type::getInt8Ty(gVal->getContext()),
545 cast<PointerType>(gVal->getType())->getAddressSpace());
546 gVal = B.CreatePointerCast(gVal, bTy);
547 auto pVal = B.CreatePointerCast(ptr, gVal->getType());
549 Value *margs[] = {gVal, pVal, prevSize,
550 ConstantInt::getFalse(M.getContext())};
551 Type *tys[] = {margs[0]->getType(), margs[1]->getType(),
552 margs[2]->getType()};
554 B.CreateCall(memsetF, margs);
557 IRBuilder<> BB(SubZero);
558 Value *zeroSize = BB.CreateSub(next, prevSize);
559 Value *tmp = SubZero->getOperand(0);
560 Type *tmpT = tmp->getType();
561 tmp = BB.CreatePointerCast(tmp, bTy);
562 tmp = BB.CreateInBoundsGEP(Type::getInt8Ty(tmp->getContext()), tmp,
564 tmp = BB.CreatePointerCast(tmp, tmpT);
565 SubZero->setOperand(0, tmp);
566 SubZero->setOperand(2, zeroSize);
571 Value *zeroSize = B.CreateSub(next, prevSize);
573 Value *margs[] = {B.CreateInBoundsGEP(B.getInt8Ty(), gVal, prevSize),
574 B.getInt8(0), zeroSize, B.getFalse()};
575 Type *tys[] = {margs[0]->getType(), margs[2]->getType()};
577 B.CreateCall(memsetF, margs);
579 gVal = B.CreatePointerCast(gVal, ptr->getType());
582 B.SetInsertPoint(ok);
583 auto phi = B.CreatePHI(ptr->getType(), 2);
584 phi->addIncoming(gVal, grow);
585 phi->addIncoming(ptr, entry);
591 llvm::Type *T, llvm::Value *OuterCount,
592 llvm::Value *InnerCount,
593 const llvm::Twine &Name,
594 llvm::CallInst **caller,
bool ZeroMem) {
595 auto newFunc = B.GetInsertBlock()->getParent();
597 Value *tsize = ConstantInt::get(
598 InnerCount->getType(),
599 newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits(T) / 8);
607 B.CreateMul(tsize, InnerCount,
"",
true,
612 newFunc, ZeroMem, T),
615 *caller = realloccall;
620 const Twine &Name, CallInst **caller,
621 Instruction **ZeroMem,
bool isDefault) {
623 auto &M = *Builder.GetInsertBlock()->getParent()->getParent();
624 auto AlignI = M.getDataLayout().getTypeAllocSizeInBits(T) / 8;
625 auto Align = ConstantInt::get(
Count->getType(), AlignI);
626 CallInst *malloccall =
nullptr;
628 LLVMValueRef wzeromem =
nullptr;
630 wrap(Align), isDefault,
631 ZeroMem ? &wzeromem :
nullptr));
632 if (isa<UndefValue>(res))
634 if (isa<Constant>(res))
636 if (
auto I = dyn_cast<Instruction>(res))
639 malloccall = dyn_cast<CallInst>(res);
640 if (malloccall ==
nullptr) {
641 malloccall = cast<CallInst>(cast<Instruction>(res)->getOperand(0));
644 *ZeroMem = cast_or_null<Instruction>(unwrap(wzeromem));
648#if LLVM_VERSION_MAJOR > 17
650 Builder.CreateMalloc(
Count->getType(), T, Align,
Count,
nullptr, Name);
652 if (Builder.GetInsertPoint() == Builder.GetInsertBlock()->end()) {
653 res = CallInst::CreateMalloc(Builder.GetInsertBlock(),
Count->getType(),
654 T, Align,
Count,
nullptr, Name);
655 Builder.SetInsertPoint(Builder.GetInsertBlock());
657 res = CallInst::CreateMalloc(&*Builder.GetInsertPoint(),
Count->getType(),
658 T, Align,
Count,
nullptr, Name);
660 if (!cast<Instruction>(res)->getParent())
661 Builder.Insert(cast<Instruction>(res));
664 malloccall = dyn_cast<CallInst>(res);
665 if (malloccall ==
nullptr) {
666 malloccall = cast<CallInst>(cast<Instruction>(res)->getOperand(0));
670 if (
auto BI = dyn_cast<BinaryOperator>(malloccall->getArgOperand(0))) {
671 if (BI->getOpcode() == BinaryOperator::Mul) {
672 if ((BI->getOperand(0) == Align && BI->getOperand(1) ==
Count) ||
673 (BI->getOperand(1) == Align && BI->getOperand(0) ==
Count))
674 BI->setHasNoSignedWrap(
true);
675 BI->setHasNoUnsignedWrap(
true);
679 if (
auto ci = dyn_cast<ConstantInt>(
Count)) {
680#if LLVM_VERSION_MAJOR >= 14
681 malloccall->addDereferenceableRetAttr(ci->getLimitedValue() * AlignI);
682#if !defined(FLANG) && !defined(ROCM)
683 AttrBuilder B(ci->getContext());
687 B.addDereferenceableOrNullAttr(ci->getLimitedValue() * AlignI);
688 malloccall->setAttributes(malloccall->getAttributes().addRetAttributes(
689 malloccall->getContext(), B));
691 malloccall->addDereferenceableAttr(llvm::AttributeList::ReturnIndex,
692 ci->getLimitedValue() * AlignI);
693 malloccall->addDereferenceableOrNullAttr(llvm::AttributeList::ReturnIndex,
694 ci->getLimitedValue() * AlignI);
699#if LLVM_VERSION_MAJOR >= 14
700 malloccall->addAttributeAtIndex(AttributeList::ReturnIndex,
702 malloccall->addAttributeAtIndex(AttributeList::ReturnIndex,
705 malloccall->addAttribute(AttributeList::ReturnIndex, Attribute::NoAlias);
706 malloccall->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
710 *caller = malloccall;
713 auto PT = cast<PointerType>(malloccall->getType());
714 Value *tozero = malloccall;
716 bool needsCast =
false;
717#if LLVM_VERSION_MAJOR < 17
718#if LLVM_VERSION_MAJOR >= 15
719 if (PT->getContext().supportsTypedPointers()) {
721 needsCast = !PT->getPointerElementType()->isIntegerTy(8);
722#if LLVM_VERSION_MAJOR >= 15
727 tozero = Builder.CreatePointerCast(
728 tozero, PointerType::get(Type::getInt8Ty(PT->getContext()),
729 PT->getAddressSpace()));
731 tozero, ConstantInt::get(Type::getInt8Ty(malloccall->getContext()), 0),
732 Builder.CreateMul(Align,
Count,
"",
true,
true),
733 ConstantInt::getFalse(malloccall->getContext())};
734 Type *tys[] = {args[0]->getType(), args[2]->getType()};
736 *ZeroMem = Builder.CreateCall(
743 CallInst *res =
nullptr;
746 res = dyn_cast_or_null<CallInst>(
751 Builder.CreatePointerCast(ToFree,
getInt8PtrTy(ToFree->getContext()));
752#if LLVM_VERSION_MAJOR > 17
753 res = cast<CallInst>(Builder.CreateFree(ToFree));
755 if (Builder.GetInsertPoint() == Builder.GetInsertBlock()->end()) {
756 res = cast<CallInst>(
757 CallInst::CreateFree(ToFree, Builder.GetInsertBlock()));
758 Builder.SetInsertPoint(Builder.GetInsertBlock());
760 res = cast<CallInst>(
761 CallInst::CreateFree(ToFree, &*Builder.GetInsertPoint()));
763 if (!cast<Instruction>(res)->getParent())
764 Builder.Insert(cast<Instruction>(res));
766#if LLVM_VERSION_MAJOR >= 14
767 res->addAttributeAtIndex(AttributeList::FirstArgIndex, Attribute::NonNull);
769 res->addAttribute(AttributeList::FirstArgIndex, Attribute::NonNull);
776 const llvm::DiagnosticLocation &Loc,
777 const llvm::Instruction *CodeRegion)
778 :
EnzymeWarning(RemarkName, Loc, CodeRegion->getParent()->getParent()) {}
781 const llvm::DiagnosticLocation &Loc,
782 const llvm::Function *CodeRegion)
783 : DiagnosticInfoUnsupported(*CodeRegion, RemarkName, Loc, DS_Warning) {}
786 const llvm::DiagnosticLocation &Loc,
787 const llvm::Instruction *CodeRegion)
788 :
EnzymeFailure(RemarkName, Loc, CodeRegion->getParent()->getParent()) {}
791 const llvm::DiagnosticLocation &Loc,
792 const llvm::Function *CodeRegion)
793 : DiagnosticInfoUnsupported(*CodeRegion, RemarkName, Loc) {}
797 if (
auto VT = dyn_cast<VectorType>(T)) {
798#if LLVM_VERSION_MAJOR >= 12
799 auto len = VT->getElementCount().getFixedValue();
801 auto len = VT->getNumElements();
803 return "vec" + std::to_string(len) +
tofltstr(VT->getElementType());
805 switch (T->getTypeID()) {
808 case Type::FloatTyID:
810 case Type::DoubleTyID:
812 case Type::X86_FP80TyID:
814 case Type::BFloatTyID:
816 case Type::FP128TyID:
818 case Type::PPC_FP128TyID:
821 llvm_unreachable(
"Invalid floating type");
826 llvm::Constant *s = llvm::ConstantDataArray::getString(M.getContext(), Str);
827 auto *gv =
new llvm::GlobalVariable(
828 M, s->getType(),
true, llvm::GlobalValue::PrivateLinkage, s,
".str");
829 gv->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
830 Value *Idxs[2] = {ConstantInt::get(Type::getInt32Ty(M.getContext()), 0),
831 ConstantInt::get(Type::getInt32Ty(M.getContext()), 0)};
832 return ConstantExpr::getInBoundsGetElementPtr(s->getType(), gv, Idxs);
836 SmallPtrSet<llvm::Instruction *, 8> visited;
838 if (visited.contains(inst))
840 visited.insert(inst);
843 if (
auto dbgLoc = inst->getDebugLoc()) {
844 auto *loc = dbgLoc.get();
846 if (
auto *scope = loc->getScope()) {
847 StringRef name = scope->getName();
849 while (!name.empty() && name.back() ==
';')
850 name = name.drop_back();
851 if (
auto *file = scope->getFile()) {
852 StringRef dir = file->getDirectory();
853 StringRef fn = file->getFilename();
854 ss <<
" in '" << name <<
"' at ";
857 ss << fn <<
":" << loc->getLine() <<
"\n";
859 ss <<
" in '" << name <<
"' at unknown:" << loc->getLine() <<
"\n";
862 loc = loc->getInlinedAt();
867 Function *f = inst->getParent()->getParent();
870 SmallVector<CallInst *, 4> callersWithDbg;
871 for (
auto *U : f->users()) {
872 auto *CI = dyn_cast<CallInst>(U);
875 if (!CI->getDebugLoc())
877 callersWithDbg.push_back(CI);
880 if (callersWithDbg.empty())
884 SmallVector<CallInst *, 4> uniqueCallSites;
885 SmallPtrSet<const MDNode *, 4> seenMD;
886 for (
auto *CI : callersWithDbg) {
887 if (seenMD.insert(CI->getDebugLoc().getAsMDNode()).second)
888 uniqueCallSites.push_back(CI);
891 if (uniqueCallSites.size() > 1) {
892 ss <<
" (multiple call sites)\n";
894 }
else if (uniqueCallSites.size() == 1) {
895 inst = uniqueCallSites[0];
903 llvm::Value *shadow,
const char *Message,
904 llvm::DebugLoc &&loc, llvm::Instruction *orig) {
905 Module &M = *B.GetInsertBlock()->getParent()->getParent();
906 std::string name =
"__enzyme_runtimeinactiveerr";
908 static int count = 0;
909 name += std::to_string(count);
912 FunctionType *FT = FunctionType::get(Type::getVoidTy(M.getContext()),
913 {getInt8PtrTy(M.getContext()),
914 getInt8PtrTy(M.getContext()),
915 getInt8PtrTy(M.getContext())},
918 Function *F = cast<Function>(M.getOrInsertFunction(name, FT).getCallee());
921 F->setLinkage(Function::LinkageTypes::InternalLinkage);
922 F->addFnAttr(Attribute::AlwaysInline);
926 BasicBlock *entry = BasicBlock::Create(M.getContext(),
"entry", F);
927 BasicBlock *error = BasicBlock::Create(M.getContext(),
"error", F);
928 BasicBlock *end = BasicBlock::Create(M.getContext(),
"end", F);
930 auto prim = F->arg_begin();
931 prim->setName(
"primal");
932 auto shadow = prim + 1;
933 shadow->setName(
"shadow");
937 IRBuilder<> EB(entry);
938 EB.CreateCondBr(EB.CreateICmpEQ(prim, shadow), error, end);
940 EB.SetInsertPoint(error);
946 FunctionType::get(Type::getInt32Ty(M.getContext()),
947 {getInt8PtrTy(M.getContext())},
false);
949 auto PutsF = M.getOrInsertFunction(
"puts", FT);
950 EB.CreateCall(PutsF, msg);
953 FunctionType::get(Type::getVoidTy(M.getContext()),
954 {Type::getInt32Ty(M.getContext())},
false);
956 auto ExitF = M.getOrInsertFunction(
"exit", FT2);
958 ConstantInt::get(Type::getInt32Ty(M.getContext()), 1));
960 EB.CreateUnreachable();
962 EB.SetInsertPoint(end);
966 std::string Message2 = Message;
969 raw_string_ostream ss(
str);
970 ss << Message <<
"\n";
974 Value *args[] = {B.CreatePointerCast(primal,
getInt8PtrTy(M.getContext())),
975 B.CreatePointerCast(shadow,
getInt8PtrTy(M.getContext())),
977 auto call = B.CreateCall(F, args);
978 call->setDebugLoc(loc);
983 return Type::getDoubleTy(ctx);
985 return Type::getFloatTy(ctx);
988 return Type::getFloatTy(ctx);
989 return VectorType::get(Type::getFloatTy(ctx), 2,
false);
992 return Type::getDoubleTy(ctx);
993 return VectorType::get(Type::getDoubleTy(ctx), 2,
false);
995 assert(
false &&
"Unreachable");
1002 return IntegerType::get(ctx, 64);
1004 return IntegerType::get(ctx, 32);
1012 unsigned dstaddr,
unsigned srcaddr,
1013 unsigned bitwidth) {
1014 assert(elementType->isFloatingPointTy());
1015 std::string name =
"__enzyme_memcpy";
1017 name += std::to_string(bitwidth);
1018 name +=
"add_" +
tofltstr(elementType) +
"da" + std::to_string(dstalign) +
1019 "sa" + std::to_string(srcalign);
1021 name +=
"dadd" + std::to_string(dstaddr);
1023 name +=
"sadd" + std::to_string(srcaddr);
1025 FunctionType::get(Type::getVoidTy(M.getContext()),
1026 {PointerType::get(elementType, dstaddr),
1027 PointerType::get(elementType, srcaddr),
1028 IntegerType::get(M.getContext(), bitwidth)},
1031 Function *F = cast<Function>(M.getOrInsertFunction(name, FT).getCallee());
1036 F->setLinkage(Function::LinkageTypes::InternalLinkage);
1037#if LLVM_VERSION_MAJOR >= 16
1038 F->setOnlyAccessesArgMemory();
1040 F->addFnAttr(Attribute::ArgMemOnly);
1042 F->addFnAttr(Attribute::NoUnwind);
1043 F->addFnAttr(Attribute::AlwaysInline);
1047 BasicBlock *entry = BasicBlock::Create(M.getContext(),
"entry", F);
1048 BasicBlock *body = BasicBlock::Create(M.getContext(),
"for.body", F);
1049 BasicBlock *end = BasicBlock::Create(M.getContext(),
"for.end", F);
1051 auto dst = F->arg_begin();
1052 dst->setName(
"dst");
1054 src->setName(
"src");
1056 num->setName(
"num");
1059 IRBuilder<> B(entry);
1060 B.CreateCondBr(B.CreateICmpEQ(num, ConstantInt::get(num->getType(), 0)),
1064 auto elSize = (M.getDataLayout().getTypeSizeInBits(elementType) + 7) / 8;
1066 IRBuilder<> B(body);
1067 B.setFastMathFlags(
getFast());
1068 PHINode *idx = B.CreatePHI(num->getType(), 2,
"idx");
1069 idx->addIncoming(ConstantInt::get(num->getType(), 0), entry);
1071 Value *dsti = B.CreateInBoundsGEP(elementType, dst, idx,
"dst.i");
1072 LoadInst *dstl = B.CreateLoad(elementType, dsti,
"dst.i.l");
1073 StoreInst *dsts = B.CreateStore(Constant::getNullValue(elementType), dsti);
1078 if (elSize % dstalign == 0) {
1080 }
else if (dstalign % elSize == 0) {
1095 if (elSize % srcalign == 0) {
1097 }
else if (srcalign % elSize == 0) {
1110 dstl->setAlignment(Align(dstalign));
1111 dsts->setAlignment(Align(dstalign));
1114 Value *srci = B.CreateInBoundsGEP(elementType, src, idx,
"src.i");
1115 LoadInst *srcl = B.CreateLoad(elementType, srci,
"src.i.l");
1116 StoreInst *srcs = B.CreateStore(B.CreateFAdd(srcl, dstl), srci);
1118 srcl->setAlignment(Align(srcalign));
1119 srcs->setAlignment(Align(srcalign));
1123 B.CreateNUWAdd(idx, ConstantInt::get(num->getType(), 1),
"idx.next");
1124 idx->addIncoming(next, body);
1125 B.CreateCondBr(B.CreateICmpEQ(num, next), end, body);
1136 Value *
const base, Value *lda, Value *row,
1138 Type *intType = row->getType();
1140 layout ? B.CreateICmpEQ(layout, ConstantInt::get(layout->getType(), 101))
1142 Value *offset =
nullptr;
1144 offset = B.CreateMul(
1145 row,
CreateSelect(B, is_row_maj, lda, ConstantInt::get(intType, 1)));
1146 offset = B.CreateAdd(
1149 ConstantInt::get(intType, 1), lda)));
1151 offset = B.CreateMul(row, lda);
1157 if (base->getType()->isIntegerTy())
1158 ptr = B.CreateIntToPtr(ptr,
getUnqual(fpType));
1160#if LLVM_VERSION_MAJOR < 17
1161#if LLVM_VERSION_MAJOR >= 15
1162 if (ptr->getContext().supportsTypedPointers()) {
1164 if (fpType != ptr->getType()->getPointerElementType()) {
1165 ptr = B.CreatePointerCast(
1168 fpType, cast<PointerType>(ptr->getType())->getAddressSpace()));
1170#if LLVM_VERSION_MAJOR >= 15
1174 ptr = B.CreateGEP(fpType, ptr, offset);
1176 if (base->getType()->isIntegerTy()) {
1177 ptr = B.CreatePtrToInt(ptr, base->getType());
1178 }
else if (ptr->getType() != base->getType()) {
1179 ptr = B.CreatePointerCast(ptr, base->getType());
1185 BlasInfo blas,
bool byRef, llvm::Value *layout,
1186 llvm::Value *islower, llvm::Value *A, llvm::Value *lda,
1189 const bool cublasv2 =
1190 blas.
prefix ==
"cublas" && StringRef(blas.
suffix).contains(
"v2");
1192 const bool cublas = blas.
prefix ==
"cublas";
1193 auto &M = *B.GetInsertBlock()->getParent()->getParent();
1195 llvm::Type *intType = N->getType();
1197 auto fnc_name =
"__enzyme_copy_lower_to_upper" + blas.
floatType +
1200 SmallVector<Type *, 1> tys = {islower->getType(), A->getType(),
1201 lda->getType(), N->getType()};
1203 tys.insert(tys.begin(), layout->getType());
1204 auto ltuFT = FunctionType::get(B.getVoidTy(), tys,
false);
1206 auto F0 = M.getOrInsertFunction(fnc_name, ltuFT);
1208 SmallVector<Value *, 1> args = {islower, A, lda, N};
1210 args.insert(args.begin(), layout);
1211 auto C = B.CreateCall(F0, args);
1219 F->setLinkage(Function::LinkageTypes::InternalLinkage);
1220#if LLVM_VERSION_MAJOR >= 16
1221 F->setOnlyAccessesArgMemory();
1223 F->addFnAttr(Attribute::ArgMemOnly);
1225 F->addFnAttr(Attribute::NoUnwind);
1226 F->addFnAttr(Attribute::AlwaysInline);
1227 if (A->getType()->isPointerTy())
1230 BasicBlock *entry = BasicBlock::Create(M.getContext(),
"entry", F);
1231 BasicBlock *loop = BasicBlock::Create(M.getContext(),
"loop", F);
1232 BasicBlock *end = BasicBlock::Create(M.getContext(),
"for.end", F);
1234 auto arg = F->arg_begin();
1235 Argument *layoutarg =
nullptr;
1238 layoutarg->setName(
"layout");
1241 auto islowerarg = arg;
1242 islowerarg->setName(
"islower");
1248 ldaarg->setName(
"lda");
1253 IRBuilder<> EB(entry);
1255 auto one = ConstantInt::get(intType, 1);
1256 auto zero = ConstantInt::get(intType, 0);
1258 Value *N_minus_1 = EB.CreateSub(Narg, one);
1260 IRBuilder<> LB(loop);
1262 auto i = LB.CreatePHI(intType, 2);
1263 i->addIncoming(zero, entry);
1264 auto i_plus_one = LB.CreateAdd(i, one,
"",
true,
true);
1265 i->addIncoming(i_plus_one, loop);
1267 Value *copyArgs[] = {
1278 byRef, cublas,
nullptr, EB),
1287 byRef, cublas,
nullptr, EB)};
1289 Type *copyTys[] = {copyArgs[0]->getType(), copyArgs[1]->getType(),
1290 copyArgs[2]->getType(), copyArgs[3]->getType(),
1291 copyArgs[4]->getType()};
1293 FunctionType *FT = FunctionType::get(B.getVoidTy(), copyTys,
false);
1296 (cublasv2 ?
"" : blas.
suffix);
1298 auto copyfn = M.getOrInsertFunction(copy_name, FT);
1299 LB.CreateCall(copyfn, copyArgs);
1302 LB.CreateCondBr(LB.CreateICmpEQ(i_plus_one, N_minus_1), end, loop);
1304 EB.CreateCondBr(EB.CreateICmpSLE(N_minus_1, zero), end, loop);
1310 if (llvm::verifyFunction(*F, &llvm::errs())) {
1311 llvm::errs() << *F <<
"\n";
1312 report_fatal_error(
"helper function failed verification");
1317 llvm::ArrayRef<llvm::Value *> args,
1318 llvm::Type *copy_retty,
1319 llvm::ArrayRef<llvm::OperandBundleDef> bundles) {
1320 const bool cublasv2 =
1321 blas.
prefix ==
"cublas" && StringRef(blas.
suffix).contains(
"v2");
1323 (cublasv2 ?
"" : blas.
suffix);
1325 SmallVector<Type *, 1> tys;
1326 for (
auto arg : args)
1327 tys.push_back(arg->getType());
1329 FunctionType *FT = FunctionType::get(copy_retty, tys,
false);
1330 auto fn = M.getOrInsertFunction(copy_name, FT);
1331 B.CreateCall(fn, args, bundles);
1337 BlasInfo blas, llvm::ArrayRef<llvm::Value *> args,
1338 llvm::ArrayRef<llvm::OperandBundleDef> bundles) {
1342 SmallVector<Type *, 1> tys;
1343 for (
auto arg : args)
1344 tys.push_back(arg->getType());
1346 auto FT = FunctionType::get(Type::getVoidTy(M.getContext()), tys,
false);
1347 auto fn = M.getOrInsertFunction(copy_name, FT);
1348 B.CreateCall(fn, args, bundles);
1355 IntegerType *IT, Type *BlasCT, Type *BlasFPT,
1356 Type *BlasPT, Type *BlasIT, Type *fpTy,
1357 ArrayRef<Value *> args,
1358 ArrayRef<OperandBundleDef> bundles,
bool byRef,
1364 auto FDiagUpdateT = FunctionType::get(
1366 {BlasCT, BlasIT, BlasFPT, BlasPT, BlasIT, BlasPT, BlasIT, BlasPT},
false);
1368 cast<Function>(M.getOrInsertFunction(fnc_name, FDiagUpdateT).getCallee());
1371 B.CreateCall(F, args, bundles);
1376 F->setLinkage(Function::LinkageTypes::InternalLinkage);
1377#if LLVM_VERSION_MAJOR >= 16
1378 F->setOnlyAccessesArgMemory();
1380 F->addFnAttr(Attribute::ArgMemOnly);
1382 F->addFnAttr(Attribute::NoUnwind);
1383 F->addFnAttr(Attribute::AlwaysInline);
1388 F->addParamAttr(3, Attribute::NoAlias);
1389 F->addParamAttr(5, Attribute::NoAlias);
1390 F->addParamAttr(7, Attribute::NoAlias);
1391 F->addParamAttr(3, Attribute::ReadOnly);
1392 F->addParamAttr(5, Attribute::ReadOnly);
1395 F->addParamAttr(2, Attribute::NoAlias);
1396 F->addParamAttr(2, Attribute::ReadOnly);
1400 BasicBlock *entry = BasicBlock::Create(M.getContext(),
"entry", F);
1401 BasicBlock *init = BasicBlock::Create(M.getContext(),
"init", F);
1402 BasicBlock *uper_code = BasicBlock::Create(M.getContext(),
"uper", F);
1403 BasicBlock *lower_code = BasicBlock::Create(M.getContext(),
"lower", F);
1404 BasicBlock *end = BasicBlock::Create(M.getContext(),
"for.end", F);
1407 auto blasuplo = F->arg_begin();
1408 blasuplo->setName(
"blasuplo");
1409 auto blasn = blasuplo + 1;
1410 blasn->setName(
"blasn");
1411 auto blasalpha = blasn + 1;
1412 blasalpha->setName(
"blasalpha");
1413 auto blasx = blasalpha + 1;
1414 blasx->setName(
"blasx");
1415 auto blasincx = blasx + 1;
1416 blasincx->setName(
"blasincx");
1417 auto blasdy = blasx + 1;
1418 blasdy->setName(
"blasdy");
1419 auto blasincy = blasdy + 1;
1420 blasincy->setName(
"blasincy");
1421 auto blasdAP = blasincy + 1;
1422 blasdAP->setName(
"blasdAP");
1442 IRBuilder<> B1(entry);
1444 Value *incx =
load_if_ref(B1, IT, blasincx, byRef);
1445 Value *incy =
load_if_ref(B1, IT, blasincy, byRef);
1446 Value *alpha = blasalpha;
1448 auto VP = B1.CreatePointerCast(
1452 cast<PointerType>(blasalpha->getType())->getAddressSpace()));
1453 alpha = B1.CreateLoad(fpTy, VP);
1455 Value *is_l =
is_lower(B1, blasuplo, byRef,
false);
1456 B1.CreateCondBr(B1.CreateICmpEQ(n, ConstantInt::get(IT, 0)), end, init);
1458 IRBuilder<> B2(init);
1459 Value *xfloat = B2.CreatePointerCast(
1462 fpTy, cast<PointerType>(blasx->getType())->getAddressSpace()));
1463 Value *dyfloat = B2.CreatePointerCast(
1466 fpTy, cast<PointerType>(blasdy->getType())->getAddressSpace()));
1467 Value *dAPfloat = B2.CreatePointerCast(
1470 fpTy, cast<PointerType>(blasdAP->getType())->getAddressSpace()));
1471 B2.CreateCondBr(is_l, lower_code, uper_code);
1473 IRBuilder<> B3(uper_code);
1474 B3.setFastMathFlags(
getFast());
1476 PHINode *iter = B3.CreatePHI(IT, 2,
"iteration");
1477 PHINode *kval = B3.CreatePHI(IT, 2,
"k");
1478 iter->addIncoming(ConstantInt::get(IT, 0), init);
1479 kval->addIncoming(ConstantInt::get(IT, 0), init);
1481 B3.CreateAdd(iter, ConstantInt::get(IT, 1),
"iter.next");
1483 Value *kvalnext = B3.CreateAdd(kval, iternext,
"k.next");
1484 iter->addIncoming(iternext, uper_code);
1485 kval->addIncoming(kvalnext, uper_code);
1487 Value *xidx = B3.CreateNUWMul(iter, incx,
"x.idx");
1488 Value *yidx = B3.CreateNUWMul(iter, incy,
"y.idx");
1489 Value *x = B3.CreateInBoundsGEP(fpTy, xfloat, xidx,
"x.ptr");
1490 Value *y = B3.CreateInBoundsGEP(fpTy, dyfloat, yidx,
"y.ptr");
1491 Value *xval = B3.CreateLoad(fpTy, x,
"x.val");
1492 Value *yval = B3.CreateLoad(fpTy, y,
"y.val");
1493 Value *xy = B3.CreateFMul(xval, yval,
"xy");
1494 Value *xyalpha = B3.CreateFMul(xy, alpha,
"xy.alpha");
1495 Value *kptr = B3.CreateInBoundsGEP(fpTy, dAPfloat, kval,
"k.ptr");
1496 Value *kvalloaded = B3.CreateLoad(fpTy, kptr,
"k.val");
1497 Value *kvalnew = B3.CreateFSub(kvalloaded, xyalpha,
"k.val.new");
1498 B3.CreateStore(kvalnew, kptr);
1500 B3.CreateCondBr(B3.CreateICmpEQ(iternext, n), end, uper_code);
1503 IRBuilder<> B4(lower_code);
1504 B4.setFastMathFlags(
getFast());
1506 PHINode *iter = B4.CreatePHI(IT, 2,
"iteration");
1507 PHINode *kval = B4.CreatePHI(IT, 2,
"k");
1508 iter->addIncoming(ConstantInt::get(IT, 0), init);
1509 kval->addIncoming(ConstantInt::get(IT, 0), init);
1511 B4.CreateAdd(iter, ConstantInt::get(IT, 1),
"iter.next");
1512 Value *ktmp = B4.CreateAdd(n, ConstantInt::get(IT, 1),
"tmp.val");
1513 Value *ktmp2 = B4.CreateSub(ktmp, iternext,
"tmp.val.other");
1514 Value *kvalnext = B4.CreateAdd(kval, ktmp2,
"k.next");
1515 iter->addIncoming(iternext, lower_code);
1516 kval->addIncoming(kvalnext, lower_code);
1518 Value *xidx = B4.CreateNUWMul(iter, incx,
"x.idx");
1519 Value *yidx = B4.CreateNUWMul(iter, incy,
"y.idx");
1520 Value *x = B4.CreateInBoundsGEP(fpTy, xfloat, xidx,
"x.ptr");
1521 Value *y = B4.CreateInBoundsGEP(fpTy, dyfloat, yidx,
"y.ptr");
1522 Value *xval = B4.CreateLoad(fpTy, x,
"x.val");
1523 Value *yval = B4.CreateLoad(fpTy, y,
"y.val");
1524 Value *xy = B4.CreateFMul(xval, yval,
"xy");
1525 Value *xyalpha = B4.CreateFMul(xy, alpha,
"xy.alpha");
1526 Value *kptr = B4.CreateInBoundsGEP(fpTy, dAPfloat, kval,
"k.ptr");
1527 Value *kvalloaded = B4.CreateLoad(fpTy, kptr,
"k.val");
1528 Value *kvalnew = B4.CreateFSub(kvalloaded, xyalpha,
"k.val.new");
1529 B4.CreateStore(kvalnew, kptr);
1531 B4.CreateCondBr(B4.CreateICmpEQ(iternext, n), end, lower_code);
1534 IRBuilder<> B5(end);
1537 B.CreateCall(F, args, bundles);
1543 IntegerType *IT, Type *BlasPT, Type *BlasIT, Type *fpTy,
1544 llvm::ArrayRef<llvm::Value *> args,
1545 const llvm::ArrayRef<llvm::OperandBundleDef> bundles,
1546 bool byRef,
bool cublas,
bool julia_decl) {
1547 assert(fpTy->isFloatingPointTy());
1550 std::string prod_name =
"__enzyme_inner_prod" + blas.
floatType + blas.
suffix;
1552 FunctionType::get(fpTy, {BlasIT, BlasIT, BlasPT, BlasIT, BlasPT},
false);
1554 cast<Function>(M.getOrInsertFunction(prod_name, FInnerProdT).getCallee());
1557 return B.CreateCall(F, args, bundles);
1562 FunctionType::get(fpTy, {BlasIT, BlasPT, BlasIT, BlasPT, BlasIT},
false);
1563 auto FDot = M.getOrInsertFunction(dot_name, FDotT);
1566 F->setLinkage(Function::LinkageTypes::InternalLinkage);
1567#if LLVM_VERSION_MAJOR >= 16
1568 F->setOnlyAccessesArgMemory();
1569 F->setOnlyReadsMemory();
1571 F->addFnAttr(Attribute::ArgMemOnly);
1572 F->addFnAttr(Attribute::ReadOnly);
1574 F->addFnAttr(Attribute::NoUnwind);
1575 F->addFnAttr(Attribute::AlwaysInline);
1579 F->addParamAttr(2, Attribute::NoAlias);
1580 F->addParamAttr(4, Attribute::NoAlias);
1581 F->addParamAttr(2, Attribute::ReadOnly);
1582 F->addParamAttr(4, Attribute::ReadOnly);
1585 BasicBlock *entry = BasicBlock::Create(M.getContext(),
"entry", F);
1586 BasicBlock *init = BasicBlock::Create(M.getContext(),
"init.idx", F);
1587 BasicBlock *fastPath = BasicBlock::Create(M.getContext(),
"fast.path", F);
1588 BasicBlock *body = BasicBlock::Create(M.getContext(),
"for.body", F);
1589 BasicBlock *end = BasicBlock::Create(M.getContext(),
"for.end", F);
1596 auto blasm = F->arg_begin();
1597 blasm->setName(
"blasm");
1598 auto blasn = blasm + 1;
1599 blasn->setName(
"blasn");
1600 auto matA = blasn + 1;
1602 auto blaslda = matA + 1;
1603 blaslda->setName(
"lda");
1604 auto matB = blaslda + 1;
1608 IRBuilder<> B1(entry);
1610 cublas,
nullptr, B1,
"constant.one");
1612 if (blasOne->getType() != BlasIT)
1613 blasOne = B1.CreatePointerCast(blasOne, BlasIT,
"intcast.constant.one");
1617 Value *size = B1.CreateNUWMul(m, n,
"mat.size");
1619 B1, size, byRef, cublas, julia_decl ? IT :
nullptr, B1,
"mat.size");
1621 if (blasSize->getType() != BlasIT)
1622 blasSize = B1.CreatePointerCast(blasSize, BlasIT,
"intcast.mat.size");
1623 B1.CreateCondBr(B1.CreateICmpEQ(size, ConstantInt::get(IT, 0)), end, init);
1625 IRBuilder<> B2(init);
1626 B2.setFastMathFlags(
getFast());
1628 Value *Afloat = B2.CreatePointerCast(
1629 matA, PointerType::get(
1630 fpTy, cast<PointerType>(matA->getType())->getAddressSpace()));
1631 Value *Bfloat = B2.CreatePointerCast(
1632 matB, PointerType::get(
1633 fpTy, cast<PointerType>(matB->getType())->getAddressSpace()));
1634 B2.CreateCondBr(B2.CreateICmpEQ(m, lda), fastPath, body);
1639 IRBuilder<> B3(fastPath);
1640 B3.setFastMathFlags(
getFast());
1641 Value *blasA = B3.CreatePointerCast(matA, BlasPT);
1642 Value *blasB = B3.CreatePointerCast(matB, BlasPT);
1644 B3.CreateCall(FDot, {blasSize, blasA, blasOne, blasB, blasOne});
1647 IRBuilder<> B4(body);
1648 B4.setFastMathFlags(
getFast());
1649 PHINode *Aidx = B4.CreatePHI(IT, 2,
"Aidx");
1650 PHINode *Bidx = B4.CreatePHI(IT, 2,
"Bidx");
1651 PHINode *iter = B4.CreatePHI(IT, 2,
"iteration");
1652 PHINode *sum = B4.CreatePHI(fpTy, 2,
"sum");
1653 Aidx->addIncoming(ConstantInt::get(IT, 0), init);
1654 Bidx->addIncoming(ConstantInt::get(IT, 0), init);
1655 iter->addIncoming(ConstantInt::get(IT, 0), init);
1656 sum->addIncoming(ConstantFP::get(fpTy, 0.0), init);
1658 Value *Ai = B4.CreateInBoundsGEP(fpTy, Afloat, Aidx,
"A.i");
1659 Value *Bi = B4.CreateInBoundsGEP(fpTy, Bfloat, Bidx,
"B.i");
1660 Value *AiDot = B4.CreatePointerCast(Ai, BlasPT);
1661 Value *BiDot = B4.CreatePointerCast(Bi, BlasPT);
1663 B4.CreateCall(FDot, {blasm, AiDot, blasOne, BiDot, blasOne});
1665 Value *Anext = B4.CreateNUWAdd(Aidx, lda,
"Aidx.next");
1666 Value *Bnext = B4.CreateNUWAdd(Aidx, m,
"Bidx.next");
1667 Value *iternext = B4.CreateAdd(iter, ConstantInt::get(IT, 1),
"iter.next");
1668 Value *sumnext = B4.CreateFAdd(sum, newDot);
1670 iter->addIncoming(iternext, body);
1671 Aidx->addIncoming(Anext, body);
1672 Bidx->addIncoming(Bnext, body);
1673 sum->addIncoming(sumnext, body);
1675 B4.CreateCondBr(B4.CreateICmpEQ(iter, n), end, body);
1677 IRBuilder<> B5(end);
1678 PHINode *res = B5.CreatePHI(fpTy, 3,
"res");
1679 res->addIncoming(ConstantFP::get(fpTy, 0.0), entry);
1680 res->addIncoming(sum, body);
1681 res->addIncoming(fastSum, fastPath);
1685 auto res = B.CreateCall(F, args, bundles);
1692 Type *IT,
unsigned dstalign,
1693 unsigned srcalign) {
1694 assert(elementType->isFloatingPointTy());
1695 std::string name =
"__enzyme_memcpy_" +
tofltstr(elementType) +
"_" +
1696 std::to_string(cast<IntegerType>(IT)->getBitWidth()) +
1697 "_da" + std::to_string(dstalign) +
"sa" +
1698 std::to_string(srcalign) +
"stride";
1700 FunctionType::get(Type::getVoidTy(M.getContext()), {T, T, IT, IT},
false);
1702 Function *F = cast<Function>(M.getOrInsertFunction(name, FT).getCallee());
1707 F->setLinkage(Function::LinkageTypes::InternalLinkage);
1708#if LLVM_VERSION_MAJOR >= 16
1709 F->setOnlyAccessesArgMemory();
1711 F->addFnAttr(Attribute::ArgMemOnly);
1713 F->addFnAttr(Attribute::NoUnwind);
1714 F->addFnAttr(Attribute::AlwaysInline);
1716 F->addParamAttr(0, Attribute::NoAlias);
1718 F->addParamAttr(1, Attribute::NoAlias);
1719 F->addParamAttr(0, Attribute::WriteOnly);
1720 F->addParamAttr(1, Attribute::ReadOnly);
1722 BasicBlock *entry = BasicBlock::Create(M.getContext(),
"entry", F);
1723 BasicBlock *init = BasicBlock::Create(M.getContext(),
"init.idx", F);
1724 BasicBlock *body = BasicBlock::Create(M.getContext(),
"for.body", F);
1725 BasicBlock *end = BasicBlock::Create(M.getContext(),
"for.end", F);
1727 auto dst = F->arg_begin();
1728 dst->setName(
"dst");
1730 src->setName(
"src");
1732 num->setName(
"num");
1733 auto stride = num + 1;
1734 stride->setName(
"stride");
1737 IRBuilder<> B(entry);
1738 B.CreateCondBr(B.CreateICmpEQ(num, ConstantInt::get(num->getType(), 0)),
1743 IRBuilder<> B2(init);
1744 B2.setFastMathFlags(
getFast());
1745 Value *a = B2.CreateNSWSub(ConstantInt::get(num->getType(), 1), num,
"a");
1746 Value *negidx = B2.CreateNSWMul(a, stride,
"negidx");
1751 B2.CreateICmpSLT(stride, ConstantInt::get(num->getType(), 0),
"is.neg");
1752 Value *startidx = B2.CreateSelect(
1753 isneg, negidx, ConstantInt::get(num->getType(), 0),
"startidx");
1758 IRBuilder<> B(body);
1759 B.setFastMathFlags(
getFast());
1760 PHINode *idx = B.CreatePHI(num->getType(), 2,
"idx");
1761 PHINode *sidx = B.CreatePHI(num->getType(), 2,
"sidx");
1762 idx->addIncoming(ConstantInt::get(num->getType(), 0), init);
1763 sidx->addIncoming(startidx, init);
1765 Value *dsti = B.CreateInBoundsGEP(elementType, dst, idx,
"dst.i");
1766 Value *srci = B.CreateInBoundsGEP(elementType, src, sidx,
"src.i");
1767 LoadInst *srcl = B.CreateLoad(elementType, srci,
"src.i.l");
1768 StoreInst *dsts = B.CreateStore(srcl, dsti);
1771 dsts->setAlignment(Align(dstalign));
1774 srcl->setAlignment(Align(srcalign));
1778 B.CreateNSWAdd(idx, ConstantInt::get(num->getType(), 1),
"idx.next");
1779 Value *snext = B.CreateNSWAdd(sidx, stride,
"sidx.next");
1780 idx->addIncoming(next, body);
1781 sidx->addIncoming(snext, body);
1782 B.CreateCondBr(B.CreateICmpEQ(num, next), end, body);
1794 IntegerType *IT,
unsigned dstalign,
1795 unsigned srcalign) {
1796 assert(elementType->isFPOrFPVectorTy());
1797#if LLVM_VERSION_MAJOR < 17
1798#if LLVM_VERSION_MAJOR >= 15
1799 if (Mod.getContext().supportsTypedPointers()) {
1801#if LLVM_VERSION_MAJOR >= 13
1802 if (!PT->isOpaquePointerTy())
1804 assert(PT->getPointerElementType() == elementType);
1805#if LLVM_VERSION_MAJOR >= 15
1809 std::string name =
"__enzyme_memcpy_" +
tofltstr(elementType) +
"_mat_" +
1810 std::to_string(cast<IntegerType>(IT)->getBitWidth());
1811 FunctionType *FT = FunctionType::get(Type::getVoidTy(Mod.getContext()),
1812 {PT, PT, IT, IT, IT},
false);
1814 Function *F = cast<Function>(Mod.getOrInsertFunction(name, FT).getCallee());
1819 F->setLinkage(Function::LinkageTypes::InternalLinkage);
1820#if LLVM_VERSION_MAJOR >= 16
1821 F->setOnlyAccessesArgMemory();
1823 F->addFnAttr(Attribute::ArgMemOnly);
1825 F->addFnAttr(Attribute::NoUnwind);
1826 F->addFnAttr(Attribute::AlwaysInline);
1828 F->addParamAttr(0, Attribute::NoAlias);
1830 F->addParamAttr(1, Attribute::NoAlias);
1831 F->addParamAttr(0, Attribute::WriteOnly);
1832 F->addParamAttr(1, Attribute::ReadOnly);
1834 BasicBlock *entry = BasicBlock::Create(F->getContext(),
"entry", F);
1835 BasicBlock *init = BasicBlock::Create(F->getContext(),
"init.idx", F);
1836 BasicBlock *body = BasicBlock::Create(F->getContext(),
"for.body", F);
1837 BasicBlock *initend = BasicBlock::Create(F->getContext(),
"init.end", F);
1838 BasicBlock *end = BasicBlock::Create(F->getContext(),
"for.end", F);
1840 auto dst = F->arg_begin();
1841 dst->setName(
"dst");
1843 src->setName(
"src");
1849 LDA->setName(
"LDA");
1852 IRBuilder<> B(entry);
1853 Value *l0 = B.CreateICmpEQ(M, ConstantInt::get(IT, 0));
1854 Value *l1 = B.CreateICmpEQ(N, ConstantInt::get(IT, 0));
1856 B.CreateCondBr(B.CreateOr(l0, l1), end, init);
1861 IRBuilder<> B(init);
1862 j = B.CreatePHI(IT, 2,
"j");
1863 j->addIncoming(ConstantInt::get(IT, 0), entry);
1868 IRBuilder<> B(body);
1869 PHINode *i = B.CreatePHI(IT, 2,
"i");
1870 i->addIncoming(ConstantInt::get(IT, 0), init);
1872 Value *dsti = B.CreateInBoundsGEP(
1874 B.CreateAdd(i, B.CreateMul(j, M,
"",
true,
true),
"",
true,
true),
1876 Value *srci = B.CreateInBoundsGEP(
1878 B.CreateAdd(i, B.CreateMul(j, LDA,
"",
true,
true),
"",
true,
true),
1880 LoadInst *srcl = B.CreateLoad(elementType, srci,
"src.i.l");
1882 StoreInst *dsts = B.CreateStore(srcl, dsti);
1885 dsts->setAlignment(Align(dstalign));
1888 srcl->setAlignment(Align(srcalign));
1892 B.CreateAdd(i, ConstantInt::get(IT, 1),
"i.next",
true,
true);
1893 i->addIncoming(nexti, body);
1894 B.CreateCondBr(B.CreateICmpEQ(nexti, M), initend, body);
1898 IRBuilder<> B(initend);
1900 B.CreateAdd(j, ConstantInt::get(IT, 1),
"j.next",
true,
true);
1901 j->addIncoming(nextj, initend);
1902 B.CreateCondBr(B.CreateICmpEQ(nextj, N), end, init);
1914 Module &Mod, Type *elementType, PointerType *PT, IntegerType *IT,
1915 IntegerType *CT,
unsigned dstalign,
unsigned srcalign,
bool zeroSrc) {
1916 assert(elementType->isFPOrFPVectorTy());
1917#if LLVM_VERSION_MAJOR < 17
1918#if LLVM_VERSION_MAJOR >= 15
1919 if (Mod.getContext().supportsTypedPointers()) {
1921#if LLVM_VERSION_MAJOR >= 13
1922 if (!PT->isOpaquePointerTy())
1924 assert(PT->getPointerElementType() == elementType);
1925#if LLVM_VERSION_MAJOR >= 15
1929 std::string name =
"__enzyme_dmemcpy_" +
tofltstr(elementType) +
"_mat_" +
1930 std::to_string(cast<IntegerType>(IT)->getBitWidth()) +
1931 (zeroSrc ?
"_zero" :
"");
1932 FunctionType *FT = FunctionType::get(Type::getVoidTy(Mod.getContext()),
1933 {CT, IT, IT, PT, IT, PT, IT},
false);
1935 Function *F = cast<Function>(Mod.getOrInsertFunction(name, FT).getCallee());
1940 F->setLinkage(Function::LinkageTypes::InternalLinkage);
1941#if LLVM_VERSION_MAJOR >= 16
1942 F->setOnlyAccessesArgMemory();
1944 F->addFnAttr(Attribute::ArgMemOnly);
1946 F->addFnAttr(Attribute::NoUnwind);
1947 F->addFnAttr(Attribute::AlwaysInline);
1948 F->addParamAttr(3, Attribute::NoAlias);
1949 F->addParamAttr(5, Attribute::NoAlias);
1951 BasicBlock *entry = BasicBlock::Create(F->getContext(),
"entry", F);
1952 BasicBlock *swtch = BasicBlock::Create(F->getContext(),
"swtch", F);
1953 BasicBlock *Ginit = BasicBlock::Create(F->getContext(),
"Ginit.idx", F);
1954 BasicBlock *Uinit = BasicBlock::Create(F->getContext(),
"Uinit.idx", F);
1955 BasicBlock *Linit = BasicBlock::Create(F->getContext(),
"Linit.idx", F);
1956 BasicBlock *end = BasicBlock::Create(F->getContext(),
"for.end", F);
1958 auto uplo = F->arg_begin();
1959 uplo->setName(
"uplo");
1966 dst->setName(
"dst");
1967 auto ldst = dst + 1;
1968 ldst->setName(
"ldst");
1969 auto src = ldst + 1;
1970 src->setName(
"src");
1971 auto lsrc = src + 1;
1972 lsrc->setName(
"lsrc");
1975 IRBuilder<> B(entry);
1976 Value *l0 = B.CreateICmpEQ(M, ConstantInt::get(IT, 0));
1977 Value *l1 = B.CreateICmpEQ(N, ConstantInt::get(IT, 0));
1979 B.CreateCondBr(B.CreateOr(l0, l1), end, swtch);
1983 IRBuilder<> B(swtch);
1984 auto swtchT = B.CreateSwitch(uplo, Ginit);
1985 swtchT->addCase(ConstantInt::get(CT,
'U'), Uinit);
1986 swtchT->addCase(ConstantInt::get(CT,
'L'), Linit);
1989 std::pair<char, BasicBlock *> todo[] = {
1990 {
'G', Ginit}, {
'U', Uinit}, {
'L', Linit}};
1991 for (
auto &&[direction, init] : todo) {
1993 std::string dir(1, direction);
1994 BasicBlock *body = BasicBlock::Create(F->getContext(), dir +
"for.body", F);
1995 BasicBlock *initend =
1996 BasicBlock::Create(F->getContext(), dir +
"init.end", F);
1998 Value *istart = ConstantInt::get(IT, 0);
2003 IRBuilder<> B(init);
2004 j = B.CreatePHI(IT, 2, dir +
"j");
2005 j->addIncoming(ConstantInt::get(IT, 0), swtch);
2007 if (direction ==
'L') {
2009 }
else if (direction ==
'U') {
2010 auto jp1 = B.CreateAdd(j, ConstantInt::get(IT, 1),
"",
true,
true);
2011 iend = B.CreateSelect(B.CreateICmpULT(jp1, M), jp1, M);
2018 IRBuilder<> B(body);
2019 PHINode *i = B.CreatePHI(IT, 2, dir +
"i");
2020 i->addIncoming(istart, init);
2022 Value *srci = B.CreateInBoundsGEP(
2024 B.CreateAdd(i, B.CreateMul(j, lsrc,
"",
true,
true),
"",
true,
true),
2027 Value *dsti = B.CreateInBoundsGEP(
2029 B.CreateAdd(i, B.CreateMul(j, ldst,
"",
true,
true),
"",
true,
true),
2031 LoadInst *srcl = B.CreateLoad(elementType, srci, dir +
"src.i.l");
2032 LoadInst *dstl = B.CreateLoad(elementType, dsti, dir +
"dst.i.l");
2033 auto res = B.CreateFAdd(srcl, dstl);
2034 StoreInst *dsts = B.CreateStore(res, dsti);
2035 StoreInst *srcs =
nullptr;
2037 srcs = B.CreateStore(Constant::getNullValue(res->getType()), srci);
2039 dsts->setAlignment(Align(dstalign));
2040 dstl->setAlignment(Align(dstalign));
2044 srcs->setAlignment(Align(srcalign));
2045 srcl->setAlignment(Align(srcalign));
2049 B.CreateAdd(i, ConstantInt::get(IT, 1), dir +
"i.next",
true,
true);
2050 i->addIncoming(nexti, body);
2051 B.CreateCondBr(B.CreateICmpEQ(nexti, iend), initend, body);
2055 IRBuilder<> B(initend);
2057 B.CreateAdd(j, ConstantInt::get(IT, 1), dir +
"j.next",
true,
true);
2058 j->addIncoming(nextj, initend);
2059 B.CreateCondBr(B.CreateICmpEQ(nextj, N), end, init);
2074 unsigned srcalign,
unsigned dstaddr,
2075 unsigned srcaddr,
unsigned bitwidth) {
2078 <<
"warning: didn't implement memmove, using memcpy as fallback "
2079 "which can result in errors\n";
2086 FunctionType *FreeTy = call->getFunctionType();
2087 Value *Free = call->getCalledOperand();
2088 AttributeList FreeAttributes = call->getAttributes();
2089 CallingConv::ID CallingConvention = call->getCallingConv();
2091 std::string name =
"__enzyme_checked_free_" + std::to_string(width);
2094 if (callname !=
"free")
2095 name +=
"_" + callname.str();
2097 SmallVector<Type *, 3> types;
2098 types.push_back(Ty);
2099 for (
unsigned i = 0; i < width; i++) {
2100 types.push_back(Ty);
2102#if LLVM_VERSION_MAJOR >= 14
2103 for (
size_t i = 1; i < call->arg_size(); i++)
2105 for (
size_t i = 1; i < call->getNumArgOperands(); i++)
2108 types.push_back(call->getArgOperand(i)->getType());
2112 FunctionType::get(Type::getVoidTy(M.getContext()), types,
false);
2114 Function *F = cast<Function>(M.getOrInsertFunction(name, FT).getCallee());
2119 F->setLinkage(Function::LinkageTypes::InternalLinkage);
2120#if LLVM_VERSION_MAJOR >= 16
2121 F->setOnlyAccessesArgMemory();
2123 F->addFnAttr(Attribute::ArgMemOnly);
2125 F->addFnAttr(Attribute::NoUnwind);
2126 F->addFnAttr(Attribute::AlwaysInline);
2128 BasicBlock *entry = BasicBlock::Create(M.getContext(),
"entry", F);
2129 BasicBlock *free0 = BasicBlock::Create(M.getContext(),
"free0", F);
2130 BasicBlock *end = BasicBlock::Create(M.getContext(),
"end", F);
2132 IRBuilder<> EntryBuilder(entry);
2133 IRBuilder<> Free0Builder(free0);
2134 IRBuilder<> EndBuilder(end);
2136 auto primal = F->arg_begin();
2137 Argument *first_shadow = F->arg_begin() + 1;
2141 Value *isNotEqual = EntryBuilder.CreateICmpNE(primal, first_shadow);
2142 EntryBuilder.CreateCondBr(isNotEqual, free0, end);
2144 SmallVector<Value *, 1> args = {first_shadow};
2145#if LLVM_VERSION_MAJOR >= 14
2146 for (
size_t i = 1; i < call->arg_size(); i++)
2148 for (
size_t i = 1; i < call->getNumArgOperands(); i++)
2151 args.push_back(F->arg_begin() + width + i);
2154 CallInst *CI = Free0Builder.CreateCall(FreeTy, Free, args);
2155 CI->setAttributes(FreeAttributes);
2156 CI->setCallingConv(CallingConvention);
2159 Value *checkResult =
nullptr;
2160 BasicBlock *free1 = BasicBlock::Create(M.getContext(),
"free1", F);
2161 IRBuilder<> Free1Builder(free1);
2163 for (
unsigned i = 0; i < width; i++) {
2165 Argument *shadow = F->arg_begin() + i + 1;
2167 if (i < width - 1) {
2168 Argument *nextShadow = F->arg_begin() + i + 2;
2169 Value *isNotEqual = Free0Builder.CreateICmpNE(shadow, nextShadow);
2170 checkResult = checkResult
2171 ? Free0Builder.CreateAnd(isNotEqual, checkResult)
2174 args[0] = nextShadow;
2175 CallInst *CI = Free1Builder.CreateCall(FreeTy, Free, args);
2176 CI->setAttributes(FreeAttributes);
2177 CI->setCallingConv(CallingConvention);
2180 Free0Builder.CreateCondBr(checkResult, free1, end);
2181 Free1Builder.CreateBr(end);
2183 Free0Builder.CreateBr(end);
2186 EndBuilder.CreateRetVoid();
2193 assert(V->getType()->isIntegerTy());
2194 IntegerType *T = cast<IntegerType>(V->getType());
2195 V = B.CreateAdd(V, ConstantInt::get(T, -1));
2196 for (
size_t i = 1; i < T->getBitWidth(); i *= 2) {
2197 V = B.CreateOr(V, B.CreateLShr(V, ConstantInt::get(T, i)));
2199 V = B.CreateAdd(V, ConstantInt::get(T, 1));
2204 ArrayRef<llvm::Type *> T,
2205 PointerType *reqType) {
2206 std::string name =
"__enzyme_differential_waitall_save";
2207 FunctionType *FT = FunctionType::get(
getUnqual(reqType), T,
false);
2208 Function *F = cast<Function>(M.getOrInsertFunction(name, FT).getCallee());
2213 F->setLinkage(Function::LinkageTypes::InternalLinkage);
2214 F->addFnAttr(Attribute::NoUnwind);
2215 F->addFnAttr(Attribute::AlwaysInline);
2217 BasicBlock *entry = BasicBlock::Create(M.getContext(),
"entry", F);
2219 auto buff = F->arg_begin();
2220 buff->setName(
"count");
2221 Value *count = buff;
2222 Value *req = buff + 1;
2223 req->setName(
"req");
2224 Value *dreq = buff + 2;
2225 dreq->setName(
"dreq");
2227 IRBuilder<> B(entry);
2228 count = B.CreateZExtOrTrunc(count, Type::getInt64Ty(entry->getContext()));
2232 BasicBlock *loopBlock = BasicBlock::Create(M.getContext(),
"loop", F);
2233 BasicBlock *endBlock = BasicBlock::Create(M.getContext(),
"end", F);
2235 B.CreateCondBr(B.CreateICmpEQ(count, ConstantInt::get(count->getType(), 0)),
2236 endBlock, loopBlock);
2238 B.SetInsertPoint(loopBlock);
2239 auto idx = B.CreatePHI(count->getType(), 2);
2240 idx->addIncoming(ConstantInt::get(count->getType(), 0), entry);
2241 auto inc = B.CreateAdd(idx, ConstantInt::get(count->getType(), 1));
2242 idx->addIncoming(inc, loopBlock);
2244 Type *reqT = reqType;
2245 Value *idxs[] = {idx};
2246 Value *ireq = B.CreateInBoundsGEP(reqT, req, idxs);
2247 Value *idreq = B.CreateInBoundsGEP(reqT, dreq, idxs);
2248 Value *iout = B.CreateInBoundsGEP(reqType, ret, idxs);
2249 Value *isNull =
nullptr;
2250 if (
auto GV = M.getNamedValue(
"ompi_request_null")) {
2251 Value *reql = B.CreatePointerCast(ireq,
getUnqual(GV->getType()));
2252 reql = B.CreateLoad(GV->getType(), reql);
2253 isNull = B.CreateICmpEQ(reql, GV);
2256 idreq = B.CreatePointerCast(idreq,
getUnqual(reqType));
2257 Value *d_reqp = B.CreateLoad(reqType, idreq);
2259 d_reqp = B.CreateSelect(isNull, Constant::getNullValue(d_reqp->getType()),
2262 B.CreateStore(d_reqp, iout);
2264 B.CreateCondBr(B.CreateICmpEQ(inc, count), endBlock, loopBlock);
2266 B.SetInsertPoint(endBlock);
2272 ArrayRef<llvm::Type *> T,
2275 llvm::SmallVector<llvm::Type *, 4> types(T.begin(), T.end());
2276 types.push_back(reqType);
2280 std::string name =
"__enzyme_differential_mpi_wait";
2281 if (prefix.size() != 0 || postfix.size() != 0) {
2282 name = (Twine(name) +
"$" + prefix +
"$" + postfix).
str();
2285 FunctionType::get(Type::getVoidTy(M.getContext()), types,
false);
2286 Function *F = cast<Function>(M.getOrInsertFunction(name, FT).getCallee());
2291 F->setLinkage(Function::LinkageTypes::InternalLinkage);
2292 F->addFnAttr(Attribute::NoUnwind);
2293 F->addFnAttr(Attribute::AlwaysInline);
2295 BasicBlock *entry = BasicBlock::Create(M.getContext(),
"entry", F);
2296 BasicBlock *isend = BasicBlock::Create(M.getContext(),
"invertISend", F);
2297 BasicBlock *irecv = BasicBlock::Create(M.getContext(),
"invertIRecv", F);
2306 Type::getInt8Ty(call.getContext())
2309 auto buff = F->arg_begin();
2310 buff->setName(
"buf");
2312 Value *count = buff + 1;
2313 count->setName(
"count");
2314 Value *datatype = buff + 2;
2315 datatype->setName(
"datatype");
2316 Value *source = buff + 3;
2317 source->setName(
"source");
2318 Value *tag = buff + 4;
2319 tag->setName(
"tag");
2320 Value *comm = buff + 5;
2321 comm->setName(
"comm");
2322 Value *fn = buff + 6;
2324 Value *d_req = buff + 7;
2325 d_req->setName(
"d_req");
2330 FunctionType *FuT = isendfn->getFunctionType();
2332 auto irecvfn = cast<Function>(
2337 IRBuilder<> B(entry);
2338 auto arg = isendfn->arg_begin();
2339 if (arg->getType()->isIntegerTy())
2340 buf = B.CreatePtrToInt(buf, arg->getType());
2342 count = B.CreateZExtOrTrunc(count, arg->getType());
2344 datatype = B.CreatePointerCast(datatype, arg->getType());
2346 source = B.CreateZExtOrTrunc(source, arg->getType());
2348 tag = B.CreateZExtOrTrunc(tag, arg->getType());
2350 comm = B.CreatePointerCast(comm, arg->getType());
2352 if (arg->getType()->isIntegerTy())
2353 d_req = B.CreatePtrToInt(d_req, arg->getType());
2355 buf, count, datatype, source, tag, comm, d_req,
2358 B.CreateCondBr(B.CreateICmpEQ(fn, ConstantInt::get(fn->getType(),
2363 B.SetInsertPoint(isend);
2364 auto fcall = B.CreateCall(irecvfn, args);
2365 fcall->setCallingConv(isendfn->getCallingConv());
2370 B.SetInsertPoint(irecv);
2371 auto fcall = B.CreateCall(isendfn, args);
2372 fcall->setCallingConv(isendfn->getCallingConv());
2380 llvm::Type *intType, IRBuilder<> &B2) {
2381 std::string name =
"__enzyme_mpi_sum" + CT.
str();
2385 if (
auto Glob = M.getGlobalVariable(name)) {
2386 return B2.CreateLoad(Glob->getValueType(), Glob);
2392 FunctionType::get(Type::getVoidTy(M.getContext()), types,
false);
2394 cast<Function>(M.getOrInsertFunction(name +
"_run", FuT).getCallee());
2396 F->setLinkage(Function::LinkageTypes::InternalLinkage);
2397#if LLVM_VERSION_MAJOR >= 16
2398 F->setOnlyAccessesArgMemory();
2400 F->addFnAttr(Attribute::ArgMemOnly);
2402 F->addFnAttr(Attribute::NoUnwind);
2403 F->addFnAttr(Attribute::AlwaysInline);
2405 F->addParamAttr(0, Attribute::ReadOnly);
2408 F->addParamAttr(2, Attribute::ReadOnly);
2410 F->addParamAttr(3, Attribute::ReadNone);
2412 BasicBlock *entry = BasicBlock::Create(M.getContext(),
"entry", F);
2413 BasicBlock *body = BasicBlock::Create(M.getContext(),
"for.body", F);
2414 BasicBlock *end = BasicBlock::Create(M.getContext(),
"for.end", F);
2416 auto src = F->arg_begin();
2417 src->setName(
"src");
2419 dst->setName(
"dst");
2420 auto lenp = dst + 1;
2421 lenp->setName(
"lenp");
2427 IRBuilder<> B(entry);
2428 len = B.CreateLoad(intType, lenp);
2429 B.CreateCondBr(B.CreateICmpEQ(len, ConstantInt::get(len->getType(), 0)),
2434 IRBuilder<> B(body);
2435 B.setFastMathFlags(
getFast());
2436 PHINode *idx = B.CreatePHI(len->getType(), 2,
"idx");
2437 idx->addIncoming(ConstantInt::get(len->getType(), 0), entry);
2439 Value *dsti = B.CreateInBoundsGEP(FlT, dst, idx,
"dst.i");
2440 LoadInst *dstl = B.CreateLoad(FlT, dsti,
"dst.i.l");
2442 Value *srci = B.CreateInBoundsGEP(FlT, src, idx,
"src.i");
2443 LoadInst *srcl = B.CreateLoad(FlT, srci,
"src.i.l");
2444 B.CreateStore(B.CreateFAdd(srcl, dstl), dsti);
2447 B.CreateNUWAdd(idx, ConstantInt::get(len->getType(), 1),
"idx.next");
2448 idx->addIncoming(next, body);
2449 B.CreateCondBr(B.CreateICmpEQ(len, next), end, body);
2457 llvm::Type *rtypes[] = {
getInt8PtrTy(M.getContext()), intType, OpPtr};
2458 FunctionType *RFT = FunctionType::get(intType, rtypes,
false);
2460 Constant *RF = M.getNamedValue(
"MPI_Op_create");
2463 cast<Function>(M.getOrInsertFunction(
"MPI_Op_create", RFT).getCallee());
2465 RF = ConstantExpr::getBitCast(RF,
getUnqual(RFT));
2468 GlobalVariable *GV =
2469 new GlobalVariable(M, OpType,
false, GlobalVariable::InternalLinkage,
2470 UndefValue::get(OpType), name);
2472 Type *i1Ty = Type::getInt1Ty(M.getContext());
2473 GlobalVariable *initD =
new GlobalVariable(
2474 M, i1Ty,
false, GlobalVariable::InternalLinkage,
2475 ConstantInt::getFalse(M.getContext()), name +
"_initd");
2479 FunctionType *IFT = FunctionType::get(Type::getVoidTy(M.getContext()),
2480 ArrayRef<Type *>(),
false);
2481 Function *initializerFunction = cast<Function>(
2482 M.getOrInsertFunction(name +
"initializer", IFT).getCallee());
2484 initializerFunction->setLinkage(Function::LinkageTypes::InternalLinkage);
2485 initializerFunction->addFnAttr(Attribute::NoUnwind);
2489 BasicBlock::Create(M.getContext(),
"entry", initializerFunction);
2491 BasicBlock::Create(M.getContext(),
"run", initializerFunction);
2493 BasicBlock::Create(M.getContext(),
"end", initializerFunction);
2494 IRBuilder<> B(entry);
2496 B.CreateCondBr(B.CreateLoad(initD->getValueType(), initD), end, run);
2498 B.SetInsertPoint(run);
2499 Value *args[] = {ConstantExpr::getPointerCast(F, rtypes[0]),
2500 ConstantInt::get(rtypes[1], 1,
false),
2501 ConstantExpr::getPointerCast(GV, rtypes[2])};
2502 B.CreateCall(RFT, RF, args);
2503 B.CreateStore(ConstantInt::getTrue(M.getContext()), initD);
2505 B.SetInsertPoint(end);
2509 B2.CreateCall(M.getFunction(name +
"initializer"));
2510 return B2.CreateLoad(GV->getValueType(), GV);
2514 llvm::Instruction *inst,
2515 const llvm::SmallPtrSetImpl<Instruction *> &stores,
2516 const llvm::Loop *region) {
2517 using namespace llvm;
2518 std::map<BasicBlock *, SmallVector<Instruction *, 1>> maybeBlocks;
2519 BasicBlock *instBlk = inst->getParent();
2520 for (
auto store : stores) {
2521 BasicBlock *storeBlk = store->getParent();
2522 if (instBlk == storeBlk) {
2525 if (store != inst) {
2526 BasicBlock::const_iterator It = storeBlk->begin();
2527 for (; &*It != store && &*It != inst; ++It)
2532 results.push_back(store);
2535 maybeBlocks[storeBlk].push_back(store);
2537 maybeBlocks[storeBlk].push_back(store);
2541 if (maybeBlocks.size() == 0)
2544 llvm::SmallVector<BasicBlock *, 2> todo;
2545 for (
auto B : successors(instBlk)) {
2546 if (region && region->getHeader() == B) {
2552 SmallPtrSet<BasicBlock *, 2> seen;
2553 while (todo.size()) {
2554 auto cur = todo.back();
2556 if (seen.count(cur))
2559 auto found = maybeBlocks.find(cur);
2560 if (found != maybeBlocks.end()) {
2561 for (
auto store : found->second)
2562 results.push_back(store);
2563 maybeBlocks.erase(found);
2565 for (
auto B : successors(cur)) {
2566 if (region && region->getHeader() == B) {
2575 llvm::ScalarEvolution &SE, llvm::LoopInfo &LI, llvm::DominatorTree &DT,
2576 llvm::Instruction *maybeReader,
const llvm::SCEV *LoadStart,
2577 const llvm::SCEV *LoadEnd, llvm::Instruction *maybeWriter,
2578 const llvm::SCEV *StoreStart,
const llvm::SCEV *StoreEnd,
2579 llvm::Loop *scope) {
2587 Loop *anc =
getAncestor(LI.getLoopFor(maybeReader->getParent()),
2588 LI.getLoopFor(maybeWriter->getParent()));
2593 assert(scope == anc || scope->contains(anc));
2629 SmallPtrSet<const Loop *, 1> visitedAncestors;
2630 auto skipLoop = [&](
const Loop *L) {
2632 if (scope && L->contains(scope))
2635 if (anc && (anc == L || anc->contains(L))) {
2636 visitedAncestors.insert(L);
2646 auto hasOverlap = [&](
const SCEV *EndPrev,
const SCEV *StartNext,
2648 for (
auto slim = StartNext; slim != SE.getCouldNotCompute();) {
2651 if (
auto startL = dyn_cast<SCEVAddRecExpr>(slim))
2652 if (skipLoop(startL->getLoop()) &&
2653 SE.isKnownNonPositive(startL->getStepRecurrence(SE))) {
2658 for (
auto elim = EndPrev; elim != SE.getCouldNotCompute();) {
2663 if (
auto endL = dyn_cast<SCEVAddRecExpr>(elim)) {
2664 if (skipLoop(endL->getLoop()) &&
2665 SE.isKnownNonNegative(endL->getStepRecurrence(SE))) {
2676 if (
auto endL = dyn_cast<SCEVAddRecExpr>(elim)) {
2677 auto EH = endL->getLoop()->getHeader();
2678 if (
auto startL = dyn_cast<SCEVAddRecExpr>(slim)) {
2679 auto SH = startL->getLoop()->getHeader();
2680 if (EH != SH && !DT.dominates(EH, SH) &&
2681 !DT.dominates(SH, EH))
2687 auto sub = SE.getMinusSCEV(slim, elim);
2688 if (sub != SE.getCouldNotCompute() && SE.isKnownNonNegative(sub))
2693 if (
auto endL = dyn_cast<SCEVAddRecExpr>(elim)) {
2694 if (SE.isKnownNonPositive(endL->getStepRecurrence(SE))) {
2695 elim = endL->getStart();
2697 }
else if (SE.isKnownNonNegative(endL->getStepRecurrence(SE))) {
2698#if LLVM_VERSION_MAJOR >= 12
2699 auto ebd = SE.getSymbolicMaxBackedgeTakenCount(endL->getLoop());
2701 auto ebd = SE.getBackedgeTakenCount(endL->getLoop());
2703 if (ebd == SE.getCouldNotCompute())
2705 elim = endL->evaluateAtIteration(ebd, SE);
2712 if (
auto startL = dyn_cast<SCEVAddRecExpr>(slim)) {
2713 if (SE.isKnownNonNegative(startL->getStepRecurrence(SE))) {
2714 slim = startL->getStart();
2716 }
else if (SE.isKnownNonPositive(startL->getStepRecurrence(SE))) {
2717#if LLVM_VERSION_MAJOR >= 12
2718 auto sbd = SE.getSymbolicMaxBackedgeTakenCount(startL->getLoop());
2720 auto sbd = SE.getBackedgeTakenCount(startL->getLoop());
2722 if (sbd == SE.getCouldNotCompute())
2724 slim = startL->evaluateAtIteration(sbd, SE);
2735 visitedAncestors.clear();
2736 if (!hasOverlap(StoreEnd, LoadStart,
true)) {
2740 for (
const Loop *L = anc; anc != scope; anc = anc->getParentLoop()) {
2741 if (!visitedAncestors.count(L))
2750 visitedAncestors.clear();
2751 if (!hasOverlap(LoadEnd, StoreStart,
false)) {
2755 for (
const Loop *L = anc; anc != scope; anc = anc->getParentLoop()) {
2756 if (!visitedAncestors.count(L))
2766 llvm::TargetLibraryInfo &TLI, ScalarEvolution &SE,
2767 llvm::LoopInfo &LI, llvm::DominatorTree &DT,
2768 llvm::Instruction *maybeReader,
2769 llvm::Instruction *maybeWriter,
2770 llvm::Loop *scope) {
2771 using namespace llvm;
2774 const SCEV *LoadBegin = SE.getCouldNotCompute();
2775 const SCEV *LoadEnd = SE.getCouldNotCompute();
2777 const SCEV *StoreBegin = SE.getCouldNotCompute();
2778 const SCEV *StoreEnd = SE.getCouldNotCompute();
2780 Value *loadPtr =
nullptr;
2781 Value *storePtr =
nullptr;
2782 if (
auto LI = dyn_cast<LoadInst>(maybeReader)) {
2783 loadPtr = LI->getPointerOperand();
2784 LoadBegin = SE.getSCEV(LI->getPointerOperand());
2785 if (LoadBegin != SE.getCouldNotCompute() &&
2786 !LoadBegin->getType()->isIntegerTy()) {
2787 auto &DL = maybeWriter->getModule()->getDataLayout();
2788 auto width = cast<IntegerType>(DL.getIndexType(LoadBegin->getType()))
2790#if LLVM_VERSION_MAJOR >= 18
2791 auto TS = SE.getConstant(
2792 APInt(width, (int64_t)DL.getTypeStoreSize(LI->getType())));
2794 auto TS = SE.getConstant(
2795 APInt(width, DL.getTypeStoreSize(LI->getType()).getFixedSize()));
2797 LoadEnd = SE.getAddExpr(LoadBegin, TS);
2800 if (
auto SI = dyn_cast<StoreInst>(maybeWriter)) {
2801 storePtr = SI->getPointerOperand();
2802 StoreBegin = SE.getSCEV(SI->getPointerOperand());
2803 if (StoreBegin != SE.getCouldNotCompute() &&
2804 !StoreBegin->getType()->isIntegerTy()) {
2805 auto &DL = maybeWriter->getModule()->getDataLayout();
2806 auto width = cast<IntegerType>(DL.getIndexType(StoreBegin->getType()))
2808#if LLVM_VERSION_MAJOR >= 18
2810 SE.getConstant(APInt(width, (int64_t)DL.getTypeStoreSize(
2811 SI->getValueOperand()->getType())));
2813 auto TS = SE.getConstant(
2814 APInt(width, DL.getTypeStoreSize(SI->getValueOperand()->getType())
2817 StoreEnd = SE.getAddExpr(StoreBegin, TS);
2820 if (
auto MS = dyn_cast<MemSetInst>(maybeWriter)) {
2821 storePtr = MS->getArgOperand(0);
2822 StoreBegin = SE.getSCEV(MS->getArgOperand(0));
2823 if (StoreBegin != SE.getCouldNotCompute() &&
2824 !StoreBegin->getType()->isIntegerTy()) {
2825 if (
auto Len = dyn_cast<ConstantInt>(MS->getArgOperand(2))) {
2826 auto &DL = MS->getModule()->getDataLayout();
2827 auto width = cast<IntegerType>(DL.getIndexType(StoreBegin->getType()))
2830 SE.getConstant(APInt(width, Len->getValue().getLimitedValue()));
2831 StoreEnd = SE.getAddExpr(StoreBegin, TS);
2835 if (
auto MS = dyn_cast<MemTransferInst>(maybeWriter)) {
2836 storePtr = MS->getArgOperand(0);
2837 StoreBegin = SE.getSCEV(MS->getArgOperand(0));
2838 if (StoreBegin != SE.getCouldNotCompute() &&
2839 !StoreBegin->getType()->isIntegerTy()) {
2840 if (
auto Len = dyn_cast<ConstantInt>(MS->getArgOperand(2))) {
2841 auto &DL = MS->getModule()->getDataLayout();
2842 auto width = cast<IntegerType>(DL.getIndexType(StoreBegin->getType()))
2845 SE.getConstant(APInt(width, Len->getValue().getLimitedValue()));
2846 StoreEnd = SE.getAddExpr(StoreBegin, TS);
2850 if (
auto MS = dyn_cast<MemTransferInst>(maybeReader)) {
2851 loadPtr = MS->getArgOperand(1);
2852 LoadBegin = SE.getSCEV(MS->getArgOperand(1));
2853 if (LoadBegin != SE.getCouldNotCompute() &&
2854 !LoadBegin->getType()->isIntegerTy()) {
2855 if (
auto Len = dyn_cast<ConstantInt>(MS->getArgOperand(2))) {
2856 auto &DL = MS->getModule()->getDataLayout();
2857 auto width = cast<IntegerType>(DL.getIndexType(LoadBegin->getType()))
2860 SE.getConstant(APInt(width, Len->getValue().getLimitedValue()));
2861 LoadEnd = SE.getAddExpr(LoadBegin, TS);
2866 if (loadPtr && storePtr)
2873 maybeWriter, StoreBegin, StoreEnd, scope))
2881 llvm::TargetLibraryInfo &TLI,
2882 llvm::Instruction *maybeReader,
2883 llvm::Instruction *maybeWriter) {
2884 assert(maybeReader->getParent()->getParent() ==
2885 maybeWriter->getParent()->getParent());
2886 using namespace llvm;
2887 if (isa<StoreInst>(maybeReader))
2889 if (isa<FenceInst>(maybeReader)) {
2892 if (
auto call = dyn_cast<CallInst>(maybeWriter)) {
2906 if (funcName ==
"jl_array_copy" || funcName ==
"ijl_array_copy")
2909 if (funcName ==
"jl_genericmemory_copy_slice" ||
2910 funcName ==
"ijl_genericmemory_copy_slice")
2913 if (funcName ==
"jl_new_array" || funcName ==
"ijl_new_array")
2916 if (funcName ==
"julia.safepoint")
2919 if (funcName ==
"jl_idtable_rehash" || funcName ==
"ijl_idtable_rehash")
2923 if (funcName ==
"MPI_Send" || funcName ==
"PMPI_Send") {
2927 if (funcName ==
"MPI_Wait" || funcName ==
"PMPI_Wait" ||
2928 funcName ==
"MPI_Waitall" || funcName ==
"PMPI_Waitall") {
2929#if LLVM_VERSION_MAJOR > 11
2930 auto loc = LocationSize::afterPointer();
2932 auto loc = MemoryLocation::UnknownSize;
2934 size_t off = (funcName ==
"MPI_Wait" || funcName ==
"PMPI_Wait") ? 0 : 1;
2936 if (!isRefSet(AA.getModRefInfo(maybeReader, call->getArgOperand(off + 1),
2939 if (!isRefSet(AA.getModRefInfo(maybeReader,
2940 call->getArgOperand(off + 0), loc)))
2944 maybeReader->getParent()->getParent()->getParent()->getDataLayout(),
2954 if (funcName ==
"MPI_Isend" || funcName ==
"PMPI_Isend") {
2957 maybeReader->getParent()->getParent()->getParent()->getDataLayout(),
2964#if LLVM_VERSION_MAJOR > 11
2965 if (!isRefSet(AA.getModRefInfo(maybeReader, call->getArgOperand(6),
2966 LocationSize::afterPointer())))
2969 if (!isRefSet(AA.getModRefInfo(maybeReader, call->getArgOperand(6),
2970 MemoryLocation::UnknownSize)))
2975 if (funcName ==
"MPI_Irecv" || funcName ==
"PMPI_Irecv" ||
2976 funcName ==
"MPI_Recv" || funcName ==
"PMPI_Recv") {
2978 if (
Constant *C = dyn_cast<Constant>(call->getArgOperand(2))) {
2979 while (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) {
2980 C = CE->getOperand(0);
2982 if (
auto GV = dyn_cast<GlobalVariable>(C)) {
2983 if (GV->getName() ==
"ompi_mpi_double") {
2984 type =
ConcreteType(Type::getDoubleTy(C->getContext()));
2985 }
else if (GV->getName() ==
"ompi_mpi_float") {
2986 type =
ConcreteType(Type::getFloatTy(C->getContext()));
2993 maybeReader->getParent()->getParent()->getParent()->getDataLayout(),
2995 if (R.isKnown() && type != R) {
2998 if (funcName ==
"MPI_Recv" || funcName ==
"PMPI_Recv" ||
3001#if LLVM_VERSION_MAJOR > 11
3002 if (!isRefSet(AA.getModRefInfo(maybeReader, call->getArgOperand(6),
3003 LocationSize::afterPointer())))
3006 if (!isRefSet(AA.getModRefInfo(maybeReader, call->getArgOperand(6),
3007 MemoryLocation::UnknownSize)))
3013 if (
auto II = dyn_cast<IntrinsicInst>(call)) {
3014 if (II->getIntrinsicID() == Intrinsic::stacksave)
3016 if (II->getIntrinsicID() == Intrinsic::stackrestore)
3018 if (II->getIntrinsicID() == Intrinsic::trap)
3020#if LLVM_VERSION_MAJOR >= 13
3021 if (II->getIntrinsicID() == Intrinsic::experimental_noalias_scope_decl)
3026 if (
auto iasm = dyn_cast<InlineAsm>(call->getCalledOperand())) {
3027 if (StringRef(iasm->getAsmString()).contains(
"exit"))
3031 if (
auto call = dyn_cast<CallInst>(maybeReader)) {
3046 if (
auto II = dyn_cast<IntrinsicInst>(call)) {
3047 if (II->getIntrinsicID() == Intrinsic::stacksave)
3049 if (II->getIntrinsicID() == Intrinsic::stackrestore)
3051 if (II->getIntrinsicID() == Intrinsic::trap)
3053#if LLVM_VERSION_MAJOR >= 13
3054 if (II->getIntrinsicID() == Intrinsic::experimental_noalias_scope_decl)
3059 if (
auto call = dyn_cast<InvokeInst>(maybeWriter)) {
3073 if (funcName ==
"jl_array_copy" || funcName ==
"ijl_array_copy")
3076 if (funcName ==
"jl_genericmemory_copy_slice" ||
3077 funcName ==
"ijl_genericmemory_copy_slice")
3080 if (funcName ==
"jl_idtable_rehash" || funcName ==
"ijl_idtable_rehash")
3083 if (
auto iasm = dyn_cast<InlineAsm>(call->getCalledOperand())) {
3084 if (StringRef(iasm->getAsmString()).contains(
"exit"))
3088 if (
auto call = dyn_cast<InvokeInst>(maybeReader)) {
3103 assert(maybeWriter->mayWriteToMemory());
3104 assert(maybeReader->mayReadFromMemory());
3106 if (
auto li = dyn_cast<LoadInst>(maybeReader)) {
3108 auto TT = TR->
query(li)[{-1}];
3110 if (
auto si = dyn_cast<StoreInst>(maybeWriter)) {
3111 auto TT2 = TR->
query(si->getValueOperand())[{-1}];
3116 auto &dl = li->getParent()->getParent()->getParent()->getDataLayout();
3118 (dl.getTypeSizeInBits(si->getValueOperand()->getType()) + 7) / 8;
3119 TT2 = TR->
query(si->getPointerOperand()).
Lookup(len, dl)[{-1}];
3127 return isModSet(AA.getModRefInfo(maybeWriter, MemoryLocation::get(li)));
3129 if (
auto rmw = dyn_cast<AtomicRMWInst>(maybeReader)) {
3130 return isModSet(AA.getModRefInfo(maybeWriter, MemoryLocation::get(rmw)));
3132 if (
auto xch = dyn_cast<AtomicCmpXchgInst>(maybeReader)) {
3133 return isModSet(AA.getModRefInfo(maybeWriter, MemoryLocation::get(xch)));
3135 if (
auto mti = dyn_cast<MemTransferInst>(maybeReader)) {
3137 AA.getModRefInfo(maybeWriter, MemoryLocation::getForSource(mti)));
3140 if (
auto si = dyn_cast<StoreInst>(maybeWriter)) {
3141 return isRefSet(AA.getModRefInfo(maybeReader, MemoryLocation::get(si)));
3143 if (
auto rmw = dyn_cast<AtomicRMWInst>(maybeWriter)) {
3144 return isRefSet(AA.getModRefInfo(maybeReader, MemoryLocation::get(rmw)));
3146 if (
auto xch = dyn_cast<AtomicCmpXchgInst>(maybeWriter)) {
3147 return isRefSet(AA.getModRefInfo(maybeReader, MemoryLocation::get(xch)));
3149 if (
auto mti = dyn_cast<MemIntrinsic>(maybeWriter)) {
3151 AA.getModRefInfo(maybeReader, MemoryLocation::getForDest(mti)));
3154 if (
auto cb = dyn_cast<CallInst>(maybeReader)) {
3155 return isModOrRefSet(AA.getModRefInfo(maybeWriter, cb));
3157 if (
auto cb = dyn_cast<InvokeInst>(maybeReader)) {
3158 return isModOrRefSet(AA.getModRefInfo(maybeWriter, cb));
3160 llvm::errs() <<
" maybeReader: " << *maybeReader
3161 <<
" maybeWriter: " << *maybeWriter <<
"\n";
3162 llvm_unreachable(
"unknown inst2");
3170 if (
auto CI = dyn_cast<CastInst>(ptr)) {
3171 ptr = CI->getOperand(0);
3174 if (
auto CI = dyn_cast<GetElementPtrInst>(ptr)) {
3175 auto &DL = CI->getParent()->getParent()->getParent()->getDataLayout();
3176#if LLVM_VERSION_MAJOR >= 20
3177 SmallMapVector<Value *, APInt, 4> VariableOffsets;
3179 MapVector<Value *, APInt> VariableOffsets;
3181 auto width =
sizeof(size_t) * 8;
3182 APInt Offset(width, 0);
3183 bool success =
collectOffset(cast<GEPOperator>(CI), DL, width,
3184 VariableOffsets, Offset);
3185 if (!success || VariableOffsets.size() != 0 || Offset.isNegative()) {
3188 offset += Offset.getZExtValue();
3189 ptr = CI->getOperand(0);
3192 if (isa<AllocaInst>(ptr)) {
3195 if (
auto LI = dyn_cast<LoadInst>(ptr)) {
3203 return cast<AllocaInst>(ptr);
3209SmallVector<std::tuple<Instruction *, Value *, size_t>, 1>
3211 SmallVector<std::pair<Value *, size_t>, 1> todo;
3212 todo.emplace_back(AI, 0);
3214 SmallVector<std::tuple<Instruction *, Value *, size_t>, 1> users;
3215 while (todo.size()) {
3216 auto pair = todo.pop_back_val();
3217 Value *ptr = pair.first;
3218 size_t suboff = pair.second;
3220 for (
auto U : ptr->users()) {
3221 if (
auto CI = dyn_cast<CastInst>(U)) {
3222 todo.emplace_back(CI, suboff);
3225 if (
auto CI = dyn_cast<GetElementPtrInst>(U)) {
3226 auto &DL = CI->getParent()->getParent()->getParent()->getDataLayout();
3227#if LLVM_VERSION_MAJOR >= 20
3228 SmallMapVector<Value *, APInt, 4> VariableOffsets;
3230 MapVector<Value *, APInt> VariableOffsets;
3232 auto width =
sizeof(size_t) * 8;
3233 APInt Offset(width, 0);
3234 bool success =
collectOffset(cast<GEPOperator>(CI), DL, width,
3235 VariableOffsets, Offset);
3237 if (!success || VariableOffsets.size() != 0 || Offset.isNegative()) {
3238 users.emplace_back(cast<Instruction>(U), ptr, suboff);
3241 todo.emplace_back(CI, suboff + Offset.getZExtValue());
3244 users.emplace_back(cast<Instruction>(U), ptr, suboff);
3254SmallVector<std::pair<Value *, size_t>, 1>
3257 SmallVector<std::pair<Value *, size_t>, 1> options;
3260 std::set<std::tuple<Instruction *, Value *, size_t>> seen;
3262 while (todo.size()) {
3263 auto pair = todo.pop_back_val();
3264 if (seen.count(pair))
3267 Instruction *U = std::get<0>(pair);
3268 Value *ptr = std::get<1>(pair);
3269 size_t suboff = std::get<2>(pair);
3272 if (isa<LoadInst>(U)) {
3275 if (
auto MTI = dyn_cast<MemTransferInst>(U))
3276 if (MTI->getOperand(0) != ptr) {
3279 if (
auto I = dyn_cast<Instruction>(U)) {
3280 if (!I->mayWriteToMemory() && I->getType()->isVoidTy())
3284 if (
auto SI = dyn_cast<StoreInst>(U)) {
3285 auto &DL = SI->getParent()->getParent()->getParent()->getDataLayout();
3288 if (SI->getPointerOperand() == ptr) {
3290 (DL.getTypeStoreSizeInBits(SI->getValueOperand()->getType()) + 7) /
3293 if (storeSz + suboff <= offset)
3296 if (offset + valSz <= suboff)
3299 if (valSz <= storeSz) {
3300 assert(offset >= suboff);
3301 options.emplace_back(SI->getValueOperand(), offset - suboff);
3308 if (SI->getValueOperand() == ptr) {
3310 size_t mid_offset = 0;
3313 bool sublegal =
true;
3314 auto ptrSz = (DL.getTypeStoreSizeInBits(ptr->getType()) + 7) / 8;
3321 for (
auto &&[subPtr, subOff] : subPtrs) {
3325 todo.emplace_back(std::move(pair3));
3334 if (
auto II = dyn_cast<IntrinsicInst>(U)) {
3335 if (II->getCalledFunction()->getName() ==
"llvm.enzyme.lifetime_start" ||
3336 II->getCalledFunction()->getName() ==
"llvm.enzyme.lifetime_end")
3338 if (II->getIntrinsicID() == Intrinsic::lifetime_start ||
3339 II->getIntrinsicID() == Intrinsic::lifetime_end)
3345 if (
auto MTI = dyn_cast<MemTransferInst>(U)) {
3346 if (
auto CI = dyn_cast<ConstantInt>(MTI->getLength())) {
3347 if (MTI->getOperand(0) == ptr) {
3348 auto storeSz = CI->getValue();
3351 if ((storeSz + suboff).ule(offset))
3355 if (offset + valSz <= suboff)
3358 if (suboff == 0 && CI->getValue().uge(offset + valSz)) {
3359 size_t midoffset = 0;
3365 if (midoffset != 0) {
3370 todo.emplace_back(std::move(pair3));
3387 if (
auto LI = dyn_cast<LoadInst>(V)) {
3389 auto &DL = LI->getParent()->getParent()->getParent()->getDataLayout();
3390 valSz = (DL.getTypeSizeInBits(LI->getType()) + 7) / 8;
3393 Value *ptr = LI->getPointerOperand();
3403 offset += preOffset;
3411 std::set<Value *> res;
3412 for (
auto &&[opt, startOff] : opts) {
3419 if (res.size() != 1) {
3422 Value *retval = *res.begin();
3425 if (
auto EVI = dyn_cast<ExtractValueInst>(V)) {
3429 EVI->getIndices(),
"",
false);
3430 if (em !=
nullptr) {
3435 if (
auto LI = dyn_cast<LoadInst>(EVI->getAggregateOperand())) {
3436 auto offset = preOffset;
3438 auto &DL = LI->getParent()->getParent()->getParent()->getDataLayout();
3439 SmallVector<Value *, 4> vec;
3440 vec.push_back(ConstantInt::get(Type::getInt64Ty(EVI->getContext()), 0));
3441 for (
auto ind : EVI->getIndices()) {
3443 ConstantInt::get(Type::getInt32Ty(EVI->getContext()), ind));
3445 auto ud = UndefValue::get(
getUnqual(EVI->getOperand(0)->getType()));
3447 GetElementPtrInst::Create(EVI->getOperand(0)->getType(), ud, vec);
3448 APInt ai(DL.getIndexSizeInBits(g2->getPointerAddressSpace()), 0);
3449 g2->accumulateConstantOffset(DL, ai);
3454 offset += (size_t)ai.getLimitedValue();
3457 auto &DL = EVI->getParent()->getParent()->getParent()->getDataLayout();
3458 valSz = (DL.getTypeSizeInBits(EVI->getType()) + 7) / 8;
3467 while (!isa<Function>(fn)) {
3468 if (
auto ci = dyn_cast<CastInst>(fn)) {
3469 fn = ci->getOperand(0);
3472 if (
auto ci = dyn_cast<ConstantExpr>(fn)) {
3474 fn = ci->getOperand(0);
3478 if (
auto ci = dyn_cast<BlockAddress>(fn)) {
3479 fn = ci->getFunction();
3482 if (
auto *GA = dyn_cast<GlobalAlias>(fn)) {
3483 fn = GA->getAliasee();
3486 if (
auto *
Call = dyn_cast<CallInst>(fn)) {
3487 if (
auto F =
Call->getCalledFunction()) {
3488 SmallPtrSet<Value *, 1> ret;
3489 for (
auto &BB : *F) {
3490 if (
auto RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
3491 ret.insert(RI->getReturnValue());
3494 if (ret.size() == 1) {
3495 auto val = *ret.begin();
3497 if (isa<Constant>(val)) {
3501 if (
auto arg = dyn_cast<Argument>(val)) {
3502 fn =
Call->getArgOperand(arg->getArgNo());
3508 if (
auto *
Call = dyn_cast<InvokeInst>(fn)) {
3509 if (
auto F =
Call->getCalledFunction()) {
3510 SmallPtrSet<Value *, 1> ret;
3511 for (
auto &BB : *F) {
3512 if (
auto RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
3513 ret.insert(RI->getReturnValue());
3516 if (ret.size() == 1) {
3517 auto val = *ret.begin();
3518 while (isa<LoadInst>(val)) {
3526 if (isa<Constant>(val)) {
3530 if (
auto arg = dyn_cast<Argument>(val)) {
3531 fn =
Call->getArgOperand(arg->getArgNo());
3553 if (!F.isDeclaration()) {
3560#if LLVM_VERSION_MAJOR >= 16
3561std::optional<BlasInfo>
extractBLAS(llvm::StringRef in)
3566 const char *extractable[] = {
3567 "dot",
"scal",
"axpy",
"gemv",
"gemm",
"spmv",
"syrk",
"nrm2",
3568 "trmm",
"trmv",
"symm",
"potrf",
"potrs",
"copy",
"spmv",
"syr2k",
3569 "potrs",
"getrf",
"getrs",
"trtrs",
"getri",
"symv",
"lacpy",
"trsv",
3571 const char *floatType[] = {
"s",
"d",
"c",
"z"};
3572 const char *prefixes[] = {
"" ,
"cblas_"};
3573 const char *suffixes[] = {
"",
"_",
"64_",
"_64_"};
3574 for (
auto t : floatType) {
3575 for (
auto f : extractable) {
3576 for (
auto p : prefixes) {
3577 for (
auto s : suffixes) {
3578 if (in == (Twine(p) + t + f + s).
str()) {
3579 bool is64 = llvm::StringRef(s).contains(
"64");
3589 const char *cuCFloatType[] = {
"S",
"D",
"C",
"Z"};
3590 const char *cuFFloatType[] = {
"s",
"d",
"c",
"z"};
3591 const char *cuCPrefixes[] = {
"cublas"};
3592 const char *cuSuffixes[] = {
"",
"_v2",
"_64",
"_v2_64"};
3593 for (
auto t : llvm::enumerate(cuCFloatType)) {
3594 for (
auto f : extractable) {
3595 for (
auto p : cuCPrefixes) {
3596 for (
auto s : cuSuffixes) {
3597 if (in == (Twine(p) + t.value() + f + s).str()) {
3598 bool is64 = llvm::StringRef(s).contains(
"64");
3600 t.value(), p, s, f, is64,
3608 const char *cuFPrefixes[] = {
"cublas_"};
3609 for (
auto t : cuFFloatType) {
3610 for (
auto f : extractable) {
3611 for (
auto p : cuFPrefixes) {
3612 if (in == (Twine(p) + t + f).
str()) {
3626 return cast<Constant>(
3629 return Constant::getNullValue(T);
3631 return UndefValue::get(T);
3635 llvm::IRBuilder<> &BuilderM,
3636 llvm::Value *mask) {
3638 auto current_bb = BuilderM.GetInsertBlock();
3639 auto fn = current_bb->getParent();
3640 auto mod = fn->getParent();
3641 auto &Context = mod->getContext();
3643 std::string type_str;
3644 llvm::raw_string_ostream type_ss(type_str);
3645 toset->getType()->print(type_ss);
3646 std::string fn_name =
"__enzyme_sanitize_nan_" + type_str;
3648 llvm::FunctionType *SanitizeFT = llvm::FunctionType::get(
3649 llvm::Type::getVoidTy(Context),
3652 auto SanitizeFCallee = mod->getOrInsertFunction(fn_name, SanitizeFT);
3653 llvm::Function *SanitizeF =
3654 llvm::cast<llvm::Function>(SanitizeFCallee.getCallee());
3656 if (SanitizeF->empty()) {
3657 SanitizeF->setLinkage(Function::LinkageTypes::InternalLinkage);
3658 llvm::BasicBlock *entry =
3659 llvm::BasicBlock::Create(Context,
"entry", SanitizeF);
3660 llvm::BasicBlock *good =
3661 llvm::BasicBlock::Create(Context,
"good", SanitizeF);
3662 llvm::BasicBlock *bad =
3663 llvm::BasicBlock::Create(Context,
"bad", SanitizeF);
3665 llvm::IRBuilder<> B(entry);
3666 llvm::Value *inp = SanitizeF->getArg(0);
3667 llvm::Value *msg_ptr = SanitizeF->getArg(1);
3669 llvm::Value *cmp = B.CreateFCmpUNO(inp, inp);
3670 if (
auto VT = llvm::dyn_cast<llvm::VectorType>(inp->getType())) {
3671#if LLVM_VERSION_MAJOR >= 12
3672 unsigned len = VT->getElementCount().getKnownMinValue();
3674 unsigned len = VT->getNumElements();
3676 llvm::Value *res = B.CreateExtractElement(cmp, (uint64_t)0);
3677 for (
unsigned i = 1; i < len; ++i) {
3678 res = B.CreateOr(res, B.CreateExtractElement(cmp, (uint64_t)i));
3682 B.CreateCondBr(cmp, bad, good);
3684 B.SetInsertPoint(good);
3687 B.SetInsertPoint(bad);
3690 wrap(msg_ptr), wrap(&B));
3692 llvm::FunctionType *PutsFT = llvm::FunctionType::get(
3693 llvm::Type::getInt32Ty(Context), {
getInt8PtrTy(Context)},
false);
3694 auto PutsF = mod->getOrInsertFunction(
"puts", PutsFT);
3695 B.CreateCall(PutsF, msg_ptr);
3697 llvm::FunctionType *ExitFT =
3698 llvm::FunctionType::get(llvm::Type::getVoidTy(Context),
3699 {llvm::Type::getInt32Ty(Context)},
false);
3700 auto ExitF = mod->getOrInsertFunction(
"exit", ExitFT);
3702 ExitF, llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), 1));
3704 B.CreateUnreachable();
3707 std::string stringv =
"Enzyme: Found nan while computing derivative of ";
3710 llvm::raw_string_ostream ss(
str);
3711 if (
auto inst = llvm::dyn_cast<llvm::Instruction>(val)) {
3712 ss << *inst <<
"\n";
3717 stringv += ss.str();
3722 BuilderM.CreateCall(SanitizeFCallee, {toset,
getString(*mod, stringv)});
3727 wrap(&BuilderM), wrap(mask)));
3732 llvm::FastMathFlags f;
3739 llvm::SmallVectorImpl<llvm::Value *> &cacheValues,
3740 llvm::IRBuilder<> &BuilderZ,
const Twine &name) {
3743 if (!arg->getType()->isPointerTy()) {
3744 assert(arg->getType() == ty);
3745 cacheValues.push_back(arg);
3748#if LLVM_VERSION_MAJOR < 17
3749 auto PT = cast<PointerType>(arg->getType());
3750#if LLVM_VERSION_MAJOR <= 14
3751 if (PT->getElementType() != ty)
3752 arg = BuilderZ.CreatePointerCast(
3753 arg, PointerType::get(ty, PT->getAddressSpace()),
"pcld." + name);
3755 auto PT2 = PointerType::get(ty, PT->getAddressSpace());
3756 if (!PT->isOpaqueOrPointeeTypeMatches(PT2))
3757 arg = BuilderZ.CreatePointerCast(
3758 arg, PointerType::get(ty, PT->getAddressSpace()),
"pcld." + name);
3761 arg = BuilderZ.CreateLoad(ty, arg,
"avld." + name);
3762 cacheValues.push_back(arg);
3768 bool cublas, IntegerType *julia_decl,
3769 IRBuilder<> &entryBuilder,
3770 llvm::Twine
const &name) {
3775 entryBuilder.CreateAlloca(V->getType(),
nullptr,
"byref." + name);
3776 B.CreateStore(V, allocV);
3779 allocV = B.CreatePointerCast(allocV,
getInt8PtrTy(V->getContext()),
3785 Type *fpTy, IRBuilder<> &entryBuilder,
3786 llvm::Twine
const &name) {
3791 entryBuilder.CreateAlloca(V->getType(),
nullptr,
"byref." + name);
3792 B.CreateStore(V, allocV);
3795 allocV = B.CreatePointerCast(allocV, fpTy,
"fpcast." + name);
3800Value *
is_lower(IRBuilder<> &B, Value *uplo,
bool byRef,
bool cublas) {
3802 Value *isNormal =
nullptr;
3803 isNormal = B.CreateICmpEQ(
3804 uplo, ConstantInt::get(uplo->getType(),
3808 if (
auto CI = dyn_cast<ConstantInt>(uplo)) {
3809 if (CI->getValue() ==
'L' || CI->getValue() ==
'l')
3810 return ConstantInt::getTrue(B.getContext());
3811 if (CI->getValue() ==
'U' || CI->getValue() ==
'u')
3812 return ConstantInt::getFalse(B.getContext());
3816 IntegerType *charTy = IntegerType::get(uplo->getContext(), 8);
3817 uplo = B.CreateLoad(charTy, uplo,
"loaded.trans");
3819 auto isL = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(),
'L'));
3820 auto isl = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(),
'l'));
3822 return B.CreateOr(isl, isL);
3825 auto capi = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 122));
3828 auto isL = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(),
'L'));
3829 auto isl = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(),
'l'));
3830 return B.CreateOr(capi, B.CreateOr(isl, isL));
3834Value *
is_nonunit(IRBuilder<> &B, Value *uplo,
bool byRef,
bool cublas) {
3836 Value *isNormal =
nullptr;
3838 B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(),
3842 if (
auto CI = dyn_cast<ConstantInt>(uplo)) {
3843 if (CI->getValue() ==
'N' || CI->getValue() ==
'n')
3844 return ConstantInt::getTrue(B.getContext());
3845 if (CI->getValue() ==
'U' || CI->getValue() ==
'u')
3846 return ConstantInt::getFalse(B.getContext());
3850 IntegerType *charTy = IntegerType::get(uplo->getContext(), 8);
3851 uplo = B.CreateLoad(charTy, uplo,
"loaded.nonunit");
3853 auto isL = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(),
'N'));
3854 auto isl = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(),
'n'));
3856 return B.CreateOr(isl, isL);
3859 auto capi = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 131));
3862 auto isL = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(),
'N'));
3863 auto isl = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(),
'n'));
3864 return B.CreateOr(capi, B.CreateOr(isl, isL));
3868llvm::Value *
is_normal(IRBuilder<> &B, llvm::Value *trans,
bool byRef,
3871 Value *isNormal =
nullptr;
3872 isNormal = B.CreateICmpEQ(
3873 trans, ConstantInt::get(trans->getType(),
3878 if (
auto CI = dyn_cast<ConstantInt>(trans)) {
3879 if (CI->getValue() ==
'N' || CI->getValue() ==
'n')
3880 return ConstantInt::getTrue(
3885 IntegerType *charTy = IntegerType::get(trans->getContext(), 8);
3886 trans = B.CreateLoad(charTy, trans,
"loaded.trans");
3888 auto isN = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(),
'N'));
3889 auto isn = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(),
'n'));
3891 return B.CreateOr(isn, isN);
3896 auto capi = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 111));
3897 auto isN = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(),
'N'));
3898 auto isn = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(),
'n'));
3900 return B.CreateOr(capi, B.CreateOr(isn, isN));
3904llvm::Value *
is_left(IRBuilder<> &B, llvm::Value *side,
bool byRef,
3907 Value *isNormal =
nullptr;
3908 isNormal = B.CreateICmpEQ(
3909 side, ConstantInt::get(side->getType(),
3914 if (
auto CI = dyn_cast<ConstantInt>(side)) {
3915 if (CI->getValue() ==
'L' || CI->getValue() ==
'l')
3916 return ConstantInt::getTrue(B.getContext());
3917 if (CI->getValue() ==
'R' || CI->getValue() ==
'r')
3918 return ConstantInt::getFalse(B.getContext());
3922 IntegerType *charTy = IntegerType::get(side->getContext(), 8);
3923 side = B.CreateLoad(charTy, side,
"loaded.side");
3925 auto isL = B.CreateICmpEQ(side, ConstantInt::get(side->getType(),
'L'));
3926 auto isl = B.CreateICmpEQ(side, ConstantInt::get(side->getType(),
'l'));
3928 return B.CreateOr(isl, isL);
3933 auto capi = B.CreateICmpEQ(side, ConstantInt::get(side->getType(), 141));
3934 auto isL = B.CreateICmpEQ(side, ConstantInt::get(side->getType(),
'L'));
3935 auto isl = B.CreateICmpEQ(side, ConstantInt::get(side->getType(),
'l'));
3937 return B.CreateOr(capi, B.CreateOr(isl, isL));
3947llvm::Value *
transpose(std::string floatType, IRBuilder<> &B, llvm::Value *V,
3949 llvm::Type *T = V->getType();
3951 auto isT1 = B.CreateICmpEQ(V, ConstantInt::get(T, 1));
3952 auto isT0 = B.CreateICmpEQ(V, ConstantInt::get(T, 0));
3953 return B.CreateSelect(isT1, ConstantInt::get(V->getType(), 0),
3954 B.CreateSelect(isT0,
3955 ConstantInt::get(V->getType(), 1),
3956 ConstantInt::get(V->getType(), 42)));
3957 }
else if (T->isIntegerTy(8)) {
3958 if (floatType ==
"z" || floatType ==
"c") {
3959 auto isn = B.CreateICmpEQ(V, ConstantInt::get(V->getType(),
'n'));
3960 auto sel1 = B.CreateSelect(isn, ConstantInt::get(V->getType(),
'c'),
3961 ConstantInt::get(V->getType(), 0));
3963 auto isN = B.CreateICmpEQ(V, ConstantInt::get(V->getType(),
'N'));
3965 B.CreateSelect(isN, ConstantInt::get(V->getType(),
'C'), sel1);
3967 auto ist = B.CreateICmpEQ(V, ConstantInt::get(V->getType(),
'c'));
3969 B.CreateSelect(ist, ConstantInt::get(V->getType(),
'n'), sel2);
3971 auto isT = B.CreateICmpEQ(V, ConstantInt::get(V->getType(),
'C'));
3972 return B.CreateSelect(isT, ConstantInt::get(V->getType(),
'N'), sel3);
3975 auto isn = B.CreateICmpEQ(V, ConstantInt::get(V->getType(),
'n'));
3976 auto sel1 = B.CreateSelect(isn, ConstantInt::get(V->getType(),
't'),
3977 ConstantInt::get(V->getType(),
'N'));
3979 auto isN = B.CreateICmpEQ(V, ConstantInt::get(V->getType(),
'N'));
3981 B.CreateSelect(isN, ConstantInt::get(V->getType(),
'T'), sel1);
3983 auto ist = B.CreateICmpEQ(V, ConstantInt::get(V->getType(),
't'));
3985 B.CreateSelect(ist, ConstantInt::get(V->getType(),
'n'), sel2);
3987 auto isT = B.CreateICmpEQ(V, ConstantInt::get(V->getType(),
'T'));
3988 return B.CreateSelect(isT, ConstantInt::get(V->getType(),
'N'), sel3);
3991 }
else if (T->isIntegerTy(32)) {
3992 auto is111 = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 111));
3993 auto sel1 = B.CreateSelect(
3994 B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 112)),
3995 ConstantInt::get(V->getType(), 111), ConstantInt::get(V->getType(), 0));
3996 return B.CreateSelect(is111, ConstantInt::get(V->getType(), 112), sel1);
3999 llvm::raw_string_ostream ss(s);
4000 ss <<
"cannot handle unknown trans blas value\n" << V;
4003 nullptr,
nullptr,
nullptr);
4005 EmitFailure(
"unknown trans blas value", B.getCurrentDebugLocation(),
4006 B.GetInsertBlock()->getParent(), ss.str());
4019 llvm::ArrayRef<llvm::Value *> trans,
4020 llvm::Value *arg_ld, llvm::Value *dim1,
4021 llvm::Value *dim2,
bool cacheMat,
bool byRef,
4026 assert(trans.size() == 1);
4028 llvm::Value *width =
4034llvm::Value *
transpose(std::string floatType, llvm::IRBuilder<> &B,
4035 llvm::Value *V,
bool byRef,
bool cublas,
4036 llvm::IntegerType *julia_decl,
4037 llvm::IRBuilder<> &entryBuilder,
4038 const llvm::Twine &name) {
4042 if (
auto CI = dyn_cast<ConstantInt>(V)) {
4043 if (floatType ==
"c" || floatType ==
"z") {
4044 if (CI->getValue() ==
'N')
4045 return ConstantInt::get(CI->getType(),
'C');
4046 if (CI->getValue() ==
'c')
4047 return ConstantInt::get(CI->getType(),
'c');
4049 if (CI->getValue() ==
'N')
4050 return ConstantInt::get(CI->getType(),
'T');
4051 if (CI->getValue() ==
'n')
4052 return ConstantInt::get(CI->getType(),
't');
4058 return B.CreateSelect(
4059 B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 111)),
4060 ConstantInt::get(V->getType(), 112),
4061 ConstantInt::get(V->getType(), 111));
4065 auto charType = IntegerType::get(V->getContext(), 8);
4066 V = B.CreateLoad(charType, V,
"ld." + name);
4072 "transpose." + name);
4076 llvm::Value *V,
bool byRef) {
4080 if (V->getType()->isIntegerTy())
4081 V = B.CreateIntToPtr(V,
getUnqual(intType));
4083 V = B.CreatePointerCast(
4084 V, PointerType::get(
4085 intType, cast<PointerType>(V->getType())->getAddressSpace()));
4086 return B.CreateLoad(intType, V);
4090 ArrayRef<llvm::Value *> transA,
4091 bool byRef,
bool cublas) {
4092 assert(transA.size() == 1);
4093 auto trans = transA[0];
4095 auto charType = IntegerType::get(trans->getContext(), 8);
4096 trans = B.CreateLoad(charType, trans,
"ld.row.trans");
4099 Value *cond =
nullptr;
4103 cond = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 111));
4105 auto isn = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(),
'n'));
4106 auto isN = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(),
'N'));
4107 cond = B.CreateOr(isN, isn);
4112 cond = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 0));
4117 ArrayRef<llvm::Value *> transA,
4118 ArrayRef<llvm::Value *> row,
4119 ArrayRef<llvm::Value *> col,
4120 bool byRef,
bool cublas) {
4122 assert(row.size() == col.size());
4123 SmallVector<Value *, 1> toreturn;
4124 for (
size_t i = 0; i < row.size(); i++) {
4127 if (lhs->getType() != rhs->getType())
4128 rhs = B.CreatePointerCast(rhs, lhs->getType());
4129 toreturn.push_back(B.CreateSelect(conds[0], lhs, rhs));
4137 if (isa<PointerType>(T)) {
4143 }
else if (isa<StructType>(T) || isa<ArrayType>(T) || isa<VectorType>(T)) {
4144 for (Type *ElT : T->subtypes()) {
4150 if (isa<ArrayType>(T))
4151 count *= cast<ArrayType>(T)->getNumElements();
4152 else if (isa<VectorType>(T)) {
4153#if LLVM_VERSION_MAJOR >= 12
4154 count *= cast<VectorType>(T)->getElementCount().getKnownMinValue();
4156 count *= cast<VectorType>(T)->getNumElements();
4164#if LLVM_VERSION_MAJOR >= 20
4165bool collectOffset(GEPOperator *gep,
const DataLayout &DL,
unsigned BitWidth,
4166 SmallMapVector<Value *, APInt, 4> &VariableOffsets,
4167 APInt &ConstantOffset)
4170 MapVector<Value *, APInt> &VariableOffsets,
4171 APInt &ConstantOffset)
4174#if LLVM_VERSION_MAJOR >= 13
4175 return gep->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset);
4177 assert(BitWidth == DL.getIndexSizeInBits(gep->getPointerAddressSpace()) &&
4178 "The offset bit width does not match DL specification.");
4180 auto CollectConstantOffset = [&](APInt Index, uint64_t Size) {
4181 Index = Index.sextOrTrunc(BitWidth);
4182 APInt IndexedSize = APInt(BitWidth, Size);
4183 ConstantOffset += Index * IndexedSize;
4186 for (gep_type_iterator GTI = gep_type_begin(gep), GTE = gep_type_end(gep);
4187 GTI != GTE; ++GTI) {
4189 bool ScalableType = isa<ScalableVectorType>(GTI.getIndexedType());
4191 Value *V = GTI.getOperand();
4192 StructType *STy = GTI.getStructTypeOrNull();
4194 if (
auto ConstOffset = dyn_cast<ConstantInt>(V)) {
4195 if (ConstOffset->isZero())
4206 unsigned ElementIdx = ConstOffset->getZExtValue();
4207 const StructLayout *SL = DL.getStructLayout(STy);
4209 CollectConstantOffset(APInt(BitWidth, SL->getElementOffset(ElementIdx)),
4213 CollectConstantOffset(ConstOffset->getValue(),
4214 DL.getTypeAllocSize(GTI.getIndexedType()));
4218 if (STy || ScalableType)
4221 APInt(BitWidth, DL.getTypeAllocSize(GTI.getIndexedType()));
4224 if (IndexedSize != 0) {
4225 VariableOffsets.insert({V, APInt(BitWidth, 0)});
4226 VariableOffsets[V] += IndexedSize;
4234 llvm::Intrinsic::ID ID, llvm::Type *RetTy,
4235 llvm::ArrayRef<llvm::Value *>
Args,
4236 llvm::Instruction *FMFSource,
4237 const llvm::Twine &Name) {
4238#if LLVM_VERSION_MAJOR >= 16
4239 llvm::CallInst *nres = B.CreateIntrinsic(RetTy, ID,
Args, FMFSource, Name);
4241 SmallVector<Intrinsic::IITDescriptor, 1> Table;
4242 Intrinsic::getIntrinsicInfoTableEntries(ID, Table);
4243 ArrayRef<Intrinsic::IITDescriptor> TableRef(Table);
4245 SmallVector<Type *, 2> ArgTys;
4246 ArgTys.reserve(
Args.size());
4247 for (
auto &I :
Args)
4248 ArgTys.push_back(I->getType());
4249 FunctionType *FTy = FunctionType::get(RetTy, ArgTys,
false);
4250 SmallVector<Type *, 2> OverloadTys;
4251 Intrinsic::MatchIntrinsicTypesResult Res =
4252 matchIntrinsicSignature(FTy, TableRef, OverloadTys);
4254 assert(Res == Intrinsic::MatchIntrinsicTypes_Match && TableRef.empty() &&
4255 "Wrong types for intrinsic!");
4256 Function *Fn = Intrinsic::getDeclaration(B.GetInsertPoint()->getModule(), ID,
4258 CallInst *nres = B.CreateCall(Fn,
Args, {}, Name);
4260 nres->copyFastMathFlags(FMFSource);
4272llvm::Value *
get1ULP(llvm::IRBuilder<> &builder, llvm::Value *res) {
4273 auto ty = res->getType();
4274 unsigned tsize = builder.GetInsertBlock()
4278 .getTypeSizeInBits(ty);
4280 auto ity = IntegerType::get(ty->getContext(), tsize);
4282 auto as_int = builder.CreateBitCast(res, ity);
4283 auto masked = builder.CreateXor(as_int, ConstantInt::get(ity, 1));
4284 auto neighbor = builder.CreateBitCast(masked, ty);
4286 auto diff = builder.CreateFSub(res, neighbor);
4288 auto absres = builder.CreateIntrinsic(Intrinsic::fabs,
4289 ArrayRef<Type *>(diff->getType()),
4290 ArrayRef<Value *>(diff));
4296 llvm::Instruction &inst,
4298 llvm::IRBuilder<> &Builder2,
4299 llvm::Value *condition) {
4303 wrap(condition), wrap(&Builder2)));
4305 auto &M = *inst.getParent()->getParent()->getParent();
4306 FunctionType *FT = FunctionType::get(Type::getInt32Ty(M.getContext()),
4307 {getInt8PtrTy(M.getContext())},
false);
4309 raw_string_ostream ss(
str);
4310 ss << message <<
"\n";
4314 auto PutsF = M.getOrInsertFunction(
"puts", FT);
4315 Builder2.CreateCall(PutsF, msg);
4318 FunctionType::get(Type::getVoidTy(M.getContext()),
4319 {Type::getInt32Ty(M.getContext())},
false);
4321 auto ExitF = M.getOrInsertFunction(
"exit", FT2);
4322 Builder2.CreateCall(ExitF,
4323 ConstantInt::get(Type::getInt32Ty(M.getContext()), 1));
4326 if (StringRef(message).
contains(
"cannot handle above cast")) {
4329 EmitFailure(
"NoDerivative", inst.getDebugLoc(), &inst, message);
4336 Value *toshow = todiff;
4338 toshow = context.
req;
4342 nullptr, wrap(todiff), wrap(context.
ip));
4345 auto &M = *context.
ip->GetInsertBlock()->getParent()->getParent();
4346 FunctionType *FT = FunctionType::get(Type::getInt32Ty(M.getContext()),
4347 {getInt8PtrTy(M.getContext())},
false);
4349 raw_string_ostream ss(
str);
4350 ss << message <<
"\n";
4351 if (
auto inst = dyn_cast<Instruction>(todiff))
4354 auto PutsF = M.getOrInsertFunction(
"puts", FT);
4355 context.
ip->CreateCall(PutsF, msg);
4358 FunctionType::get(Type::getVoidTy(M.getContext()),
4359 {Type::getInt32Ty(M.getContext())},
false);
4361 auto ExitF = M.getOrInsertFunction(
"exit", FT2);
4362 context.
ip->CreateCall(
4363 ExitF, ConstantInt::get(Type::getInt32Ty(M.getContext()), 1));
4365 }
else if (context.
req) {
4369 }
else if (
auto arg = dyn_cast<Instruction>(todiff)) {
4370 auto loc = arg->getDebugLoc();
4381 gutils->
TR.
analyzer,
nullptr, wrap(&Builder2));
4383 auto &M = *inst.getParent()->getParent()->getParent();
4384 FunctionType *FT = FunctionType::get(Type::getInt32Ty(M.getContext()),
4385 {getInt8PtrTy(M.getContext())},
false);
4387 raw_string_ostream ss(
str);
4388 ss << message <<
"\n";
4391 auto PutsF = M.getOrInsertFunction(
"puts", FT);
4392 Builder2.CreateCall(PutsF, msg);
4395 FunctionType::get(Type::getVoidTy(M.getContext()),
4396 {Type::getInt32Ty(M.getContext())},
false);
4398 auto ExitF = M.getOrInsertFunction(
"exit", FT2);
4399 Builder2.CreateCall(ExitF,
4400 ConstantInt::get(Type::getInt32Ty(M.getContext()), 1));
4403 raw_string_ostream ss(
str);
4404 ss << message <<
"\n";
4406 EmitFailure(
"CannotDeduceType", inst.getDebugLoc(), &inst, ss.str());
4410std::vector<std::tuple<llvm::Type *, size_t, size_t>>
4412 std::vector<std::pair<ConcreteType, size_t>> parsed;
4413 for (
size_t i = 0; i < md->getNumOperands(); i += 2) {
4415 llvm::cast<llvm::MDString>(md->getOperand(i))->getString(),
4417 auto size = llvm::cast<llvm::ConstantInt>(
4418 llvm::cast<llvm::ConstantAsMetadata>(md->getOperand(i + 1))
4421 parsed.emplace_back(base, size);
4424 std::vector<std::tuple<llvm::Type *, size_t, size_t>> toIterate;
4426 while (idx < parsed.size()) {
4428 auto dt = parsed[idx].first;
4429 size_t start = parsed[idx].second;
4430 size_t end = 0x0fffffff;
4431 for (idx = idx + 1; idx < parsed.size(); ++idx) {
4434 auto next = parsed[idx].first;
4435 tmp.checkedOrIn(next,
true, Legal);
4452 if ((parsed[idx].first.isFloat() ==
nullptr) ==
4453 (parsed[idx - 1].first.isFloat() ==
nullptr)) {
4461 end = parsed[idx].second;
4467 assert(dt.isKnown());
4468 toIterate.emplace_back(dt.isFloat(), start, end - start);
4473void dumpModule(llvm::Module *mod) { llvm::errs() << *mod <<
"\n"; }
4475void dumpValue(llvm::Value *val) { llvm::errs() << *val <<
"\n"; }
4477void dumpBlock(llvm::BasicBlock *blk) { llvm::errs() << *blk <<
"\n"; }
4479void dumpType(llvm::Type *ty) { llvm::errs() << *ty <<
"\n"; }
4484 auto II = dyn_cast<IntrinsicInst>(V);
4487 switch (II->getIntrinsicID()) {
4488 case Intrinsic::nvvm_ldu_global_i:
4489 case Intrinsic::nvvm_ldu_global_p:
4490 case Intrinsic::nvvm_ldu_global_f:
4491#if LLVM_VERSION_MAJOR < 20
4492 case Intrinsic::nvvm_ldg_global_i:
4493 case Intrinsic::nvvm_ldg_global_p:
4494 case Intrinsic::nvvm_ldg_global_f:
4504 size_t checkLoadCaptures) {
4505 Instruction *VI = dyn_cast<Instruction>(V);
4507 VI = &*inst->getParent()->getParent()->getEntryBlock().begin();
4509 VI = VI->getNextNode();
4510 SmallPtrSet<BasicBlock *, 1> regionBetween;
4512 SmallVector<BasicBlock *, 1> todo;
4513 todo.push_back(VI->getParent());
4514 while (todo.size()) {
4515 auto cur = todo.pop_back_val();
4516 if (regionBetween.count(cur))
4518 regionBetween.insert(cur);
4519 if (cur == inst->getParent())
4521 for (
auto BB : successors(cur))
4525 SmallVector<std::tuple<Instruction *, size_t, Value *>, 1> todo;
4526 for (
auto U : V->users()) {
4527 todo.emplace_back(cast<Instruction>(U), checkLoadCaptures, V);
4529 std::set<std::tuple<Value *, size_t, Value *>> seen;
4530 while (todo.size()) {
4531 auto pair = todo.pop_back_val();
4532 if (seen.count(pair))
4535 auto UI = std::get<0>(pair);
4536 auto level = std::get<1>(pair);
4537 auto prev = std::get<2>(pair);
4539 if (!regionBetween.count(UI->getParent()))
4541 if (UI->getParent() == VI->getParent()) {
4542 if (UI->comesBefore(VI))
4545 if (UI->getParent() == inst->getParent())
4546 if (inst->comesBefore(UI))
4552 for (
auto U2 : UI->users()) {
4553 auto UI2 = cast<Instruction>(U2);
4554 todo.emplace_back(UI2, level, UI);
4559 if (isa<MemSetInst>(UI))
4562 if (isa<MemTransferInst>(UI)) {
4565 if (UI->getOperand(1) != prev)
4569 if (
auto CI = dyn_cast<CallBase>(UI)) {
4570#if LLVM_VERSION_MAJOR >= 14
4571 for (
size_t i = 0, size = CI->arg_size(); i < size; i++)
4573 for (
size_t i = 0, size = CI->getNumArgOperands(); i < size; i++)
4576 if (prev == CI->getArgOperand(i)) {
4585 if (isa<CmpInst>(UI)) {
4588 if (isa<LoadInst>(UI)) {
4590 for (
auto U2 : UI->users()) {
4591 auto UI2 = cast<Instruction>(U2);
4592 todo.emplace_back(UI2, level - 1, UI);
4598 if (
auto SI = dyn_cast<StoreInst>(UI)) {
4599 if (SI->getValueOperand() != prev) {
4613#if LLVM_VERSION_MAJOR >= 16
4619 llvm::LoopInfo &LI, llvm::Value *op0,
4620 llvm::Value *op1,
bool offsetAllowed) {
4627 if (
auto i1 = dyn_cast<Instruction>(op1))
4628 if (isa<ConstantPointerNull>(op0) &&
4632 if (
auto i0 = dyn_cast<Instruction>(op0))
4633 if (isa<ConstantPointerNull>(op1) &&
4638 if (!lhs->getType()->isPointerTy() && !rhs->getType()->isPointerTy())
4644 bool noalias[2] = {noalias_lhs, noalias_rhs};
4646 for (
int i = 0; i < 2; i++) {
4647 Value *start = (i == 0) ? lhs : rhs;
4648 Value *end = (i == 0) ? rhs : lhs;
4650 if (noalias[1 - i]) {
4653 if (isa<Argument>(end)) {
4656 if (
auto endi = dyn_cast<Instruction>(end)) {
4662 if (
auto ld = dyn_cast<LoadInst>(start)) {
4665 if (isa<Argument>(end))
4667 if (
auto endi = dyn_cast<Instruction>(end))
4669 Instruction *starti = dyn_cast<Instruction>(start);
4671 if (!isa<Argument>(start))
4674 &cast<Argument>(start)->getParent()->getEntryBlock().front();
4677 bool overwritten =
false;
4679 LI, starti, endi, [&](Instruction *I) ->
bool {
4680 if (!I->mayWriteToMemory())
4705 ArrayRef<unsigned> path) {
4706 SmallVector<Value *, 2> vals;
4707 vals.push_back(ConstantInt::get(B.getInt64Ty(), 0));
4708 for (
auto v : path) {
4709 vals.push_back(ConstantInt::get(B.getInt32Ty(), v));
4711 return B.CreateInBoundsGEP(type, value, vals);
4715 llvm::Value *sret, llvm::Type *root_ty,
4716 llvm::Value *rootRet,
size_t rootOffset,
4718 std::deque<std::pair<llvm::Type *, std::vector<unsigned>>> todo = {
4720 SmallVector<Value *> extracted;
4722 auto rootOffset0 = rootOffset;
4723 while (!todo.empty()) {
4724 auto cur = std::move(todo[0]);
4726 auto path = std::move(cur.second);
4727 auto ty = cur.first;
4729 if (
auto PT = dyn_cast<PointerType>(ty)) {
4733 Value *loc =
nullptr;
4734 switch (direction) {
4742 llvm_unreachable(
"Unhandled");
4744 switch (direction) {
4747 outloc = B.CreateLoad(ty, outloc);
4748 B.CreateStore(outloc, loc);
4753 outloc = B.CreatePointerCast(
4754 outloc, PointerType::get(StructType::get(outloc->getContext(), {}),
4756 B.CreateStore(outloc, loc);
4760 loc = B.CreateLoad(ty, loc);
4761 val = B.CreateInsertValue(val, loc, path);
4766 *B.GetInsertBlock()->getParent()->getParent(), ty,
false);
4767 val = B.CreateInsertValue(val, loc, path);
4772 loc = B.CreateLoad(ty, loc);
4773 extracted.push_back(loc);
4774 B.CreateStore(loc, outloc);
4778 llvm_unreachable(
"Unhandled");
4786 if (
auto AT = dyn_cast<ArrayType>(ty)) {
4787 for (
size_t i = 0, E = AT->getNumElements(); i < E; i++) {
4788 std::vector<unsigned> path2(path);
4789 path2.push_back(E - 1 - i);
4790 todo.emplace_front(AT->getElementType(), path2);
4795 if (
auto VT = dyn_cast<VectorType>(ty)) {
4796 for (
size_t i = 0, E = VT->getElementCount().getKnownMinValue(); i < E;
4798 std::vector<unsigned> path2(path);
4799 path2.push_back(E - 1 - i);
4800 todo.emplace_front(VT->getElementType(), path2);
4805 if (
auto ST = dyn_cast<StructType>(ty)) {
4806 for (
size_t i = 0, E = ST->getNumElements(); i < E; i++) {
4807 std::vector<unsigned> path2(path);
4808 path2.push_back(E - 1 - i);
4809 todo.emplace_front(ST->getTypeAtIndex(E - 1 - i), path2);
4817 auto PT = cast<PointerType>(obj->getType());
4818 assert(PT->getAddressSpace() == 0 || PT->getAddressSpace() == 10);
4819 if (PT->getAddressSpace() == 10 && extracted.size()) {
4820 extracted.insert(extracted.begin(), obj);
4821 auto JLT = PointerType::get(StructType::get(PT->getContext(), {}), 10);
4822 auto FT = FunctionType::get(JLT, {},
true);
4824 B.GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
4825 "julia.write_barrier", FT);
4826 assert(obj->getType() == JLT);
4827 B.CreateCall(wb, extracted);
4832 assert(rootOffset - rootOffset0 == tracked.
count);
4838 llvm::Type *dstType, llvm::Value *dst,
4839 llvm::ArrayRef<unsigned> dstPrefix0,
4840 llvm::Type *srcType, llvm::Value *src,
4841 llvm::ArrayRef<unsigned> srcPrefix0,
bool shouldZero) {
4843 std::tuple<llvm::Type *, std::vector<unsigned>, std::vector<unsigned>>>
4845 std::vector<unsigned>(dstPrefix0.begin(), dstPrefix0.end()),
4846 std::vector<unsigned>(srcPrefix0.begin(), srcPrefix0.end())}};
4848 auto &M = *B.GetInsertBlock()->getParent()->getParent();
4850 size_t numRootsSeen = 0;
4852 while (!todo.empty()) {
4853 auto cur = std::move(todo[0]);
4854 auto &&[ty, dstPrefix, srcPrefix] = cur;
4857 if (
auto PT = dyn_cast<PointerType>(ty)) {
4858 if (PT->getAddressSpace() == 10) {
4862 if (dstPrefix.size() > 0)
4871 if (
auto AT = dyn_cast<ArrayType>(ty)) {
4872 for (
size_t i = 0, E = AT->getNumElements(); i < E; i++) {
4873 std::vector<unsigned> nextDst(dstPrefix);
4874 std::vector<unsigned> nextSrc(srcPrefix);
4875 nextDst.push_back(E - 1 - i);
4876 nextSrc.push_back(E - 1 - i);
4877 todo.emplace_front(AT->getElementType(), std::move(nextDst),
4878 std::move(nextSrc));
4883 if (
auto ST = dyn_cast<StructType>(ty)) {
4884 for (
size_t i = 0, E = ST->getNumElements(); i < E; i++) {
4885 std::vector<unsigned> nextDst(dstPrefix);
4886 std::vector<unsigned> nextSrc(srcPrefix);
4887 nextDst.push_back(E - 1 - i);
4888 nextSrc.push_back(E - 1 - i);
4889 todo.emplace_front(ST->getElementType(E - 1 - i), std::move(nextDst),
4890 std::move(nextSrc));
4896 if (dstPrefix.size() > 0)
4900 if (srcPrefix.size() > 0)
4903 auto ld = B.CreateLoad(ty, in);
4904 B.CreateStore(ld, out);
4908 assert(numRootsSeen == tracked.
count);
4914 llvm::IRBuilder<> &B) {
4915 std::deque<Value *> todo = {v};
4916 SmallVector<Value *, 1> done;
4917 while (todo.size()) {
4918 auto cur = todo.front();
4920 auto T = cur->getType();
4925 done.push_back(cur);
4928 if (
auto ST = dyn_cast<StructType>(T)) {
4929 for (
size_t i = 0, E = ST->getNumElements(); i < E; i++) {
4930 auto T2 = ST->getElementType(E - 1 - i);
4932 auto V2 = B.CreateExtractValue(cur, E - 1 - i);
4933 todo.push_front(V2);
4938 if (
auto AT = dyn_cast<ArrayType>(T)) {
4939 for (
size_t i = 0, E = AT->getNumElements(); i < E; i++) {
4940 todo.push_front(B.CreateExtractValue(cur, E - 1 - i));
4944 if (
auto VT = dyn_cast<VectorType>(T)) {
4945 assert(!VT->getElementCount().isScalable());
4946 size_t numElems = VT->getElementCount().getKnownMinValue();
4947 for (
size_t i = 0; i < numElems; i++) {
4948 todo.push_front(B.CreateExtractElement(cur, numElems - 1 - i));
4952 llvm_unreachable(
"unknown source of julia type");
static bool contains(ArrayRef< int > ar, int v)
static bool isDeallocationFunction(const llvm::StringRef name, const llvm::TargetLibraryInfo &TLI)
Return whether a given function is a known C/C++ memory deallocation function For updating below one ...
static bool isAllocationFunction(const llvm::StringRef name, const llvm::TargetLibraryInfo &TLI)
Return whether a given function is a known C/C++ memory allocation function For updating below one sh...
static bool isAllocationCall(const llvm::Value *TmpOrig, llvm::TargetLibraryInfo &TLI)
static Operation * getFunctionFromCall(CallOpInterface iface)
static std::string str(AugmentedStruct c)
static TypeTree parseTBAA(TBAAStructTypeNode AccessType, llvm::Instruction &I, const llvm::DataLayout &DL, std::shared_ptr< llvm::ModuleSlotTracker > MST)
Given a TBAA access node return the corresponding TypeTree This includes recursively parsing the acce...
static bool isMemFreeLibMFunction(llvm::StringRef str, llvm::Intrinsic::ID *ID=nullptr)
llvm::Value * get_cached_mat_width(llvm::IRBuilder<> &B, llvm::ArrayRef< llvm::Value * > trans, llvm::Value *arg_ld, llvm::Value *dim1, llvm::Value *dim2, bool cacheMat, bool byRef, bool cublas)
void addValueToCache(llvm::Value *arg, bool cache_arg, llvm::Type *ty, llvm::SmallVectorImpl< llvm::Value * > &cacheValues, llvm::IRBuilder<> &BuilderZ, const Twine &name)
llvm::Optional< bool > arePointersGuaranteedNoAlias(TargetLibraryInfo &TLI, llvm::AAResults &AA, llvm::LoopInfo &LI, llvm::Value *op0, llvm::Value *op1, bool offsetAllowed)
Function * getOrInsertDifferentialFloatMemmove(Module &M, Type *T, unsigned dstalign, unsigned srcalign, unsigned dstaddr, unsigned srcaddr, unsigned bitwidth)
llvm::Value * load_if_ref(llvm::IRBuilder<> &B, llvm::Type *intType, llvm::Value *V, bool byRef)
CallInst * CreateDealloc(llvm::IRBuilder<> &Builder, llvm::Value *ToFree)
llvm::Value * CreateReAllocation(llvm::IRBuilder<> &B, llvm::Value *prev, llvm::Type *T, llvm::Value *OuterCount, llvm::Value *InnerCount, const llvm::Twine &Name, llvm::CallInst **caller, bool ZeroMem)
AllocaInst * getBaseAndOffset(Value *ptr, size_t &offset)
Value * lookup_with_layout(IRBuilder<> &B, Type *fpType, Value *layout, Value *const base, Value *lda, Value *row, Value *col)
llvm::cl::opt< bool > EnzymeZeroCache
llvm::Value * to_blas_fp_callconv(IRBuilder<> &B, llvm::Value *V, bool byRef, Type *fpTy, IRBuilder<> &entryBuilder, llvm::Twine const &name)
void ZeroMemory(llvm::IRBuilder<> &Builder, llvm::Type *T, llvm::Value *obj, bool isTape)
llvm::SmallVector< llvm::Instruction *, 2 > PostCacheStore(llvm::StoreInst *SI, llvm::IRBuilder<> &B)
Function * getFirstFunctionDefinition(Module &M)
llvm::cl::opt< bool > EnzymeFastMath("enzyme-fast-math", cl::init(true), cl::Hidden, cl::desc("Use fast math on derivative compuation"))
llvm::Function * getOrInsertDifferentialMPI_Wait(llvm::Module &M, ArrayRef< llvm::Type * > T, Type *reqType, StringRef caller)
llvm::PointerType * getDefaultAnonymousTapeType(llvm::LLVMContext &C)
llvm::Value * getOrInsertOpFloatSum(llvm::Module &M, llvm::Type *OpPtr, llvm::Type *OpType, ConcreteType CT, llvm::Type *intType, IRBuilder<> &B2)
llvm::Function * getOrInsertDifferentialWaitallSave(llvm::Module &M, ArrayRef< llvm::Type * > T, PointerType *reqType)
Value * simplifyLoad(Value *V, size_t valSz, size_t preOffset)
bool attributeKnownFunctions(llvm::Function &F)
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &B, llvm::Intrinsic::ID ID, llvm::Type *RetTy, llvm::ArrayRef< llvm::Value * > Args, llvm::Instruction *FMFSource, const llvm::Twine &Name)
llvm::Value * get1ULP(llvm::IRBuilder<> &builder, llvm::Value *res)
void callMemcpyStridedLapack(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, llvm::ArrayRef< llvm::Value * > args, llvm::ArrayRef< llvm::OperandBundleDef > bundles)
Create function for type that performs memcpy using lapack copy.
LLVMValueRef(* EnzymeUndefinedValueForType)(LLVMModuleRef, LLVMTypeRef, uint8_t)
Function * getOrInsertDifferentialFloatMemcpy(Module &M, Type *elementType, unsigned dstalign, unsigned srcalign, unsigned dstaddr, unsigned srcaddr, unsigned bitwidth)
Create function for type that is equivalent to memcpy but adds to destination rather than a direct co...
llvm::Value * nextPowerOfTwo(llvm::IRBuilder<> &B, llvm::Value *V)
Create function to computer nearest power of two.
Function * getOrInsertMemcpyMat(Module &Mod, Type *elementType, PointerType *PT, IntegerType *IT, unsigned dstalign, unsigned srcalign)
llvm::Value * transpose(std::string floatType, IRBuilder<> &B, llvm::Value *V, bool cublas)
void EmitNoTypeError(const std::string &message, llvm::Instruction &inst, GradientUtils *gutils, llvm::IRBuilder<> &Builder2)
Value * is_lower(IRBuilder<> &B, Value *uplo, bool byRef, bool cublas)
LLVMValueRef(* CustomDeallocator)(LLVMBuilderRef, LLVMValueRef)
llvm::cl::opt< bool > EnzymeCheckDerivativeNaN("enzyme-check-nan", cl::init(false), cl::Hidden, cl::desc("Add NaN checks to all derivative intermediate values"))
llvm::Optional< BlasInfo > extractBLAS(llvm::StringRef in)
Value * CreateAllocation(IRBuilder<> &Builder, llvm::Type *T, Value *Count, const Twine &Name, CallInst **caller, Instruction **ZeroMem, bool isDefault)
LLVMValueRef(* CustomAllocator)(LLVMBuilderRef, LLVMTypeRef, LLVMValueRef, LLVMValueRef, uint8_t, LLVMValueRef *)
void dumpBlock(llvm::BasicBlock *blk)
LLVMValueRef(* CustomErrorHandler)(const char *, LLVMValueRef, ErrorType, const void *, LLVMValueRef, LLVMBuilderRef)
Function * getOrInsertExponentialAllocator(Module &M, Function *newFunc, bool ZeroInit, llvm::Type *RT)
bool notCaptured(llvm::Value *V)
Check if value if b captured.
Value * GetFunctionValFromValue(Value *fn)
bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI, ScalarEvolution &SE, llvm::LoopInfo &LI, llvm::DominatorTree &DT, llvm::Instruction *maybeReader, llvm::Instruction *maybeWriter, llvm::Loop *scope)
bool isNVLoad(const llvm::Value *V)
llvm::Value * is_left(IRBuilder<> &B, llvm::Value *side, bool byRef, bool cublas)
llvm::Value * EmitNoDerivativeError(const std::string &message, llvm::Instruction &inst, GradientUtils *gutils, llvm::IRBuilder<> &Builder2, llvm::Value *condition)
Constant * getString(Module &M, StringRef Str)
llvm::cl::opt< bool > EnzymeMemmoveWarning("enzyme-memmove-warning", cl::init(true), cl::Hidden, cl::desc("Warn if using memmove implementation as a fallback for memmove"))
void dumpTypeResults(TypeResults &TR)
bool writesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI, llvm::Instruction *maybeReader, llvm::Instruction *maybeWriter)
Return whether maybeReader can read from memory written to by maybeWriter.
Function * getOrInsertDifferentialFloatMemcpyMat(Module &Mod, Type *elementType, PointerType *PT, IntegerType *IT, IntegerType *CT, unsigned dstalign, unsigned srcalign, bool zeroSrc)
SmallVector< std::pair< Value *, size_t >, 1 > getAllLoadedValuesFrom(AllocaInst *ptr0, size_t offset, size_t valSz, bool &legal)
void dumpType(llvm::Type *ty)
void callMemcpyStridedBlas(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, llvm::ArrayRef< llvm::Value * > args, llvm::Type *copy_retty, llvm::ArrayRef< llvm::OperandBundleDef > bundles)
Create function for type that performs memcpy with a stride using blas copy.
void(* CustomRuntimeInactiveError)(LLVMBuilderRef, LLVMValueRef, LLVMValueRef)
LLVMValueRef *(* EnzymePostCacheStore)(LLVMValueRef, LLVMBuilderRef, uint64_t *size)
llvm::Value * moveSRetToFromRoots(llvm::IRBuilder<> &B, llvm::Type *jltype, llvm::Value *sret, llvm::Type *root_ty, llvm::Value *rootRet, size_t rootOffset, SRetRootMovement direction)
void copy_lower_to_upper(llvm::IRBuilder<> &B, llvm::Type *fpType, BlasInfo blas, bool byRef, llvm::Value *layout, llvm::Value *islower, llvm::Value *A, llvm::Value *lda, llvm::Value *N)
void ErrorIfRuntimeInactive(llvm::IRBuilder<> &B, llvm::Value *primal, llvm::Value *shadow, const char *Message, llvm::DebugLoc &&loc, llvm::Instruction *orig)
SmallVector< llvm::Value *, 1 > get_blas_row(llvm::IRBuilder<> &B, ArrayRef< llvm::Value * > transA, bool byRef, bool cublas)
Value * is_nonunit(IRBuilder<> &B, Value *uplo, bool byRef, bool cublas)
Function * GetFunctionFromValue(Value *fn)
llvm::SmallVector< llvm::Value *, 1 > getJuliaObjects(llvm::Value *v, llvm::IRBuilder<> &B)
void callSPMVDiagUpdate(IRBuilder<> &B, Module &M, BlasInfo blas, IntegerType *IT, Type *BlasCT, Type *BlasFPT, Type *BlasPT, Type *BlasIT, Type *fpTy, ArrayRef< Value * > args, ArrayRef< OperandBundleDef > bundles, bool byRef, bool julia_decl)
bool overwritesToMemoryReadByLoop(llvm::ScalarEvolution &SE, llvm::LoopInfo &LI, llvm::DominatorTree &DT, llvm::Instruction *maybeReader, const llvm::SCEV *LoadStart, const llvm::SCEV *LoadEnd, llvm::Instruction *maybeWriter, const llvm::SCEV *StoreStart, const llvm::SCEV *StoreEnd, llvm::Loop *scope)
void dumpValue(llvm::Value *val)
LLVMTypeRef(* EnzymeDefaultTapeType)(LLVMContextRef)
void emit_backtrace(llvm::Instruction *inst, llvm::raw_ostream &ss)
void mayExecuteAfter(llvm::SmallVectorImpl< llvm::Instruction * > &results, llvm::Instruction *inst, const llvm::SmallPtrSetImpl< Instruction * > &stores, const llvm::Loop *region)
llvm::Value * to_blas_callconv(IRBuilder<> &B, llvm::Value *V, bool byRef, bool cublas, IntegerType *julia_decl, IRBuilder<> &entryBuilder, llvm::Twine const &name)
LLVMValueRef(* EnzymeSanitizeDerivatives)(LLVMValueRef, LLVMValueRef toset, LLVMBuilderRef, LLVMValueRef)
llvm::cl::opt< bool > EnzymeRuntimeError("enzyme-runtime-error", cl::init(false), cl::Hidden, cl::desc("Emit Runtime errors instead of compile time ones"))
llvm::CallInst * getorInsertInnerProd(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, IntegerType *IT, Type *BlasPT, Type *BlasIT, Type *fpTy, llvm::ArrayRef< llvm::Value * > args, const llvm::ArrayRef< llvm::OperandBundleDef > bundles, bool byRef, bool cublas, bool julia_decl)
SmallVector< std::tuple< Instruction *, Value *, size_t >, 1 > findAllUsersOf(Value *AI)
Function * getOrInsertMemcpyStrided(Module &M, Type *elementType, PointerType *T, Type *IT, unsigned dstalign, unsigned srcalign)
static Value * constantInBoundsGEPHelper(llvm::IRBuilder<> &B, llvm::Type *type, llvm::Value *value, ArrayRef< unsigned > path)
llvm::FastMathFlags getFast()
Get LLVM fast math flags.
void copyNonJLValueInto(llvm::IRBuilder<> &B, llvm::Type *curType, llvm::Type *dstType, llvm::Value *dst, llvm::ArrayRef< unsigned > dstPrefix0, llvm::Type *srcType, llvm::Value *src, llvm::ArrayRef< unsigned > srcPrefix0, bool shouldZero)
bool notCapturedBefore(llvm::Value *V, Instruction *inst, size_t checkLoadCaptures)
static std::string tofltstr(Type *T)
Convert a floating type to a string.
void dumpModule(llvm::Module *mod)
void(* CustomZero)(LLVMBuilderRef, LLVMTypeRef, LLVMValueRef, uint8_t)
llvm::Constant * getUndefinedValueForType(llvm::Module &M, llvm::Type *T, bool forceZero)
Function * getOrInsertCheckedFree(Module &M, CallInst *call, Type *Ty, unsigned width)
std::vector< std::tuple< llvm::Type *, size_t, size_t > > parseTrueType(const llvm::MDNode *md, DerivativeMode Mode, bool const_src)
llvm::Value * is_normal(IRBuilder<> &B, llvm::Value *trans, bool byRef, bool cublas)
llvm::Value * SanitizeDerivatives(llvm::Value *val, llvm::Value *toset, llvm::IRBuilder<> &BuilderM, llvm::Value *mask)
bool collectOffset(GEPOperator *gep, const DataLayout &DL, unsigned BitWidth, MapVector< Value *, APInt > &VariableOffsets, APInt &ConstantOffset)
static llvm::StringRef getFuncName(llvm::Function *called)
static std::string getRenamedPerCallingConv(llvm::StringRef caller, llvm::StringRef callee)
llvm::cl::opt< bool > EnzymeBlasCopy
static llvm::Loop * getAncestor(llvm::Loop *R1, llvm::Loop *R2)
static bool anyJuliaObjects(llvm::Type *T)
@ Args
Return is a struct of all args.
static bool isNoAlias(const llvm::CallBase *call)
static llvm::PointerType * getUnqual(llvm::Type *T)
static bool isCertainPrint(const llvm::StringRef name)
void EmitFailure(llvm::StringRef RemarkName, const llvm::DiagnosticLocation &Loc, const llvm::Instruction *CodeRegion, Args &...args)
static llvm::PointerType * getInt8PtrTy(llvm::LLVMContext &Context, unsigned AddressSpace=0)
static llvm::Function * getIntrinsicDeclaration(llvm::Module *M, llvm::Intrinsic::ID id, llvm::ArrayRef< llvm::Type * > Tys={})
static bool isNoCapture(const llvm::CallBase *call, size_t idx)
static void addFunctionNoCapture(llvm::Function *call, size_t idx)
static bool isDebugFunction(llvm::Function *called)
static bool isPointerArithmeticInst(const llvm::Value *V, bool includephi=true, bool includebin=true)
static llvm::Value * getBaseObject(llvm::Value *V, bool offsetAllowed=true)
static llvm::MDNode * hasMetadata(const llvm::GlobalObject *O, llvm::StringRef kind)
Check if a global has metadata.
static bool isSpecialPtr(llvm::Type *Ty)
static void allInstructionsBetween(llvm::LoopInfo &LI, llvm::Instruction *inst1, llvm::Instruction *inst2, llvm::function_ref< bool(llvm::Instruction *)> f)
Call the function f for all instructions that happen between inst1 and inst2 If the function returns ...
static std::tuple< llvm::StringRef, llvm::StringRef, llvm::StringRef > tripleSplitDollar(llvm::StringRef caller)
@ RootPointerToSRetPointer
@ SRetPointerToRootPointer
static llvm::Value * CreateSelect(llvm::IRBuilder<> &Builder2, llvm::Value *cmp, llvm::Value *tval, llvm::Value *fval, const llvm::Twine &Name="")
static llvm::StringRef getFuncNameFromCall(const llvm::CallBase *op)
llvm::cl::opt< bool > EnzymeLapackCopy
llvm::cl::opt< bool > EnzymeNonPower2Cache
Concrete SubType of a given value.
bool isKnown() const
Whether this ConcreteType has information (is not unknown)
std::string str() const
Convert the ConcreteType to a string.
llvm::Type * isFloat() const
Return the floating point type, if this is a float.
EnzymeFailure(const llvm::Twine &Msg, const llvm::DiagnosticLocation &Loc, const llvm::Instruction *CodeRegion)
EnzymeWarning(const llvm::Twine &Msg, const llvm::DiagnosticLocation &Loc, const llvm::Instruction *CodeRegion)
static llvm::Value * extractMeta(llvm::IRBuilder<> &Builder, llvm::Value *Agg, unsigned off, const llvm::Twine &name="")
Helper routine to extract a nested element from a struct/array. This is.
A holder class representing the results of running TypeAnalysis on a given function.
void dump(llvm::raw_ostream &ss=llvm::errs()) const
Prints all known information.
TypeTree query(llvm::Value *val) const
The TypeTree of a particular Value.
TypeTree Lookup(size_t len, const llvm::DataLayout &dl) const
Select all submappings whose first index is in range [0, len) and remove the first index.
llvm::Type * fpType(llvm::LLVMContext &ctx, bool to_scalar=false) const
llvm::IntegerType * intType(llvm::LLVMContext &ctx) const
CountTrackedPointers(llvm::Type *T)