28#include "llvm/ADT/MapVector.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallPtrSet.h"
32#include "llvm/Analysis/AliasAnalysis.h"
33#include "llvm/Analysis/ValueTracking.h"
34#include "llvm/IR/Attributes.h"
35#include "llvm/IR/Function.h"
36#include "llvm/IR/IRBuilder.h"
37#include "llvm/IR/IntrinsicInst.h"
38#include "llvm/IR/Module.h"
39#include "llvm/IR/Operator.h"
40#include "llvm/IR/Type.h"
42#include "llvm/IR/Function.h"
43#include "llvm/IR/IntrinsicInst.h"
45#include "llvm/IR/ValueMap.h"
46#include "llvm/Support/Casting.h"
47#include "llvm/Support/raw_ostream.h"
49#include "llvm/Support/CommandLine.h"
51#include "llvm/ADT/SetVector.h"
52#include "llvm/ADT/StringMap.h"
54#include "llvm/IR/Dominators.h"
55#include "llvm/IR/IntrinsicsAMDGPU.h"
56#include "llvm/IR/IntrinsicsNVPTX.h"
61#if LLVM_VERSION_MAJOR >= 16
65#include "llvm/IR/DiagnosticInfo.h"
67#include "llvm/Analysis/OptimizationRemarkEmitter.h"
101 const void *, LLVMValueRef,
105llvm::SmallVector<llvm::Instruction *, 2>
PostCacheStore(llvm::StoreInst *SI,
106 llvm::IRBuilder<> &B);
109 llvm::Value *
Count,
const llvm::Twine &Name =
"",
110 llvm::CallInst **caller =
nullptr,
111 llvm::Instruction **ZeroMem =
nullptr,
112 bool isDefault =
false);
113llvm::CallInst *
CreateDealloc(llvm::IRBuilder<> &B, llvm::Value *ToFree);
114void ZeroMemory(llvm::IRBuilder<> &Builder, llvm::Type *T, llvm::Value *obj,
118 llvm::Type *T, llvm::Value *OuterCount,
119 llvm::Value *InnerCount,
120 const llvm::Twine &Name =
"",
121 llvm::CallInst **caller =
nullptr,
122 bool ZeroMem =
false);
127extern llvm::StringMap<std::function<llvm::Value *(
128 llvm::IRBuilder<> &, llvm::CallInst *, llvm::ArrayRef<llvm::Value *>,
132template <
typename...
Args>
134 const llvm::DiagnosticLocation &Loc,
135 const llvm::BasicBlock *BB,
const Args &...args) {
138 if (Ctx.getDiagHandlerPtr()->isPassedOptRemarkEnabled(
"enzyme")) {
140 llvm::raw_string_ostream ss(
str);
142 auto R = llvm::OptimizationRemark(
"enzyme", RemarkName, Loc, BB)
148 (llvm::errs() << ... << args) <<
"\n";
151template <
typename...
Args>
152void EmitWarning(llvm::StringRef RemarkName,
const llvm::Instruction &I,
153 const Args &...args) {
154 EmitWarning(RemarkName, I.getDebugLoc(), I.getParent(), args...);
159 EnzymeWarning(
const llvm::Twine &Msg,
const llvm::DiagnosticLocation &Loc,
160 const llvm::Instruction *CodeRegion);
161 EnzymeWarning(
const llvm::Twine &Msg,
const llvm::DiagnosticLocation &Loc,
162 const llvm::Function *CodeRegion);
165template <
typename...
Args>
167 const Args &...args) {
168 llvm::LLVMContext &Ctx = F.getContext();
170 llvm::raw_string_ostream ss(
str);
172 auto R = llvm::OptimizationRemark(
"enzyme", RemarkName, &F) << ss.str();
173 Ctx.diagnose((
EnzymeWarning(ss.str(), F.getSubprogram(), &F)));
176template <
typename...
Args>
177void EmitWarning(llvm::StringRef RemarkName,
const llvm::Function &F,
178 const Args &...args) {
179 llvm::LLVMContext &Ctx = F.getContext();
180 if (Ctx.getDiagHandlerPtr()->isPassedOptRemarkEnabled(
"enzyme")) {
182 llvm::raw_string_ostream ss(
str);
184 auto R = llvm::OptimizationRemark(
"enzyme", RemarkName, &F) << ss.str();
188 (llvm::errs() << ... << args) <<
"\n";
193 EnzymeFailure(
const llvm::Twine &Msg,
const llvm::DiagnosticLocation &Loc,
194 const llvm::Instruction *CodeRegion);
195 EnzymeFailure(
const llvm::Twine &Msg,
const llvm::DiagnosticLocation &Loc,
196 const llvm::Function *CodeRegion);
202template <
typename...
Args>
204 const llvm::DiagnosticLocation &Loc,
205 const llvm::Instruction *CodeRegion,
Args &...args) {
206 std::string *
str =
new std::string();
207 llvm::raw_string_ostream ss(*
str);
209 CodeRegion->getContext().diagnose(
213template <
typename...
Args>
215 const llvm::DiagnosticLocation &Loc,
216 const llvm::Function *CodeRegion,
Args &...args) {
217 std::string *
str =
new std::string();
218 llvm::raw_string_ostream ss(*
str);
220 CodeRegion->getContext().diagnose(
224template <
typename...
Args>
229 EmitFailure(RemarkName, FirstFunc->getSubprogram(), FirstFunc, args...);
232 std::string *
str =
new std::string();
233 llvm::raw_string_ostream ss(*
str);
235 llvm::report_fatal_error(llvm::StringRef(*
str));
240 if (llvm::CallInst *CI = llvm::dyn_cast<llvm::CallInst>(val)) {
241 return CI->getCalledFunction();
249 llvm::Instruction &inst,
251 llvm::Value *condition =
nullptr);
262template <
typename T>
static inline T
max(T a, T b) {
268template <
typename T>
static inline T
min(T a, T b) {
276static inline std::string
to_string(
const std::set<T> &us) {
278 for (
const auto &y : us)
279 s += std::to_string(y) +
",";
285template <
typename T,
typename N>
287 const llvm::ValueMap<T, N> &o,
288 llvm::function_ref<
bool(
const llvm::Value *)> shouldPrint = [](T) {
291 llvm::errs() <<
"<begin dump>\n";
293 if (shouldPrint(a.first))
294 llvm::errs() <<
"key=" << *a.first <<
" val=" << *a.second <<
"\n";
296 llvm::errs() <<
"</end dump>\n";
301static inline void dumpSet(
const llvm::SmallPtrSetImpl<T *> &o) {
302 llvm::errs() <<
"<begin dump>\n";
304 llvm::errs() << *a <<
"\n";
305 llvm::errs() <<
"</end dump>\n";
309static inline void dumpSet(
const llvm::SetVector<T *> &o) {
310 llvm::errs() <<
"<begin dump>\n";
312 llvm::errs() << *a <<
"\n";
313 llvm::errs() <<
"</end dump>\n";
317static inline llvm::Instruction *
319 for (llvm::Instruction *I = Z->getNextNode(); I; I = I->getNextNode())
320 if (!llvm::isa<llvm::DbgInfoIntrinsic>(I))
326static inline llvm::Instruction *
331 llvm::errs() << *Z->getParent() <<
"\n";
332 llvm::errs() << *Z <<
"\n";
333 llvm_unreachable(
"No valid subsequent non debug instruction");
339static inline llvm::MDNode *
hasMetadata(
const llvm::GlobalObject *O,
340 llvm::StringRef kind) {
341 return O->getMetadata(kind);
345static inline llvm::MDNode *
hasMetadata(
const llvm::Instruction *O,
346 llvm::StringRef kind) {
347 return O->getMetadata(kind);
349static inline llvm::MDNode *
hasMetadata(
const llvm::Instruction *O,
351 return O->getMetadata(kind);
432 llvm_unreachable(
"illegal valuetype");
435static inline llvm::raw_ostream &
operator<<(llvm::raw_ostream &os,
443 return "ForwardMode";
445 return "ForwardModeError";
447 return "ForwardModeSplit";
449 return "ReverseModePrimal";
451 return "ReverseModeGradient";
453 return "ReverseModeCombined";
455 llvm_unreachable(
"illegal derivative mode");
458static inline llvm::raw_ostream &
operator<<(llvm::raw_ostream &os,
475 assert(0 &&
"illegal diffetype");
480static inline llvm::raw_ostream &
operator<<(llvm::raw_ostream &os,
489 return "ArgsWithReturn";
491 return "ArgsWithTwoReturns";
495 return "TapeAndReturn";
497 return "TapeAndTwoReturns";
507 llvm_unreachable(
"illegal ReturnType");
510static inline llvm::raw_ostream &
operator<<(llvm::raw_ostream &os,
520 bool integersAreConstant,
521 std::set<llvm::Type *> &seen) {
523 if (seen.find(arg) != seen.end())
527 if (arg->isVoidTy() || arg->isEmptyTy()) {
531 if (arg->isPointerTy()) {
532#if LLVM_VERSION_MAJOR >= 17
535#if LLVM_VERSION_MAJOR >= 15
536 if (!arg->getContext().supportsTypedPointers()) {
539#elif LLVM_VERSION_MAJOR >= 13
540 if (arg->isOpaquePointerTy()) {
544 switch (
whatType(arg->getPointerElementType(), mode, integersAreConstant,
553 llvm_unreachable(
"impossible case");
556 llvm::errs() <<
"arg: " << *arg <<
"\n";
557 assert(0 &&
"Cannot handle type0");
560 }
else if (arg->isArrayTy()) {
561 return whatType(llvm::cast<llvm::ArrayType>(arg)->getElementType(), mode,
562 integersAreConstant, seen);
563 }
else if (arg->isStructTy()) {
564 auto st = llvm::cast<llvm::StructType>(arg);
565 if (st->getNumElements() == 0)
569 for (
unsigned i = 0; i < st->getNumElements(); ++i) {
571 whatType(st->getElementType(i), mode, integersAreConstant, seen);
583 llvm_unreachable(
"impossible case");
597 llvm_unreachable(
"impossible case");
603 llvm_unreachable(
"impossible case");
607 }
else if (arg->isIntOrIntVectorTy() || arg->isFunctionTy()) {
609 }
else if (arg->isFPOrFPVectorTy()) {
617 llvm::errs() <<
"arg: " << *arg <<
"\n";
618 assert(0 &&
"Cannot handle type");
623llvm::Value *
get1ULP(llvm::IRBuilder<> &builder, llvm::Value *res);
626 std::set<llvm::Type *> seen;
627 return whatType(arg, mode,
true, seen);
632 for (
const auto a : inst->users()) {
633 if (llvm::isa<llvm::ReturnInst>(a))
642 assert(T->isFPOrFPVectorTy());
643 if (
auto ty = llvm::dyn_cast<llvm::VectorType>(T)) {
644 return llvm::VectorType::get(
FloatToIntTy(ty->getElementType()),
645 ty->getElementCount());
648 return llvm::IntegerType::get(T->getContext(), 16);
650 return llvm::IntegerType::get(T->getContext(), 16);
652 return llvm::IntegerType::get(T->getContext(), 32);
654 return llvm::IntegerType::get(T->getContext(), 64);
655 if (T->isX86_FP80Ty())
656 return llvm::IntegerType::get(T->getContext(), 80);
658 return llvm::IntegerType::get(T->getContext(), 128);
659 assert(0 &&
"unknown floating point type");
666 assert(T->isIntOrIntVectorTy());
667 if (
auto ty = llvm::dyn_cast<llvm::VectorType>(T)) {
668 return llvm::VectorType::get(
IntToFloatTy(ty->getElementType()),
669 ty->getElementCount());
671 if (
auto ty = llvm::dyn_cast<llvm::IntegerType>(T)) {
672 switch (ty->getBitWidth()) {
674 return llvm::Type::getHalfTy(T->getContext());
677 return llvm::Type::getFloatTy(T->getContext());
679 return llvm::Type::getDoubleTy(T->getContext());
681 return llvm::Type::getX86_FP80Ty(T->getContext());
683 return llvm::Type::getFP128Ty(T->getContext());
686 assert(0 &&
"unknown int to floating point type");
693 if (called->getName() ==
"llvm.enzyme.lifetime_start" ||
694 called->getName() ==
"llvm.enzyme.lifetime_end") {
697 switch (called->getIntrinsicID()) {
698 case llvm::Intrinsic::dbg_declare:
699 case llvm::Intrinsic::dbg_value:
700 case llvm::Intrinsic::dbg_label:
701#if LLVM_VERSION_MAJOR <= 16
702 case llvm::Intrinsic::dbg_addr:
704 case llvm::Intrinsic::lifetime_start:
705 case llvm::Intrinsic::lifetime_end:
713static inline bool startsWith(llvm::StringRef
string, llvm::StringRef prefix) {
714#if LLVM_VERSION_MAJOR >= 18
715 return string.starts_with(prefix);
717 return string.startswith(prefix);
721static inline bool endsWith(llvm::StringRef
string, llvm::StringRef suffix) {
722#if LLVM_VERSION_MAJOR >= 18
723 return string.ends_with(suffix);
725 return string.endswith(suffix);
730 if (name ==
"printf" || name ==
"puts" || name ==
"fprintf" ||
731 name ==
"putchar" || name ==
"fputc" ||
733 "_ZStlsISt11char_traitsIcEERSt13basic_ostreamIcT_ES5_") ||
737 startsWith(name,
"_ZN3std2io5stdio6_print") ||
751 llvm::Type *
fpType(llvm::LLVMContext &ctx,
bool to_scalar =
false)
const;
752 llvm::IntegerType *
intType(llvm::LLVMContext &ctx)
const;
755#if LLVM_VERSION_MAJOR >= 16
756std::optional<BlasInfo>
extractBLAS(llvm::StringRef in);
758llvm::Optional<BlasInfo>
extractBLAS(llvm::StringRef in);
761std::vector<std::tuple<llvm::Type *, size_t, size_t>>
767 llvm::Module &M, llvm::Type *T,
unsigned dstalign,
unsigned srcalign,
768 unsigned dstaddr,
unsigned srcaddr,
unsigned bitwidth);
772 llvm::ArrayRef<llvm::Value *> args,
773 llvm::Type *cublas_retty,
774 llvm::ArrayRef<llvm::OperandBundleDef> bundles);
778 BlasInfo blas, llvm::ArrayRef<llvm::Value *> args,
779 llvm::ArrayRef<llvm::OperandBundleDef> bundles);
782 llvm::IntegerType *IT, llvm::Type *BlasCT,
783 llvm::Type *BlasFPT, llvm::Type *BlasPT,
784 llvm::Type *BlasIT, llvm::Type *fpTy,
785 llvm::ArrayRef<llvm::Value *> args,
786 const llvm::ArrayRef<llvm::OperandBundleDef> bundles,
787 bool byRef,
bool julia_decl);
791 llvm::IntegerType *IT, llvm::Type *BlasPT,
792 llvm::Type *BlasIT, llvm::Type *fpTy,
793 llvm::ArrayRef<llvm::Value *> args,
794 const llvm::ArrayRef<llvm::OperandBundleDef> bundles,
795 bool byRef,
bool cublas,
bool julia_decl);
799 llvm::Type *elementType,
800 llvm::PointerType *T, llvm::Type *IT,
801 unsigned dstalign,
unsigned srcalign);
805 llvm::PointerType *PT,
806 llvm::IntegerType *IT,
unsigned dstalign,
810 llvm::Module &M, llvm::Type *elementType, llvm::PointerType *PT,
811 llvm::IntegerType *IT, llvm::IntegerType *CT,
unsigned dstalign,
812 unsigned srcalign,
bool zeroSrc);
817 llvm::Module &M, llvm::Type *T,
unsigned dstalign,
unsigned srcalign,
818 unsigned dstaddr,
unsigned srcaddr,
unsigned bitwidth);
821 llvm::Type *Type,
unsigned width);
825 llvm::ArrayRef<llvm::Type *> T,
827 llvm::StringRef caller);
833template <
typename K,
typename V>
834static inline typename std::map<K, V>::iterator
836 auto found = map.find(key);
837 if (found != map.end()) {
840 return map.emplace(key, val).first;
844template <
typename K,
typename V>
845static inline typename std::map<K, V>::iterator
847 auto found = map.find(key);
848 if (found != map.end()) {
851 return map.emplace(key, val).first;
854template <
typename K,
typename V>
856 auto found = map.find(key);
857 if (found == map.end())
859 V *val = &found->second;
863#include "llvm/IR/CFG.h"
870 llvm::function_ref<
bool(llvm::Instruction *)> f) {
872 for (
auto uinst = inst->getNextNode(); uinst !=
nullptr;
873 uinst = uinst->getNextNode()) {
878 std::deque<llvm::BasicBlock *> todo;
879 std::set<llvm::BasicBlock *> done;
880 for (
auto suc : llvm::successors(inst->getParent())) {
883 while (todo.size()) {
884 auto BB = todo.front();
889 for (
auto &ni : *BB) {
895 for (
auto suc : llvm::successors(BB)) {
905 llvm::function_ref<
bool(llvm::Instruction *)> f) {
907 for (
auto uinst = inst->getPrevNode(); uinst !=
nullptr;
908 uinst = uinst->getPrevNode()) {
913 std::deque<llvm::BasicBlock *> todo;
914 std::set<llvm::BasicBlock *> done;
915 for (
auto suc : llvm::predecessors(inst->getParent())) {
918 while (todo.size()) {
919 auto BB = todo.front();
925 llvm::BasicBlock::reverse_iterator I = BB->rbegin(), E = BB->rend();
926 for (; I != E; ++I) {
932 for (
auto suc : llvm::predecessors(BB)) {
942 llvm::function_ref<
bool(llvm::Instruction *)> f) {
944 for (
auto uinst = inst->getPrevNode(); uinst !=
nullptr;
945 uinst = uinst->getPrevNode()) {
950 std::deque<llvm::BasicBlock *> todo;
951 std::set<llvm::BasicBlock *> done;
952 for (
auto suc : llvm::predecessors(inst->getParent())) {
955 while (todo.size()) {
956 auto BB = todo.front();
962 if (DT.properlyDominates(BB, inst->getParent())) {
963 llvm::BasicBlock::reverse_iterator I = BB->rbegin(), E = BB->rend();
964 for (; I != E; ++I) {
970 for (
auto suc : llvm::predecessors(BB)) {
981 llvm::function_ref<
bool(llvm::Instruction *)> f,
982 llvm::function_ref<
void()> preEntry) {
984 for (
auto uinst = inst->getPrevNode(); uinst !=
nullptr;
985 uinst = uinst->getPrevNode()) {
986 if (
auto II = llvm::dyn_cast<llvm::IntrinsicInst>(uinst)) {
987 if (II->getIntrinsicID() == llvm::Intrinsic::amdgcn_s_barrier) {
990#if LLVM_VERSION_MAJOR > 20
991 if (II->getIntrinsicID() ==
992 llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all) {
996 if (II->getIntrinsicID() == llvm::Intrinsic::nvvm_barrier0) {
1005 std::deque<llvm::BasicBlock *> todo;
1006 std::set<llvm::BasicBlock *> done;
1007 for (
auto suc : llvm::predecessors(inst->getParent())) {
1008 todo.push_back(suc);
1010 while (todo.size()) {
1011 auto BB = todo.front();
1018 llvm::BasicBlock::reverse_iterator I = BB->rbegin(), E = BB->rend();
1019 for (; I != E; ++I) {
1020 if (
auto II = llvm::dyn_cast<llvm::IntrinsicInst>(&*I)) {
1021 if (II->getIntrinsicID() == llvm::Intrinsic::amdgcn_s_barrier) {
1025#if LLVM_VERSION_MAJOR > 20
1026 if (II->getIntrinsicID() ==
1027 llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all) {
1029 if (II->getIntrinsicID() == llvm::Intrinsic::nvvm_barrier0) {
1041 for (
auto suc : llvm::predecessors(BB)) {
1042 todo.push_back(suc);
1044 if (&BB->getParent()->getEntryBlock() == BB) {
1051#include "llvm/Analysis/LoopInfo.h"
1053static inline llvm::Loop *
getAncestor(llvm::Loop *R1, llvm::Loop *R2) {
1056 for (llvm::Loop *L1 = R1; L1; L1 = L1->getParentLoop())
1057 for (llvm::Loop *L2 = R2; L2; L2 = L2->getParentLoop()) {
1068 llvm::Instruction *inst,
1069 const llvm::SmallPtrSetImpl<llvm::Instruction *> &stores,
1070 const llvm::Loop *region);
1074 llvm::TargetLibraryInfo &TLI,
1075 llvm::Instruction *maybeReader,
1076 llvm::Instruction *maybeWriter);
1090 llvm::TargetLibraryInfo &TLI,
1091 llvm::ScalarEvolution &SE, llvm::LoopInfo &LI,
1092 llvm::DominatorTree &DT,
1093 llvm::Instruction *maybeReader,
1094 llvm::Instruction *maybeWriter,
1095 llvm::Loop *scope =
nullptr);
1100 llvm::Instruction *inst2,
1101 llvm::function_ref<
bool(llvm::Instruction *)> f) {
1102 assert(inst1->getParent()->getParent() == inst2->getParent()->getParent());
1103 for (
auto uinst = inst1->getNextNode(); uinst !=
nullptr;
1104 uinst = uinst->getNextNode()) {
1111 std::set<llvm::Instruction *> instructions;
1113 llvm::Loop *l1 = LI.getLoopFor(inst1->getParent());
1114 while (l1 && !l1->contains(inst2->getParent()))
1115 l1 = l1->getParentLoop();
1119 std::deque<llvm::BasicBlock *> todo;
1120 std::set<llvm::BasicBlock *> done;
1121 for (
auto suc : llvm::successors(inst1->getParent())) {
1122 todo.push_back(suc);
1124 while (todo.size()) {
1125 auto BB = todo.front();
1131 for (
auto &ni : *BB) {
1132 instructions.insert(&ni);
1134 for (
auto suc : llvm::successors(BB)) {
1135 if (!l1 || suc != l1->getHeader()) {
1136 todo.push_back(suc);
1143 if (instructions.find(I) == instructions.end())
1167#if LLVM_VERSION_MAJOR >= 17
1168 return llvm::PointerType::get(T->getContext(),
AddressSpace);
1184 using namespace llvm;
1185 auto i64 = Type::getInt64Ty(Context);
1193 Type::getInt8Ty(Context),
1196 return StructType::get(Context, types,
false);
1199template <MPI_Elem E,
bool Po
inter = true>
1202 using namespace llvm;
1203 auto i64 = Type::getInt64Ty(V->getContext());
1204 auto i32 = Type::getInt32Ty(V->getContext());
1205 auto c0_64 = ConstantInt::get(i64, 0);
1208 return B.CreateInBoundsGEP(T, V,
1209 {c0_64, ConstantInt::get(i32, (uint64_t)E)});
1211 return B.CreateExtractValue(V, {(unsigned)E});
1217 llvm::Type *intType, llvm::IRBuilder<> &B2);
1226 assert(0 &&
"attempted to delete value with remaining handle use");
1227 llvm_unreachable(
"attempted to delete value with remaining handle use");
1231 setValPtr(new_value);
1237 const llvm::Function *called =
nullptr;
1238 using namespace llvm;
1239 const llvm::Value *callVal;
1240 callVal = op->getCalledOperand();
1242 if (
auto castinst = dyn_cast<ConstantExpr>(callVal))
1243 if (castinst->isCast()) {
1244 callVal = castinst->getOperand(0);
1247 if (
auto fn = dyn_cast<Function>(callVal)) {
1251 if (
auto alias = dyn_cast<GlobalAlias>(callVal)) {
1252 callVal = alias->getAliasee();
1257 return called ?
const_cast<llvm::Function *
>(called) :
nullptr;
1261 if (called->hasFnAttribute(
"enzyme_math"))
1262 return called->getFnAttribute(
"enzyme_math").getValueAsString();
1263 else if (called->hasFnAttribute(
"enzyme_allocator"))
1264 return "enzyme_allocator";
1266 return called->getName();
1271 op->getAttributes().getAttributes(llvm::AttributeList::FunctionIndex);
1272 if (AttrList.hasAttribute(
"enzyme_math"))
1273 return AttrList.getAttribute(
"enzyme_math").getValueAsString();
1274 if (AttrList.hasAttribute(
"enzyme_allocator"))
1275 return "enzyme_allocator";
1284 using namespace llvm;
1285 if (
auto CB = dyn_cast<CallBase>(op)) {
1287 if (called->hasFnAttribute(
"enzyme_nocache"))
1294 if (
auto I = dyn_cast<Instruction>(op))
1299 if (
auto PT = dyn_cast<PointerType>(op->getType())) {
1300 if (PT->getAddressSpace() == 11 || PT->getAddressSpace() == 13) {
1301 if (isa<CastInst>(op) || isa<GetElementPtrInst>(op))
1306 if (
auto IT = dyn_cast<IntegerType>(op->getType()))
1313#if LLVM_VERSION_MAJOR >= 16
1314static inline std::optional<size_t>
1317static inline llvm::Optional<size_t>
1322 op->getAttributes().getAttributes(llvm::AttributeList::FunctionIndex);
1323 if (AttrList.hasAttribute(
"enzyme_allocator")) {
1325 bool b = AttrList.getAttribute(
"enzyme_allocator")
1327 .getAsInteger(10, res);
1330#if LLVM_VERSION_MAJOR >= 16
1331 return std::optional<size_t>(res);
1333 return llvm::Optional<size_t>(res);
1338 if (called->hasFnAttribute(
"enzyme_allocator")) {
1340 bool b = called->getFnAttribute(
"enzyme_allocator")
1342 .getAsInteger(10, res);
1345#if LLVM_VERSION_MAJOR >= 16
1346 return std::optional<size_t>(res);
1348 return llvm::Optional<size_t>(res);
1352#if LLVM_VERSION_MAJOR >= 16
1353 return std::optional<size_t>();
1355 return llvm::Optional<size_t>();
1359template <
typename T>
1361 if (
auto MD =
hasMetadata(op,
"enzyme_deallocator_fn")) {
1362 auto md2 = llvm::cast<llvm::MDTuple>(MD);
1363 assert(md2->getNumOperands() == 1);
1364 return llvm::cast<llvm::Function>(
1365 llvm::cast<llvm::ConstantAsMetadata>(md2->getOperand(0))->getValue());
1368 if (
auto MD =
hasMetadata(called,
"enzyme_deallocator_fn")) {
1369 auto md2 = llvm::cast<llvm::MDTuple>(MD);
1370 assert(md2->getNumOperands() == 1);
1371 return llvm::cast<llvm::Function>(
1372 llvm::cast<llvm::ConstantAsMetadata>(md2->getOperand(0))->getValue());
1375 llvm::errs() <<
"dealloc fn: " << *op->getParent()->getParent()->getParent()
1377 llvm_unreachable(
"Illegal deallocatorfn");
1380template <
typename T>
1382 llvm::StringRef res =
"";
1384 op->getAttributes().getAttributes(llvm::AttributeList::FunctionIndex);
1385 if (AttrList.hasAttribute(
"enzyme_deallocator"))
1386 res = AttrList.getAttribute(
"enzyme_deaellocator").getValueAsString();
1389 if (called->hasFnAttribute(
"enzyme_deallocator"))
1390 res = called->getFnAttribute(
"enzyme_deallocator").getValueAsString();
1392 if (res.size() == 0)
1393 llvm_unreachable(
"Illegal deallocator");
1394 llvm::SmallVector<llvm::StringRef, 1> inds;
1395 res.split(inds,
",");
1396 std::vector<ssize_t> vinds;
1397 for (
auto ind : inds) {
1399 bool b = ind.getAsInteger(10, Result);
1402 vinds.push_back(Result);
1409 llvm::ArrayRef<llvm::Type *> T,
1410 llvm::PointerType *reqType);
1413 llvm::Value *shadow,
const char *Message,
1414 llvm::DebugLoc &&loc, llvm::Instruction *orig);
1421 size_t preOffset = 0);
1427 if (CI->hasFnAttr(
"enzyme_preserve_primal") ||
1432 (F->hasFnAttribute(
"enzyme_preserve_primal") ||
1439 if (funcName ==
"MPI_Wait" || funcName ==
"MPI_Waitall") {
1450 if (
auto II = llvm::dyn_cast<llvm::IntrinsicInst>(val)) {
1457 bool includephi =
true,
1458 bool includebin =
true) {
1459 if (llvm::isa<llvm::CastInst>(V) || llvm::isa<llvm::GetElementPtrInst>(V) ||
1460 (includephi && llvm::isa<llvm::PHINode>(V)))
1464 if (
auto BI = llvm::dyn_cast<llvm::BinaryOperator>(V)) {
1465 switch (BI->getOpcode()) {
1466 case llvm::BinaryOperator::Add:
1467 case llvm::BinaryOperator::Sub:
1468 case llvm::BinaryOperator::Mul:
1469 case llvm::BinaryOperator::SDiv:
1470 case llvm::BinaryOperator::UDiv:
1471 case llvm::BinaryOperator::SRem:
1472 case llvm::BinaryOperator::URem:
1473 case llvm::BinaryOperator::Or:
1474 case llvm::BinaryOperator::And:
1475 case llvm::BinaryOperator::Shl:
1476 case llvm::BinaryOperator::LShr:
1477 case llvm::BinaryOperator::AShr:
1488 if (
auto *
Call = llvm::dyn_cast<llvm::CallInst>(V)) {
1490 if (funcName ==
"julia.pointer_from_objref") {
1493 if (funcName ==
"julia.gc_loaded") {
1496 if (funcName.contains(
"__enzyme_todense")) {
1499 if (funcName.contains(
"__enzyme_ignore_derivatives")) {
1508 bool offsetAllowed =
true) {
1510 if (
auto CI = llvm::dyn_cast<llvm::CastInst>(V)) {
1511 V = CI->getOperand(0);
1513 }
else if (
auto CI = llvm::dyn_cast<llvm::GetElementPtrInst>(V)) {
1514 if (offsetAllowed || CI->hasAllZeroIndices()) {
1515 V = CI->getOperand(0);
1518 }
else if (
auto II = llvm::dyn_cast<llvm::IntrinsicInst>(V);
1520 if (offsetAllowed) {
1521 V = II->getOperand(3);
1524 }
else if (
auto CI = llvm::dyn_cast<llvm::PHINode>(V)) {
1525 if (CI->getNumIncomingValues() == 1) {
1526 V = CI->getOperand(0);
1529 }
else if (
auto *GA = llvm::dyn_cast<llvm::GlobalAlias>(V)) {
1530 if (GA->isInterposable())
1532 V = GA->getAliasee();
1534 }
else if (
auto CE = llvm::dyn_cast<llvm::ConstantExpr>(V)) {
1535 if (CE->isCast() || CE->getOpcode() == llvm::Instruction::GetElementPtr) {
1536 V = CE->getOperand(0);
1539 }
else if (
auto *
Call = llvm::dyn_cast<llvm::CallInst>(V)) {
1541 auto AttrList =
Call->getAttributes().getAttributes(
1542 llvm::AttributeList::FunctionIndex);
1543 if (AttrList.hasAttribute(
"enzyme_pointermath") && offsetAllowed) {
1545 bool failed = AttrList.getAttribute(
"enzyme_pointermath")
1547 .getAsInteger(10, res);
1550 V =
Call->getArgOperand(res);
1553 if (funcName ==
"julia.pointer_from_objref") {
1554 V =
Call->getArgOperand(0);
1557 if (funcName ==
"julia.gc_loaded") {
1558 V =
Call->getArgOperand(1);
1561 if (funcName ==
"jl_reshape_array" || funcName ==
"ijl_reshape_array") {
1562 V =
Call->getArgOperand(1);
1565 if (funcName.contains(
"__enzyme_ignore_derivatives")) {
1566 V =
Call->getArgOperand(0);
1569 if (funcName.contains(
"__enzyme_todense")) {
1570#if LLVM_VERSION_MAJOR >= 14
1571 size_t numargs =
Call->arg_size();
1573 size_t numargs =
Call->getNumArgOperands();
1576 V =
Call->getArgOperand(2);
1581 auto AttrList = fn->getAttributes().getAttributes(
1582 llvm::AttributeList::FunctionIndex);
1583 if (AttrList.hasAttribute(
"enzyme_pointermath") && offsetAllowed) {
1585 bool failed = AttrList.getAttribute(
"enzyme_pointermath")
1587 .getAsInteger(10, res);
1590 V =
Call->getArgOperand(res);
1594 for (
auto &arg : fn->args()) {
1595 if (arg.hasAttribute(llvm::Attribute::Returned)) {
1597 V =
Call->getArgOperand(arg.getArgNo());
1615 llvm::getArgumentAliasingToReturnedPointer(
Call,
false)) {
1622 if (
auto I = llvm::dyn_cast<llvm::Instruction>(V)) {
1623#if LLVM_VERSION_MAJOR >= 12
1624 auto V2 = llvm::getUnderlyingObject(I, 100);
1626 auto V2 = llvm::GetUnderlyingObject(
1627 I, I->getParent()->getParent()->getParent()->getDataLayout(), 100);
1642static inline llvm::SetVector<llvm::Value *>
1644 llvm::SmallPtrSet<llvm::Value *, 1> seen;
1645 llvm::SetVector<llvm::Value *> results;
1646 llvm::SmallVector<llvm::Value *, 1> todo = {V};
1648 while (todo.size()) {
1649 auto obj = todo.back();
1651 if (seen.contains(obj))
1655 if (
auto PN = llvm::dyn_cast<llvm::PHINode>(obj)) {
1656 for (
auto &x : PN->incoming_values()) {
1664 todo.push_back(cur);
1668 results.insert(obj);
1673static inline bool isReadOnly(
const llvm::Function *F, ssize_t arg = -1) {
1674 if (F->onlyReadsMemory())
1677 if (F->hasFnAttribute(llvm::Attribute::ReadOnly) ||
1678 F->hasFnAttribute(llvm::Attribute::ReadNone))
1681 if (F->hasParamAttribute(arg, llvm::Attribute::ReadOnly) ||
1682 F->hasParamAttribute(arg, llvm::Attribute::ReadNone))
1691static inline bool isReadOnly(
const llvm::CallBase *call, ssize_t arg = -1) {
1692 if (call->onlyReadsMemory())
1694 if (arg != -1 && call->onlyReadsMemory(arg))
1702 if (F->getCallingConv() == call->getCallingConv())
1723 if (F->hasFnAttribute(
"enzyme_LocalReadOnlyOrThrow") ||
1724 F->hasFnAttribute(
"enzyme_ReadOnlyOrThrow"))
1734 if (call->hasFnAttr(
"enzyme_LocalReadOnlyOrThrow") ||
1735 call->hasFnAttr(
"enzyme_ReadOnlyOrThrow"))
1743 if (F->getCallingConv() == call->getCallingConv())
1763 if (F->hasFnAttribute(
"enzyme_ReadOnlyOrThrow"))
1773 if (call->hasFnAttr(
"enzyme_ReadOnlyOrThrow"))
1781 if (F->getCallingConv() == call->getCallingConv())
1788static inline bool isWriteOnly(
const llvm::Function *F, ssize_t arg = -1) {
1789#if LLVM_VERSION_MAJOR >= 14
1790 if (F->onlyWritesMemory())
1793 if (F->hasFnAttribute(llvm::Attribute::WriteOnly) ||
1794 F->hasFnAttribute(llvm::Attribute::ReadNone))
1797 if (F->hasParamAttribute(arg, llvm::Attribute::WriteOnly) ||
1798 F->hasParamAttribute(arg, llvm::Attribute::ReadNone))
1804static inline bool isWriteOnly(
const llvm::CallBase *call, ssize_t arg = -1) {
1805#if LLVM_VERSION_MAJOR >= 14
1806 if (call->onlyWritesMemory())
1808 if (arg != -1 && call->onlyWritesMemory(arg))
1811 if (call->hasFnAttr(llvm::Attribute::WriteOnly) ||
1812 call->hasFnAttr(llvm::Attribute::ReadNone))
1815 if (call->dataOperandHasImpliedAttr(arg + 1, llvm::Attribute::WriteOnly) ||
1816 call->dataOperandHasImpliedAttr(arg + 1, llvm::Attribute::ReadNone))
1826 if (F->getCallingConv() == call->getCallingConv())
1832static inline bool isReadNone(
const llvm::CallBase *call, ssize_t arg = -1) {
1836static inline bool isReadNone(
const llvm::Function *F, ssize_t arg = -1) {
1840static inline bool isNoCapture(
const llvm::CallBase *call,
size_t idx) {
1841 if (call->doesNotCapture(idx))
1849 if (F->getCallingConv() == call->getCallingConv())
1850 if (idx < F->arg_size() && F->getArg(idx)->hasNoCaptureAttr())
1859 if (call->returnDoesNotAlias())
1863 if (F->returnDoesNotAlias())
1870 if (
auto CB = llvm::dyn_cast<llvm::CallBase>(val))
1872 if (
auto arg = llvm::dyn_cast<llvm::Argument>(val)) {
1873 arg->hasNoAliasAttr();
1879 if (F->hasFnAttribute(
"enzyme_no_escaping_allocation"))
1881 if (F->getName() ==
"llvm.enzyme.lifetime_start" ||
1882 F->getName() ==
"llvm.enzyme.lifetime_end") {
1885 using namespace llvm;
1886 switch (F->getIntrinsicID()) {
1887 case Intrinsic::memset:
1888 case Intrinsic::memcpy:
1889 case Intrinsic::memmove:
1890#if LLVM_VERSION_MAJOR >= 12
1891 case Intrinsic::experimental_noalias_scope_decl:
1893 case Intrinsic::objectsize:
1894 case Intrinsic::floor:
1895 case Intrinsic::ceil:
1896 case Intrinsic::trunc:
1897 case Intrinsic::rint:
1898 case Intrinsic::lrint:
1899 case Intrinsic::llrint:
1900 case Intrinsic::nearbyint:
1901 case Intrinsic::round:
1902 case Intrinsic::roundeven:
1903 case Intrinsic::lround:
1904 case Intrinsic::llround:
1905#if LLVM_VERSION_MAJOR <= 20
1906 case Intrinsic::nvvm_barrier0:
1908 case Intrinsic::nvvm_barrier_cta_sync_aligned_all:
1909 case Intrinsic::nvvm_barrier_cta_sync_aligned_count:
1911#if LLVM_VERSION_MAJOR < 22
1912 case Intrinsic::nvvm_barrier0_popc:
1913 case Intrinsic::nvvm_barrier0_and:
1914 case Intrinsic::nvvm_barrier0_or:
1916 case Intrinsic::nvvm_barrier_cta_red_and_aligned_all:
1917 case Intrinsic::nvvm_barrier_cta_red_and_aligned_count:
1918 case Intrinsic::nvvm_barrier_cta_red_or_aligned_all:
1919 case Intrinsic::nvvm_barrier_cta_red_or_aligned_count:
1920 case Intrinsic::nvvm_barrier_cta_red_popc_aligned_all:
1921 case Intrinsic::nvvm_barrier_cta_red_popc_aligned_count:
1923 case Intrinsic::nvvm_membar_cta:
1924 case Intrinsic::nvvm_membar_gl:
1925 case Intrinsic::nvvm_membar_sys:
1926 case Intrinsic::amdgcn_s_barrier:
1927 case Intrinsic::assume:
1928 case Intrinsic::lifetime_start:
1929 case Intrinsic::lifetime_end:
1930#if LLVM_VERSION_MAJOR <= 16
1931 case Intrinsic::dbg_addr:
1934 case Intrinsic::dbg_declare:
1935 case Intrinsic::dbg_value:
1936 case Intrinsic::dbg_label:
1937 case Intrinsic::invariant_start:
1938 case Intrinsic::invariant_end:
1939 case Intrinsic::var_annotation:
1940 case Intrinsic::ptr_annotation:
1941 case Intrinsic::annotation:
1942 case Intrinsic::codeview_annotation:
1943 case Intrinsic::expect:
1944 case Intrinsic::type_test:
1945 case Intrinsic::donothing:
1946 case Intrinsic::prefetch:
1947 case Intrinsic::trap:
1948 case Intrinsic::is_constant:
1949#if LLVM_VERSION_MAJOR >= 12
1950 case Intrinsic::smax:
1951 case Intrinsic::smin:
1952 case Intrinsic::umax:
1953 case Intrinsic::umin:
1955 case Intrinsic::ctlz:
1956 case Intrinsic::cttz:
1957 case Intrinsic::sadd_with_overflow:
1958 case Intrinsic::ssub_with_overflow:
1959#if LLVM_VERSION_MAJOR >= 12
1960 case Intrinsic::abs:
1962 case Intrinsic::sqrt:
1963 case Intrinsic::exp:
1964 case Intrinsic::cos:
1965 case Intrinsic::sin:
1966#if LLVM_VERSION_MAJOR >= 19
1967 case Intrinsic::tanh:
1968 case Intrinsic::cosh:
1969 case Intrinsic::sinh:
1971 case Intrinsic::copysign:
1972 case Intrinsic::fabs:
1983 call->getAttributes().getAttributes(llvm::AttributeList::FunctionIndex);
1984 if (AttrList.hasAttribute(
"enzyme_no_escaping_allocation"))
1999 bool forceZero =
false);
2002 llvm::IRBuilder<> &BuilderM,
2003 llvm::Value *mask =
nullptr);
2006 llvm::Value *cmp, llvm::Value *tval,
2008 const llvm::Twine &Name =
"") {
2009 if (
auto cmpi = llvm::dyn_cast<llvm::ConstantInt>(cmp)) {
2015 return Builder2.CreateSelect(cmp, tval, fval, Name);
2019 llvm::IRBuilder<> &Builder2,
2020 llvm::Value *idiff, llvm::Value *pres,
2021 const llvm::Twine &Name =
"") {
2022 llvm::Value *res = Builder2.CreateFMul(idiff, pres, Name);
2024 llvm::Value *zero = llvm::Constant::getNullValue(idiff->getType());
2025 if (
auto C = llvm::dyn_cast<llvm::ConstantFP>(pres))
2026 if (!C->isInfinity() && !C->isNaN())
2028 res = Builder2.CreateSelect(Builder2.CreateFCmpOEQ(idiff, zero), zero, res);
2033 llvm::IRBuilder<> &Builder2,
2034 llvm::Value *idiff, llvm::Value *pres,
2035 const llvm::Twine &Name =
"") {
2036 llvm::Value *res = Builder2.CreateFDiv(idiff, pres, Name);
2038 llvm::Value *zero = llvm::Constant::getNullValue(idiff->getType());
2039 if (
auto C = llvm::dyn_cast<llvm::ConstantFP>(pres))
2040 if (!C->isZero() && !C->isNaN())
2042 res = Builder2.CreateSelect(Builder2.CreateFCmpOEQ(idiff, zero), zero, res);
2049 const llvm::DataLayout &dl,
2050 llvm::Type **vFT =
nullptr) {
2051 using namespace llvm;
2052 if (
auto CI = dyn_cast_or_null<ConstantInt>(V)) {
2058 if (dl.getTypeSizeInBits(FT) == dl.getTypeSizeInBits(CI->getType())) {
2059 if (CI->isNegative() && CI->isMinValue(
true)) {
2066 if (
auto CV = dyn_cast_or_null<ConstantVector>(V)) {
2068 for (
size_t i = 0, end = CV->getNumOperands(); i < end; ++i) {
2072#if LLVM_VERSION_MAJOR >= 12
2073 *vFT = VectorType::get(FT, CV->getType()->getElementCount());
2075 *vFT = VectorType::get(FT, CV->getType()->getNumElements());
2081 if (
auto CV = dyn_cast_or_null<ConstantDataVector>(V)) {
2083 for (
size_t i = 0, end = CV->getNumElements(); i < end; ++i) {
2084 auto CI = CV->getElementAsAPInt(i);
2085#if LLVM_VERSION_MAJOR > 16
2089 if (CI.isNullValue())
2092 if (dl.getTypeSizeInBits(FT) !=
2093 dl.getTypeSizeInBits(CV->getElementType())) {
2097 if (!CI.isMinSignedValue()) {
2103#if LLVM_VERSION_MAJOR >= 12
2104 *vFT = VectorType::get(FT, CV->getType()->getElementCount());
2106 *vFT = VectorType::get(FT, CV->getType()->getNumElements());
2111 if (
auto BO = dyn_cast<BinaryOperator>(V)) {
2112 if (BO->getOpcode() == Instruction::And) {
2113 for (
size_t i = 0; i < 2; i++) {
2124 llvm::SmallVectorImpl<llvm::Value *> &cacheValues,
2125 llvm::IRBuilder<> &BuilderZ,
const llvm::Twine &name =
"");
2127llvm::Value *
load_if_ref(llvm::IRBuilder<> &B, llvm::Type *intType,
2128 llvm::Value *V,
bool byRef);
2131 BlasInfo blas,
bool byRef, llvm::Value *layout,
2132 llvm::Value *uplo, llvm::Value *A, llvm::Value *lda,
2138 bool cublas, llvm::IntegerType *julia_decl,
2139 llvm::IRBuilder<> &entryBuilder,
2140 llvm::Twine
const & =
"");
2142 bool byRef, llvm::Type *julia_decl,
2143 llvm::IRBuilder<> &entryBuilder,
2144 llvm::Twine
const & =
"");
2147 llvm::ArrayRef<llvm::Value *> trans,
2148 llvm::Value *arg_ld, llvm::Value *dim_1,
2149 llvm::Value *dim_2,
bool cacheMat,
bool byRef,
2152template <
typename T>
2153static inline void append(llvm::SmallVectorImpl<T> &vec) {}
2154template <
typename T,
typename... T2>
2155static inline void append(llvm::SmallVectorImpl<T> &vec, llvm::ArrayRef<T> vals,
2157 vec.append(vals.begin(), vals.end());
2158 append(vec, std::forward<T2>(ts)...);
2160template <
typename... T>
2162 llvm::SmallVector<llvm::Value *, 1> res;
2163 append(res, std::forward<T>(t)...);
2167llvm::Value *
is_normal(llvm::IRBuilder<> &B, llvm::Value *trans,
bool byRef,
2169llvm::Value *
is_left(llvm::IRBuilder<> &B, llvm::Value *side,
bool byRef,
2171llvm::Value *
is_lower(llvm::IRBuilder<> &B, llvm::Value *uplo,
bool byRef,
2173llvm::Value *
is_nonunit(llvm::IRBuilder<> &B, llvm::Value *uplo,
bool byRef,
2177 llvm::Value *layout, llvm::Value *base,
2178 llvm::Value *lda, llvm::Value *row,
2182llvm::Value *
transpose(std::string floatType, llvm::IRBuilder<> &B,
2183 llvm::Value *V,
bool cublas);
2185llvm::Value *
transpose(std::string floatType, llvm::IRBuilder<> &B,
2186 llvm::Value *V,
bool byRef,
bool cublas,
2187 llvm::IntegerType *IT, llvm::IRBuilder<> &entryBuilder,
2188 const llvm::Twine &name);
2189llvm::SmallVector<llvm::Value *, 1>
2191 llvm::ArrayRef<llvm::Value *> row,
2192 llvm::ArrayRef<llvm::Value *> col,
bool byRef,
bool cublas);
2194llvm::SmallVector<llvm::Value *, 1>
2196 bool byRef,
bool cublas);
2199#pragma clang diagnostic push
2200#pragma clang diagnostic ignored "-Wunused-variable"
2202#pragma GCC diagnostic push
2203#pragma GCC diagnostic ignored "-Wunused-variable"
2209 llvm::Attribute::AttrKind::ReadOnly,
2210 llvm::Attribute::AttrKind::WriteOnly,
2211 llvm::Attribute::AttrKind::ZExt,
2212 llvm::Attribute::AttrKind::SExt,
2213 llvm::Attribute::AttrKind::InReg,
2214 llvm::Attribute::AttrKind::ByVal,
2215#if LLVM_VERSION_MAJOR >= 12
2216 llvm::Attribute::AttrKind::ByRef,
2218 llvm::Attribute::AttrKind::Preallocated,
2219 llvm::Attribute::AttrKind::InAlloca,
2220#if LLVM_VERSION_MAJOR >= 13
2221 llvm::Attribute::AttrKind::ElementType,
2223#if LLVM_VERSION_MAJOR >= 15
2224 llvm::Attribute::AttrKind::AllocAlign,
2226 llvm::Attribute::AttrKind::NoFree,
2227 llvm::Attribute::AttrKind::Alignment,
2228 llvm::Attribute::AttrKind::StackAlignment,
2229#if LLVM_VERSION_MAJOR >= 20
2230 llvm::Attribute::AttrKind::Captures,
2232 llvm::Attribute::AttrKind::NoCapture,
2234 llvm::Attribute::AttrKind::ReadNone
2241 llvm::Attribute::AttrKind::ZExt,
2242 llvm::Attribute::AttrKind::SExt,
2243#if LLVM_VERSION_MAJOR >= 13
2244 llvm::Attribute::AttrKind::ElementType,
2246 llvm::Attribute::AttrKind::NoFree,
2247 llvm::Attribute::AttrKind::Alignment,
2248 llvm::Attribute::AttrKind::StackAlignment,
2249#if LLVM_VERSION_MAJOR >= 20
2250 llvm::Attribute::AttrKind::Captures,
2252 llvm::Attribute::AttrKind::NoCapture,
2254 llvm::Attribute::AttrKind::ReadNone,
2257#pragma clang diagnostic pop
2259#pragma GCC diagnostic pop
2262static inline llvm::Function *
2264 llvm::ArrayRef<llvm::Type *> Tys = {}) {
2265#if LLVM_VERSION_MAJOR >= 20
2266 return llvm::Intrinsic::getOrInsertDeclaration(M,
id, Tys);
2268 return llvm::Intrinsic::getDeclaration(M,
id, Tys);
2273#if LLVM_VERSION_MAJOR >= 20
2274 return &*B->getFirstNonPHIOrDbg();
2276 return B->getFirstNonPHIOrDbg();
2280static inline llvm::Instruction *
2282#if LLVM_VERSION_MAJOR >= 20
2283 return &*B->getFirstNonPHIOrDbgOrLifetime();
2285 return B->getFirstNonPHIOrDbgOrLifetime();
2290#if LLVM_VERSION_MAJOR > 20
2292 idx, llvm::Attribute::get(call->getContext(), llvm::Attribute::Captures,
2293 llvm::CaptureInfo::none().toIntValue()));
2295 call->addParamAttr(idx, llvm::Attribute::NoCapture);
2300#if LLVM_VERSION_MAJOR > 20
2302 idx, llvm::Attribute::get(call->getContext(), llvm::Attribute::Captures,
2303 llvm::CaptureInfo::none().toIntValue()));
2305 call->addParamAttr(idx, llvm::Attribute::NoCapture);
2309[[nodiscard]]
static inline llvm::AttributeList
2312 unsigned idxs = {(unsigned)idx};
2313#if LLVM_VERSION_MAJOR > 20
2314 return list.addParamAttribute(
2316 llvm::Attribute::get(ctx, llvm::Attribute::Captures,
2317 llvm::CaptureInfo::none().toIntValue()));
2319 return list.addParamAttribute(ctx, idxs, llvm::Attribute::NoCapture);
2325template <
typename Arg1,
typename...
Args>
2327 if (
auto AT = llvm::dyn_cast<llvm::ArrayType>(T))
2328 return getSubType(AT->getElementType(), args...);
2329 if (
auto VT = llvm::dyn_cast<llvm::VectorType>(T))
2330 return getSubType(VT->getElementType(), args...);
2331 if (
auto ST = llvm::dyn_cast<llvm::StructType>(T)) {
2332 assert((
int)i != -1);
2333 return getSubType(ST->getElementType(i), args...);
2335 llvm::errs() << *T <<
"\n";
2336 llvm_unreachable(
"unknown subtype");
2355 llvm::PointerType *PTy = llvm::dyn_cast<llvm::PointerType>(Ty);
2358 unsigned AS = PTy->getAddressSpace();
2362#if LLVM_VERSION_MAJOR >= 20
2364 llvm::GEPOperator *gep,
const llvm::DataLayout &DL,
unsigned BitWidth,
2365 llvm::SmallMapVector<llvm::Value *, llvm::APInt, 4> &VariableOffsets,
2366 llvm::APInt &ConstantOffset);
2370 llvm::MapVector<llvm::Value *, llvm::APInt> &VariableOffsets,
2371 llvm::APInt &ConstantOffset);
2375 llvm::Intrinsic::ID ID, llvm::Type *RetTy,
2376 llvm::ArrayRef<llvm::Value *>
Args,
2377 llvm::Instruction *FMFSource =
nullptr,
2378 const llvm::Twine &Name =
"");
2380bool isNVLoad(
const llvm::Value *V);
2386 size_t checkLoadCaptured);
2394#if LLVM_VERSION_MAJOR >= 16
2400 llvm::LoopInfo &LI, llvm::Value *op0,
2401 llvm::Value *op1,
bool offsetAllowed =
false);
2403static inline std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef>
2406 return {
"", caller,
""};
2408 auto &&[prefix, todo] = caller.split(
"$");
2409 auto &&[name, postfix] = todo.split(
"$");
2410 return std::make_tuple(prefix, name, postfix);
2414 llvm::StringRef callee) {
2423 return (
"P" + callee).str();
2425 return callee.str();
2429 return std::to_string((
size_t)T);
2432static inline llvm::Type *
2434 if (
str ==
"test_type") {
2436 llvm::SmallVector<llvm::Type *, 1> elts;
2437#if LLVM_VERSION_MAJOR >= 17
2440 elts.push_back(llvm::PointerType::get(llvm::StructType::get(*C, {}),
2443 llvm::Type *inner = llvm::StructType::get(*C, elts);
2444 llvm::SmallVector<llvm::Type *, 1> innerElts;
2445 innerElts.push_back(inner);
2446 return llvm::StructType::get(*C, innerElts);
2448 if (
str ==
"test_type2") {
2450 return llvm::ArrayType::get(llvm::Type::getInt64Ty(*C), 6);
2452 if (
str ==
"test_type3") {
2454 llvm::SmallVector<llvm::Type *, 1> elts;
2455 elts.push_back(llvm::Type::getDoubleTy(*C));
2456 return llvm::StructType::get(*C, elts);
2458 if (
str ==
"test_type4") {
2460 llvm::SmallVector<llvm::Type *, 3> elts;
2461 elts.push_back(llvm::ArrayType::get(llvm::Type::getDoubleTy(*C), 2));
2462 elts.push_back(llvm::Type::getDoubleTy(*C));
2463 elts.push_back(llvm::Type::getInt64Ty(*C));
2464 return llvm::StructType::get(*C, elts);
2466 if (
str ==
"test_type5") {
2468 llvm::SmallVector<llvm::Type *, 3> elts;
2469 elts.push_back(llvm::ArrayType::get(llvm::Type::getDoubleTy(*C), 1));
2470 elts.push_back(llvm::Type::getDoubleTy(*C));
2471 elts.push_back(llvm::Type::getInt64Ty(*C));
2472 return llvm::StructType::get(*C, elts);
2475 bool failed =
str.consumeInteger(10, idx);
2478 return (llvm::Type *)idx;
2483 bool failed =
str.consumeInteger(10, idx);
2490 if (CB->hasStructRetAttr())
2492 for (
size_t i = 0; i < CB->arg_size(); i++) {
2493 if (CB->getAttributeAtIndex(llvm::AttributeList::FirstArgIndex + i,
2494 "enzymejl_sret_union_bytes")
2497 if (CB->getAttributeAtIndex(llvm::AttributeList::FirstArgIndex + i,
2498 "enzymejl_returnRoots")
2514 llvm::Value *sret, llvm::Type *root_ty,
2515 llvm::Value *rootRet,
size_t rootOffset,
2519 llvm::Type *dstType, llvm::Value *dst,
2520 llvm::ArrayRef<unsigned> dstPrefix, llvm::Type *srcType,
2521 llvm::Value *src, llvm::ArrayRef<unsigned> srcPrefix,
2527 if (
auto ST = llvm::dyn_cast<llvm::StructType>(T)) {
2528 for (
auto elem : ST->elements()) {
2534 if (
auto AT = llvm::dyn_cast<llvm::ArrayType>(T)) {
2537 if (
auto VT = llvm::dyn_cast<llvm::VectorType>(T)) {
2544 llvm::IRBuilder<> &B);
2549llvm::SmallVector<std::tuple<llvm::Instruction *, llvm::Value *, size_t>, 1>
2553#if LLVM_VERSION_MAJOR >= 23
2554 return BB->hasTerminator();
2556 return BB->getTerminator();
static std::string str(AugmentedStruct c)
static llvm::Value * checkedMul(bool strongZero, llvm::IRBuilder<> &Builder2, llvm::Value *idiff, llvm::Value *pres, const llvm::Twine &Name="")
llvm::Value * EmitNoDerivativeError(const std::string &message, llvm::Instruction &inst, GradientUtils *gutils, llvm::IRBuilder<> &B, llvm::Value *condition=nullptr)
static llvm::StringRef getFuncName(llvm::Function *called)
static bool isIntelSubscriptIntrinsic(const llvm::IntrinsicInst &II)
llvm::Function * getOrInsertDifferentialWaitallSave(llvm::Module &M, llvm::ArrayRef< llvm::Type * > T, llvm::PointerType *reqType)
static void allDomPredecessorsOf(llvm::Instruction *inst, llvm::DominatorTree &DT, llvm::function_ref< bool(llvm::Instruction *)> f)
Call the function f for all instructions that happen before inst If the function returns true,...
llvm::Function * getOrInsertMemcpyMat(llvm::Module &M, llvm::Type *elementType, llvm::PointerType *PT, llvm::IntegerType *IT, unsigned dstalign, unsigned srcalign)
Turned out to be a faster alternatives to lapacks lacpy function.
llvm::Value * load_if_ref(llvm::IRBuilder<> &B, llvm::Type *intType, llvm::Value *V, bool byRef)
static bool isNoEscapingAllocation(const llvm::Function *F)
llvm::Value * is_nonunit(llvm::IRBuilder<> &B, llvm::Value *uplo, bool byRef, bool cublas)
static std::string getRenamedPerCallingConv(llvm::StringRef caller, llvm::StringRef callee)
void ZeroMemory(llvm::IRBuilder<> &Builder, llvm::Type *T, llvm::Value *obj, bool isTape)
llvm::cl::opt< bool > EnzymeBlasCopy
llvm::SmallVector< llvm::Value *, 1 > get_blas_row(llvm::IRBuilder<> &B, llvm::ArrayRef< llvm::Value * > trans, llvm::ArrayRef< llvm::Value * > row, llvm::ArrayRef< llvm::Value * > col, bool byRef, bool cublas)
static llvm::Loop * getAncestor(llvm::Loop *R1, llvm::Loop *R2)
static llvm::Function * getDeallocatorFnFromCall(T *op)
llvm::SmallVector< llvm::Instruction *, 2 > PostCacheStore(llvm::StoreInst *SI, llvm::IRBuilder<> &B)
static llvm::Instruction * getNextNonDebugInstructionOrNull(llvm::Instruction *Z)
Get the next non-debug instruction, if one exists.
static T max(T a, T b)
Pick the maximum value.
static llvm::PointerType * getPointerType(llvm::Type *T, unsigned AddressSpace=0)
llvm::PointerType * getDefaultAnonymousTapeType(llvm::LLVMContext &C)
void callSPMVDiagUpdate(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, llvm::IntegerType *IT, llvm::Type *BlasCT, llvm::Type *BlasFPT, llvm::Type *BlasPT, llvm::Type *BlasIT, llvm::Type *fpTy, llvm::ArrayRef< llvm::Value * > args, const llvm::ArrayRef< llvm::OperandBundleDef > bundles, bool byRef, bool julia_decl)
static void dumpMap(const llvm::ValueMap< T, N > &o, llvm::function_ref< bool(const llvm::Value *)> shouldPrint=[](T) { return true;})
Print a map, optionally with a shouldPrint function to decide to print a given value.
bool attributeKnownFunctions(llvm::Function &F)
static bool anyJuliaObjects(llvm::Type *T)
ReturnType
Potential return type of generated functions.
@ Tape
Return is a tape type.
@ ArgsWithTwoReturns
Return is a struct of all args and two of the original return.
@ TapeAndReturn
Return is a tape type and the original return.
@ Args
Return is a struct of all args.
@ TapeAndTwoReturns
Return is a tape type and the two of the original return.
@ ArgsWithReturn
Return is a struct of all args and the original return.
void EmitWarningAlways(llvm::StringRef RemarkName, const llvm::Function &F, const Args &...args)
static V * findInMap(std::map< K, V > &map, K key)
llvm::Value * is_left(llvm::IRBuilder<> &B, llvm::Value *side, bool byRef, bool cublas)
static bool isNoAlias(const llvm::CallBase *call)
static llvm::Type * FloatToIntTy(llvm::Type *T)
Convert a floating point type to an integer type of the same size.
static bool startsWith(llvm::StringRef string, llvm::StringRef prefix)
DIFFE_TYPE
Potential differentiable argument classifications.
llvm::Value * CreateReAllocation(llvm::IRBuilder<> &B, llvm::Value *prev, llvm::Type *T, llvm::Value *OuterCount, llvm::Value *InnerCount, const llvm::Twine &Name="", llvm::CallInst **caller=nullptr, bool ZeroMem=false)
llvm::Function * getOrInsertDifferentialMPI_Wait(llvm::Module &M, llvm::ArrayRef< llvm::Type * > T, llvm::Type *reqType, llvm::StringRef caller)
Create function for type that performs the derivative MPI_Wait.
void EmitNoTypeError(const std::string &, llvm::Instruction &inst, GradientUtils *gutils, llvm::IRBuilder<> &B)
static size_t convertRRootCountFromString(llvm::StringRef str)
llvm::Value * CreateAllocation(llvm::IRBuilder<> &B, llvm::Type *T, llvm::Value *Count, const llvm::Twine &Name="", llvm::CallInst **caller=nullptr, llvm::Instruction **ZeroMem=nullptr, bool isDefault=false)
llvm::CallInst * getorInsertInnerProd(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, llvm::IntegerType *IT, llvm::Type *BlasPT, llvm::Type *BlasIT, llvm::Type *fpTy, llvm::ArrayRef< llvm::Value * > args, const llvm::ArrayRef< llvm::OperandBundleDef > bundles, bool byRef, bool cublas, bool julia_decl)
static void allPredecessorsOf(llvm::Instruction *inst, llvm::function_ref< bool(llvm::Instruction *)> f)
Call the function f for all instructions that happen before inst If the function returns true,...
llvm::Value * get1ULP(llvm::IRBuilder<> &builder, llvm::Value *res)
void callMemcpyStridedLapack(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, llvm::ArrayRef< llvm::Value * > args, llvm::ArrayRef< llvm::OperandBundleDef > bundles)
Create function for type that performs memcpy using lapack copy.
bool notCapturedBefore(llvm::Value *V, llvm::Instruction *inst, size_t checkLoadCaptured)
Check if value if b captured after definition before executing inst.
static llvm::PointerType * getUnqual(llvm::Type *T)
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &B, llvm::Intrinsic::ID ID, llvm::Type *RetTy, llvm::ArrayRef< llvm::Value * > Args, llvm::Instruction *FMFSource=nullptr, const llvm::Twine &Name="")
llvm::Value * nextPowerOfTwo(llvm::IRBuilder<> &B, llvm::Value *V)
Create function to computer nearest power of two.
llvm::Function * getOrInsertCheckedFree(llvm::Module &M, llvm::CallInst *call, llvm::Type *Type, unsigned width)
llvm::Value * transpose(std::string floatType, llvm::IRBuilder<> &B, llvm::Value *V, bool cublas)
static llvm::Value * getMPIMemberPtr(llvm::IRBuilder<> &B, llvm::Value *V, llvm::Type *T)
static bool isCertainPrint(const llvm::StringRef name)
llvm::Optional< bool > arePointersGuaranteedNoAlias(llvm::TargetLibraryInfo &TLI, llvm::AAResults &AA, llvm::LoopInfo &LI, llvm::Value *op0, llvm::Value *op1, bool offsetAllowed=false)
llvm::Value * SanitizeDerivatives(llvm::Value *val, llvm::Value *toset, llvm::IRBuilder<> &BuilderM, llvm::Value *mask=nullptr)
void addValueToCache(llvm::Value *arg, bool cache_arg, llvm::Type *ty, llvm::SmallVectorImpl< llvm::Value * > &cacheValues, llvm::IRBuilder<> &BuilderZ, const llvm::Twine &name="")
void EmitFailure(llvm::StringRef RemarkName, const llvm::DiagnosticLocation &Loc, const llvm::Instruction *CodeRegion, Args &...args)
llvm::Optional< BlasInfo > extractBLAS(llvm::StringRef in)
static llvm::PointerType * getInt8PtrTy(llvm::LLVMContext &Context, unsigned AddressSpace=0)
static std::map< K, V >::iterator insert_or_assign2(std::map< K, V > &map, K key, V val)
Insert into a map.
static llvm::Function * getIntrinsicDeclaration(llvm::Module *M, llvm::Intrinsic::ID id, llvm::ArrayRef< llvm::Type * > Tys={})
llvm::Value * is_normal(llvm::IRBuilder<> &B, llvm::Value *trans, bool byRef, bool cublas)
LLVMValueRef(* CustomErrorHandler)(const char *, LLVMValueRef, ErrorType, const void *, LLVMValueRef, LLVMBuilderRef)
static bool isNoCapture(const llvm::CallBase *call, size_t idx)
static void addFunctionNoCapture(llvm::Function *call, size_t idx)
static llvm::Type * convertSRetTypeFromString(llvm::StringRef str, llvm::LLVMContext *C=nullptr)
bool notCaptured(llvm::Value *V)
Check if value if b captured.
static bool isDebugFunction(llvm::Function *called)
static bool isReturned(llvm::Instruction *inst)
Check whether this instruction is returned.
std::vector< std::tuple< llvm::Type *, size_t, size_t > > parseTrueType(const llvm::MDNode *, DerivativeMode, bool const_src)
static bool isPointerArithmeticInst(const llvm::Value *V, bool includephi=true, bool includebin=true)
static llvm::Value * getBaseObject(llvm::Value *V, bool offsetAllowed=true)
bool isNVLoad(const llvm::Value *V)
static llvm::Function * isCalledFunction(llvm::Value *val)
static llvm::MDNode * hasMetadata(const llvm::GlobalObject *O, llvm::StringRef kind)
Check if a global has metadata.
static bool containsOnlyAtMostTopBit(const llvm::Value *V, llvm::Type *FT, const llvm::DataLayout &dl, llvm::Type **vFT=nullptr)
bool writesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI, llvm::Instruction *maybeReader, llvm::Instruction *maybeWriter)
Return whether maybeReader can read from memory written to by maybeWriter.
llvm::Value * lookup_with_layout(llvm::IRBuilder<> &B, llvm::Type *fpType, llvm::Value *layout, llvm::Value *base, llvm::Value *lda, llvm::Value *row, llvm::Value *col)
static bool shouldDisableNoWrite(const llvm::CallInst *CI)
static llvm::SetVector< llvm::Value * > getBaseObjects(llvm::Value *V, bool offsetAllowed=true)
bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI, llvm::ScalarEvolution &SE, llvm::LoopInfo &LI, llvm::DominatorTree &DT, llvm::Instruction *maybeReader, llvm::Instruction *maybeWriter, llvm::Loop *scope=nullptr)
void EmitWarning(llvm::StringRef RemarkName, const llvm::DiagnosticLocation &Loc, const llvm::BasicBlock *BB, const Args &...args)
static bool isReadOnlyOrThrow(const llvm::Function *F)
static bool isSpecialPtr(llvm::Type *Ty)
void callMemcpyStridedBlas(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, llvm::ArrayRef< llvm::Value * > args, llvm::Type *cublas_retty, llvm::ArrayRef< llvm::OperandBundleDef > bundles)
Create function for type that performs memcpy with a stride using blas copy.
static bool hasNoCache(llvm::Value *op)
static llvm::StructType * getMPIHelper(llvm::LLVMContext &Context)
llvm::CallInst * CreateDealloc(llvm::IRBuilder<> &B, llvm::Value *ToFree)
static bool isReadOnly(const llvm::Function *F, ssize_t arg=-1)
static llvm::Type * getSubType(llvm::Type *T)
static llvm::Instruction * getNextNonDebugInstruction(llvm::Instruction *Z)
Get the next non-debug instruction, erring if none exists.
llvm::StringMap< std::function< llvm::Value *(llvm::IRBuilder<> &, llvm::CallInst *, llvm::ArrayRef< llvm::Value * >, GradientUtils *)> > shadowHandlers
static std::vector< ssize_t > getDeallocationIndicesFromCall(T *op)
llvm::Value * moveSRetToFromRoots(llvm::IRBuilder<> &B, llvm::Type *jltype, llvm::Value *sret, llvm::Type *root_ty, llvm::Value *rootRet, size_t rootOffset, SRetRootMovement direction)
static void allUnsyncdPredecessorsOf(llvm::Instruction *inst, llvm::function_ref< bool(llvm::Instruction *)> f, llvm::function_ref< void()> preEntry)
Call the function f for all instructions that happen before inst If the function returns true,...
static void allInstructionsBetween(llvm::LoopInfo &LI, llvm::Instruction *inst1, llvm::Instruction *inst2, llvm::function_ref< bool(llvm::Instruction *)> f)
Call the function f for all instructions that happen between inst1 and inst2 If the function returns ...
void ErrorIfRuntimeInactive(llvm::IRBuilder<> &B, llvm::Value *primal, llvm::Value *shadow, const char *Message, llvm::DebugLoc &&loc, llvm::Instruction *orig)
llvm::cl::opt< bool > EnzymeJuliaAddrLoad
static void dumpSet(const llvm::SmallPtrSetImpl< T * > &o)
Print a set.
static llvm::Function * getFunctionFromCall(T *op)
llvm::Function * getOrInsertDifferentialFloatMemmove(llvm::Module &M, llvm::Type *T, unsigned dstalign, unsigned srcalign, unsigned dstaddr, unsigned srcaddr, unsigned bitwidth)
Create function for type that performs the derivative memmove on floating point memory.
llvm::Function * getOrInsertDifferentialFloatMemcpy(llvm::Module &M, llvm::Type *T, unsigned dstalign, unsigned srcalign, unsigned dstaddr, unsigned srcaddr, unsigned bitwidth)
Create function for type that performs the derivative memcpy on floating point memory.
bool collectOffset(llvm::GEPOperator *gep, const llvm::DataLayout &DL, unsigned BitWidth, llvm::MapVector< llvm::Value *, llvm::APInt > &VariableOffsets, llvm::APInt &ConstantOffset)
llvm::Value * get_cached_mat_width(llvm::IRBuilder<> &B, llvm::ArrayRef< llvm::Value * > trans, llvm::Value *arg_ld, llvm::Value *dim_1, llvm::Value *dim_2, bool cacheMat, bool byRef, bool cublas)
@ IllegalReplaceFicticiousPHIs
static llvm::Value * checkedDiv(bool strongZero, llvm::IRBuilder<> &Builder2, llvm::Value *idiff, llvm::Value *pres, const llvm::Twine &Name="")
static T min(T a, T b)
Pick the maximum value.
static bool hasSRetRRootsOrUnionSRet(llvm::CallBase *CB)
static std::tuple< llvm::StringRef, llvm::StringRef, llvm::StringRef > tripleSplitDollar(llvm::StringRef caller)
llvm::SmallVector< llvm::Value *, 1 > getJuliaObjects(llvm::Value *v, llvm::IRBuilder<> &B)
@ RootPointerToSRetPointer
@ SRetPointerToRootPointer
static void addCallSiteNoCapture(llvm::CallBase *call, size_t idx)
static llvm::Type * IntToFloatTy(llvm::Type *T)
Convert a integer type to a floating point type of the same size.
static llvm::Instruction * getFirstNonPHIOrDbgOrLifetime(llvm::BasicBlock *B)
llvm::cl::opt< bool > EnzymePrintPerf
Print additional debug info relevant to performance.
llvm::Value * to_blas_callconv(llvm::IRBuilder<> &B, llvm::Value *V, bool byRef, bool cublas, llvm::IntegerType *julia_decl, llvm::IRBuilder<> &entryBuilder, llvm::Twine const &="")
llvm::Function * GetFunctionFromValue(llvm::Value *fn)
llvm::Value * is_lower(llvm::IRBuilder<> &B, llvm::Value *uplo, bool byRef, bool cublas)
static bool endsWith(llvm::StringRef string, llvm::StringRef suffix)
static llvm::Value * CreateSelect(llvm::IRBuilder<> &Builder2, llvm::Value *cmp, llvm::Value *tval, llvm::Value *fval, const llvm::Twine &Name="")
llvm::Value * getOrInsertOpFloatSum(llvm::Module &M, llvm::Type *OpPtr, llvm::Type *OpType, ConcreteType CT, llvm::Type *intType, llvm::IRBuilder<> &B2)
static std::string convertSRetTypeToString(llvm::Type *T)
static llvm::SmallVector< llvm::Value *, 1 > concat_values(T &&...t)
static llvm::StringRef getFuncNameFromCall(const llvm::CallBase *op)
llvm::Value * to_blas_fp_callconv(llvm::IRBuilder<> &B, llvm::Value *V, bool byRef, llvm::Type *julia_decl, llvm::IRBuilder<> &entryBuilder, llvm::Twine const &="")
llvm::Function * getOrInsertMemcpyStrided(llvm::Module &M, llvm::Type *elementType, llvm::PointerType *T, llvm::Type *IT, unsigned dstalign, unsigned srcalign)
Create function for type that performs memcpy with a stride.
static bool isReadNone(const llvm::CallBase *call, ssize_t arg=-1)
static bool isWriteOnly(const llvm::Function *F, ssize_t arg=-1)
static bool isLocalReadOnlyOrThrow(const llvm::Function *F)
llvm::cl::opt< bool > EnzymeLapackCopy
ValueType
Classification of value as an original program variable, a derivative variable, neither,...
static llvm::Instruction * getFirstNonPHIOrDbg(llvm::BasicBlock *B)
llvm::FastMathFlags getFast()
Get LLVM fast math flags.
static std::string to_string(const std::set< T > &us)
Output a set as a string.
static llvm::Attribute::AttrKind PrimalParamAttrsToPreserve[]
void copyNonJLValueInto(llvm::IRBuilder<> &B, llvm::Type *curType, llvm::Type *dstType, llvm::Value *dst, llvm::ArrayRef< unsigned > dstPrefix, llvm::Type *srcType, llvm::Value *src, llvm::ArrayRef< unsigned > srcPrefix, bool shouldZero)
void mayExecuteAfter(llvm::SmallVectorImpl< llvm::Instruction * > &results, llvm::Instruction *inst, const llvm::SmallPtrSetImpl< llvm::Instruction * > &stores, const llvm::Loop *region)
llvm::cl::opt< bool > EnzymeNonPower2Cache
static llvm::Attribute::AttrKind ShadowParamAttrsToPreserve[]
llvm::SmallVector< std::tuple< llvm::Instruction *, llvm::Value *, size_t >, 1 > findAllUsersOf(llvm::Value *AI)
static llvm::raw_ostream & operator<<(llvm::raw_ostream &os, ValueType mode)
static bool hasTerminator(llvm::BasicBlock *BB)
static std::map< K, V >::iterator insert_or_assign(std::map< K, V > &map, K &key, V &&val)
Insert into a map.
llvm::Constant * getUndefinedValueForType(llvm::Module &M, llvm::Type *T, bool forceZero=false)
static void allFollowersOf(llvm::Instruction *inst, llvm::function_ref< bool(llvm::Instruction *)> f)
Call the function f for all instructions that happen after inst If the function returns true,...
void copy_lower_to_upper(llvm::IRBuilder<> &B, llvm::Type *fpType, BlasInfo blas, bool byRef, llvm::Value *layout, llvm::Value *uplo, llvm::Value *A, llvm::Value *lda, llvm::Value *N)
static void append(llvm::SmallVectorImpl< T > &vec)
static llvm::Optional< size_t > getAllocationIndexFromCall(const llvm::CallBase *op)
llvm::Value * simplifyLoad(llvm::Value *LI, size_t valSz=0, size_t preOffset=0)
static DIFFE_TYPE whatType(llvm::Type *arg, DerivativeMode mode, bool integersAreConstant, std::set< llvm::Type * > &seen)
Attempt to automatically detect the differentiable classification based off of a given type.
llvm::Function * getFirstFunctionDefinition(llvm::Module &M)
llvm::Function * getOrInsertDifferentialFloatMemcpyMat(llvm::Module &M, llvm::Type *elementType, llvm::PointerType *PT, llvm::IntegerType *IT, llvm::IntegerType *CT, unsigned dstalign, unsigned srcalign, bool zeroSrc)
void deleted() override final
void allUsesReplacedWith(llvm::Value *new_value) override final
virtual ~AssertingReplacingVH()
AssertingReplacingVH()=default
AssertingReplacingVH(llvm::Value *new_value)
Concrete SubType of a given value.
EnzymeFailure(const llvm::Twine &Msg, const llvm::DiagnosticLocation &Loc, const llvm::Instruction *CodeRegion)
EnzymeWarning(const llvm::Twine &Msg, const llvm::DiagnosticLocation &Loc, const llvm::Instruction *CodeRegion)
bool getContext(llvm::BasicBlock *BB, LoopContext &lc)
A holder class representing the results of running TypeAnalysis on a given function.
llvm::Type * fpType(llvm::LLVMContext &ctx, bool to_scalar=false) const
llvm::IntegerType * intType(llvm::LLVMContext &ctx) const
CountTrackedPointers(llvm::Type *T)