Enzyme main
Loading...
Searching...
No Matches
Utils.h
Go to the documentation of this file.
1//===- Utils.h - Declaration of miscellaneous utilities -------------------===//
2//
3// Enzyme Project
4//
5// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9// If using this code in an academic setting, please cite the following:
10// @misc{enzymeGithub,
11// author = {William S. Moses and Valentin Churavy},
12// title = {Enzyme: High Performance Automatic Differentiation of LLVM},
13// year = {2020},
14// howpublished = {\url{https://github.com/wsmoses/Enzyme}},
15// note = {commit xxxxxxx}
16// }
17//
18//===----------------------------------------------------------------------===//
19//
20// This file declares miscellaneous utilities that are used as part of the
21// AD process.
22//
23//===----------------------------------------------------------------------===//
24
25#ifndef ENZYME_UTILS_H
26#define ENZYME_UTILS_H
27
28#include "llvm/ADT/MapVector.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallPtrSet.h"
31
32#include "llvm/Analysis/AliasAnalysis.h"
33#include "llvm/Analysis/ValueTracking.h"
34#include "llvm/IR/Attributes.h"
35#include "llvm/IR/Function.h"
36#include "llvm/IR/IRBuilder.h"
37#include "llvm/IR/IntrinsicInst.h"
38#include "llvm/IR/Module.h"
39#include "llvm/IR/Operator.h"
40#include "llvm/IR/Type.h"
41
42#include "llvm/IR/Function.h"
43#include "llvm/IR/IntrinsicInst.h"
44
45#include "llvm/IR/ValueMap.h"
46#include "llvm/Support/Casting.h"
47#include "llvm/Support/raw_ostream.h"
48
49#include "llvm/Support/CommandLine.h"
50
51#include "llvm/ADT/SetVector.h"
52#include "llvm/ADT/StringMap.h"
53
54#include "llvm/IR/Dominators.h"
55#include "llvm/IR/IntrinsicsAMDGPU.h"
56#include "llvm/IR/IntrinsicsNVPTX.h"
57
58#include <map>
59#include <set>
60
61#if LLVM_VERSION_MAJOR >= 16
62#include <optional>
63#endif
64
65#include "llvm/IR/DiagnosticInfo.h"
66
67#include "llvm/Analysis/OptimizationRemarkEmitter.h"
68
70
71class TypeResults;
72
73namespace llvm {
74class ScalarEvolution;
75}
76
77enum class ErrorType {
78 NoDerivative = 0,
79 NoShadow = 1,
81 NoType = 3,
83 InternalError = 5,
87 GetIndexError = 9,
88 NoTruncate = 10,
89 GCRewrite = 11,
90 NaNError = 12,
91};
92
93extern "C" {
94/// Print additional debug info relevant to performance
95extern llvm::cl::opt<bool> EnzymePrintPerf;
96extern llvm::cl::opt<bool> EnzymeNonPower2Cache;
97extern llvm::cl::opt<bool> EnzymeBlasCopy;
98extern llvm::cl::opt<bool> EnzymeLapackCopy;
99extern llvm::cl::opt<bool> EnzymeJuliaAddrLoad;
100extern LLVMValueRef (*CustomErrorHandler)(const char *, LLVMValueRef, ErrorType,
101 const void *, LLVMValueRef,
102 LLVMBuilderRef);
103}
104
105llvm::SmallVector<llvm::Instruction *, 2> PostCacheStore(llvm::StoreInst *SI,
106 llvm::IRBuilder<> &B);
107
108llvm::Value *CreateAllocation(llvm::IRBuilder<> &B, llvm::Type *T,
109 llvm::Value *Count, const llvm::Twine &Name = "",
110 llvm::CallInst **caller = nullptr,
111 llvm::Instruction **ZeroMem = nullptr,
112 bool isDefault = false);
113llvm::CallInst *CreateDealloc(llvm::IRBuilder<> &B, llvm::Value *ToFree);
114void ZeroMemory(llvm::IRBuilder<> &Builder, llvm::Type *T, llvm::Value *obj,
115 bool isTape);
116
117llvm::Value *CreateReAllocation(llvm::IRBuilder<> &B, llvm::Value *prev,
118 llvm::Type *T, llvm::Value *OuterCount,
119 llvm::Value *InnerCount,
120 const llvm::Twine &Name = "",
121 llvm::CallInst **caller = nullptr,
122 bool ZeroMem = false);
123
124llvm::PointerType *getDefaultAnonymousTapeType(llvm::LLVMContext &C);
125
126class GradientUtils;
127extern llvm::StringMap<std::function<llvm::Value *(
128 llvm::IRBuilder<> &, llvm::CallInst *, llvm::ArrayRef<llvm::Value *>,
129 GradientUtils *)>>
131
132template <typename... Args>
133void EmitWarning(llvm::StringRef RemarkName,
134 const llvm::DiagnosticLocation &Loc,
135 const llvm::BasicBlock *BB, const Args &...args) {
136
137 llvm::LLVMContext &Ctx = BB->getContext();
138 if (Ctx.getDiagHandlerPtr()->isPassedOptRemarkEnabled("enzyme")) {
139 std::string str;
140 llvm::raw_string_ostream ss(str);
141 (ss << ... << args);
142 auto R = llvm::OptimizationRemark("enzyme", RemarkName, Loc, BB)
143 << ss.str();
144 Ctx.diagnose(R);
145 }
146
147 if (EnzymePrintPerf)
148 (llvm::errs() << ... << args) << "\n";
149}
150
151template <typename... Args>
152void EmitWarning(llvm::StringRef RemarkName, const llvm::Instruction &I,
153 const Args &...args) {
154 EmitWarning(RemarkName, I.getDebugLoc(), I.getParent(), args...);
155}
156
157class EnzymeWarning final : public llvm::DiagnosticInfoUnsupported {
158public:
159 EnzymeWarning(const llvm::Twine &Msg, const llvm::DiagnosticLocation &Loc,
160 const llvm::Instruction *CodeRegion);
161 EnzymeWarning(const llvm::Twine &Msg, const llvm::DiagnosticLocation &Loc,
162 const llvm::Function *CodeRegion);
163};
164
165template <typename... Args>
166void EmitWarningAlways(llvm::StringRef RemarkName, const llvm::Function &F,
167 const Args &...args) {
168 llvm::LLVMContext &Ctx = F.getContext();
169 std::string str;
170 llvm::raw_string_ostream ss(str);
171 (ss << ... << args);
172 auto R = llvm::OptimizationRemark("enzyme", RemarkName, &F) << ss.str();
173 Ctx.diagnose((EnzymeWarning(ss.str(), F.getSubprogram(), &F)));
174}
175
176template <typename... Args>
177void EmitWarning(llvm::StringRef RemarkName, const llvm::Function &F,
178 const Args &...args) {
179 llvm::LLVMContext &Ctx = F.getContext();
180 if (Ctx.getDiagHandlerPtr()->isPassedOptRemarkEnabled("enzyme")) {
181 std::string str;
182 llvm::raw_string_ostream ss(str);
183 (ss << ... << args);
184 auto R = llvm::OptimizationRemark("enzyme", RemarkName, &F) << ss.str();
185 Ctx.diagnose(R);
186 }
187 if (EnzymePrintPerf)
188 (llvm::errs() << ... << args) << "\n";
189}
190
191class EnzymeFailure final : public llvm::DiagnosticInfoUnsupported {
192public:
193 EnzymeFailure(const llvm::Twine &Msg, const llvm::DiagnosticLocation &Loc,
194 const llvm::Instruction *CodeRegion);
195 EnzymeFailure(const llvm::Twine &Msg, const llvm::DiagnosticLocation &Loc,
196 const llvm::Function *CodeRegion);
197};
198
199// Forward declaration needed for EmitFailure template
200llvm::Function *getFirstFunctionDefinition(llvm::Module &M);
201
202template <typename... Args>
203void EmitFailure(llvm::StringRef RemarkName,
204 const llvm::DiagnosticLocation &Loc,
205 const llvm::Instruction *CodeRegion, Args &...args) {
206 std::string *str = new std::string();
207 llvm::raw_string_ostream ss(*str);
208 (ss << ... << args);
209 CodeRegion->getContext().diagnose(
210 (EnzymeFailure("Enzyme: " + ss.str(), Loc, CodeRegion)));
211}
212
213template <typename... Args>
214void EmitFailure(llvm::StringRef RemarkName,
215 const llvm::DiagnosticLocation &Loc,
216 const llvm::Function *CodeRegion, Args &...args) {
217 std::string *str = new std::string();
218 llvm::raw_string_ostream ss(*str);
219 (ss << ... << args);
220 CodeRegion->getContext().diagnose(
221 (EnzymeFailure("Enzyme: " + ss.str(), Loc, CodeRegion)));
222}
223
224template <typename... Args>
225void EmitFailure(llvm::StringRef RemarkName, llvm::Module &M, Args &...args) {
226 // Use the first function definition in the module as context for the
227 // diagnostic
228 if (llvm::Function *FirstFunc = getFirstFunctionDefinition(M)) {
229 EmitFailure(RemarkName, FirstFunc->getSubprogram(), FirstFunc, args...);
230 } else {
231 // Fallback if no functions in module
232 std::string *str = new std::string();
233 llvm::raw_string_ostream ss(*str);
234 (ss << ... << args);
235 llvm::report_fatal_error(llvm::StringRef(*str));
236 }
237}
238
239static inline llvm::Function *isCalledFunction(llvm::Value *val) {
240 if (llvm::CallInst *CI = llvm::dyn_cast<llvm::CallInst>(val)) {
241 return CI->getCalledFunction();
242 }
243 return nullptr;
244}
245
246class GradientUtils;
247struct RequestContext;
248llvm::Value *EmitNoDerivativeError(const std::string &message,
249 llvm::Instruction &inst,
250 GradientUtils *gutils, llvm::IRBuilder<> &B,
251 llvm::Value *condition = nullptr);
252bool EmitNoDerivativeError(const std::string &message, llvm::Value *todiff,
253 RequestContext &ctx);
254
255void EmitNoTypeError(const std::string &, llvm::Instruction &inst,
256 GradientUtils *gutils, llvm::IRBuilder<> &B);
257
258/// Get LLVM fast math flags
259llvm::FastMathFlags getFast();
260
261/// Pick the maximum value
262template <typename T> static inline T max(T a, T b) {
263 if (a > b)
264 return a;
265 return b;
266}
267/// Pick the maximum value
268template <typename T> static inline T min(T a, T b) {
269 if (a < b)
270 return a;
271 return b;
272}
273
274/// Output a set as a string
275template <typename T>
276static inline std::string to_string(const std::set<T> &us) {
277 std::string s = "{";
278 for (const auto &y : us)
279 s += std::to_string(y) + ",";
280 return s + "}";
281}
282
283/// Print a map, optionally with a shouldPrint function
284/// to decide to print a given value
285template <typename T, typename N>
286static inline void dumpMap(
287 const llvm::ValueMap<T, N> &o,
288 llvm::function_ref<bool(const llvm::Value *)> shouldPrint = [](T) {
289 return true;
290 }) {
291 llvm::errs() << "<begin dump>\n";
292 for (auto a : o) {
293 if (shouldPrint(a.first))
294 llvm::errs() << "key=" << *a.first << " val=" << *a.second << "\n";
295 }
296 llvm::errs() << "</end dump>\n";
297}
298
299/// Print a set
300template <typename T>
301static inline void dumpSet(const llvm::SmallPtrSetImpl<T *> &o) {
302 llvm::errs() << "<begin dump>\n";
303 for (auto a : o)
304 llvm::errs() << *a << "\n";
305 llvm::errs() << "</end dump>\n";
306}
307
308template <typename T>
309static inline void dumpSet(const llvm::SetVector<T *> &o) {
310 llvm::errs() << "<begin dump>\n";
311 for (auto a : o)
312 llvm::errs() << *a << "\n";
313 llvm::errs() << "</end dump>\n";
314}
315
316/// Get the next non-debug instruction, if one exists
317static inline llvm::Instruction *
318getNextNonDebugInstructionOrNull(llvm::Instruction *Z) {
319 for (llvm::Instruction *I = Z->getNextNode(); I; I = I->getNextNode())
320 if (!llvm::isa<llvm::DbgInfoIntrinsic>(I))
321 return I;
322 return nullptr;
323}
324
325/// Get the next non-debug instruction, erring if none exists
326static inline llvm::Instruction *
327getNextNonDebugInstruction(llvm::Instruction *Z) {
329 if (z)
330 return z;
331 llvm::errs() << *Z->getParent() << "\n";
332 llvm::errs() << *Z << "\n";
333 llvm_unreachable("No valid subsequent non debug instruction");
334 exit(1);
335 return nullptr;
336}
337
338/// Check if a global has metadata
339static inline llvm::MDNode *hasMetadata(const llvm::GlobalObject *O,
340 llvm::StringRef kind) {
341 return O->getMetadata(kind);
342}
343
344/// Check if an instruction has metadata
345static inline llvm::MDNode *hasMetadata(const llvm::Instruction *O,
346 llvm::StringRef kind) {
347 return O->getMetadata(kind);
348}
349static inline llvm::MDNode *hasMetadata(const llvm::Instruction *O,
350 unsigned kind) {
351 return O->getMetadata(kind);
352}
353
354/// Potential return type of generated functions
355enum class ReturnType {
356 /// Return is a struct of all args and the original return
358 /// Return is a struct of all args and two of the original return
360 /// Return is a struct of all args
361 Args,
362 /// Return is a tape type and the original return
364 /// Return is a tape type and the two of the original return
366 /// Return is a tape type
367 Tape,
369 Return,
370 Void,
371};
372
373/// Potential differentiable argument classifications
374enum class DIFFE_TYPE {
375 OUT_DIFF = 0, // add differential to an output struct. Only for scalar values
376 // in ReverseMode variants.
377 DUP_ARG = 1, // duplicate the argument and store differential inside.
378 // For references, pointers, or integers in ReverseMode variants.
379 // For all types in ForwardMode variants.
380 CONSTANT = 2, // no differential. Usable everywhere.
381 DUP_NONEED = 3 // duplicate this argument and store differential inside, but
382 // don't need the forward. Same as DUP_ARG otherwise.
383};
384
385enum class BATCH_TYPE {
386 SCALAR = 0,
387 VECTOR = 1,
388};
389
390enum class DerivativeMode {
391 ForwardMode = 0,
397};
398
399enum class ProbProgMode {
400 Likelihood = 0,
401 Trace = 1,
402 Condition = 2,
403};
404
405/// Classification of value as an original program
406/// variable, a derivative variable, neither, or both.
407/// This type is used both in differential use analysis
408/// and to describe argument bundles.
409enum class ValueType {
410 // A value that is neither a value in the original
411 // program, nor the derivative.
412 None = 0,
413 // The original program value
414 Primal = 1,
415 // The derivative value
416 Shadow = 2,
417 // Both the original program value and the shadow.
418 Both = Primal | Shadow,
419};
420
421static inline std::string to_string(ValueType mode) {
422 switch (mode) {
423 case ValueType::None:
424 return "None";
426 return "Primal";
428 return "Shadow";
429 case ValueType::Both:
430 return "Both";
431 }
432 llvm_unreachable("illegal valuetype");
433}
434
435static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
436 ValueType mode) {
437 return os << to_string(mode);
438}
439
440static inline std::string to_string(DerivativeMode mode) {
441 switch (mode) {
443 return "ForwardMode";
445 return "ForwardModeError";
447 return "ForwardModeSplit";
449 return "ReverseModePrimal";
451 return "ReverseModeGradient";
453 return "ReverseModeCombined";
454 }
455 llvm_unreachable("illegal derivative mode");
456}
457
458static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
459 DerivativeMode mode) {
460 return os << to_string(mode);
461}
462
463/// Convert DIFFE_TYPE to a string
464static inline std::string to_string(DIFFE_TYPE t) {
465 switch (t) {
467 return "OUT_DIFF";
469 return "CONSTANT";
471 return "DUP_ARG";
473 return "DUP_NONEED";
474 default:
475 assert(0 && "illegal diffetype");
476 return "";
477 }
478}
479
480static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
481 DIFFE_TYPE mode) {
482 return os << to_string(mode);
483}
484
485/// Convert ReturnType to a string
486static inline std::string to_string(ReturnType t) {
487 switch (t) {
489 return "ArgsWithReturn";
491 return "ArgsWithTwoReturns";
492 case ReturnType::Args:
493 return "Args";
495 return "TapeAndReturn";
497 return "TapeAndTwoReturns";
498 case ReturnType::Tape:
499 return "Tape";
501 return "TwoReturns";
503 return "Return";
504 case ReturnType::Void:
505 return "Void";
506 }
507 llvm_unreachable("illegal ReturnType");
508}
509
510static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
511 ReturnType mode) {
512 return os << to_string(mode);
513}
514
515#include <set>
516
517/// Attempt to automatically detect the differentiable
518/// classification based off of a given type
519static inline DIFFE_TYPE whatType(llvm::Type *arg, DerivativeMode mode,
520 bool integersAreConstant,
521 std::set<llvm::Type *> &seen) {
522 assert(arg);
523 if (seen.find(arg) != seen.end())
525 seen.insert(arg);
526
527 if (arg->isVoidTy() || arg->isEmptyTy()) {
529 }
530
531 if (arg->isPointerTy()) {
532#if LLVM_VERSION_MAJOR >= 17
533 return DIFFE_TYPE::DUP_ARG;
534#else
535#if LLVM_VERSION_MAJOR >= 15
536 if (!arg->getContext().supportsTypedPointers()) {
537 return DIFFE_TYPE::DUP_ARG;
538 }
539#elif LLVM_VERSION_MAJOR >= 13
540 if (arg->isOpaquePointerTy()) {
541 return DIFFE_TYPE::DUP_ARG;
542 }
543#endif
544 switch (whatType(arg->getPointerElementType(), mode, integersAreConstant,
545 seen)) {
547 return DIFFE_TYPE::DUP_ARG;
551 return DIFFE_TYPE::DUP_ARG;
553 llvm_unreachable("impossible case");
554 }
555 assert(arg);
556 llvm::errs() << "arg: " << *arg << "\n";
557 assert(0 && "Cannot handle type0");
559#endif
560 } else if (arg->isArrayTy()) {
561 return whatType(llvm::cast<llvm::ArrayType>(arg)->getElementType(), mode,
562 integersAreConstant, seen);
563 } else if (arg->isStructTy()) {
564 auto st = llvm::cast<llvm::StructType>(arg);
565 if (st->getNumElements() == 0)
567
568 auto ty = DIFFE_TYPE::CONSTANT;
569 for (unsigned i = 0; i < st->getNumElements(); ++i) {
570 auto midTy =
571 whatType(st->getElementType(i), mode, integersAreConstant, seen);
572 switch (midTy) {
574 switch (ty) {
578 break;
581 return ty;
583 llvm_unreachable("impossible case");
584 }
585 break;
587 switch (ty) {
590 break;
592 break;
595 return ty;
597 llvm_unreachable("impossible case");
598 }
599 break;
601 return DIFFE_TYPE::DUP_ARG;
603 llvm_unreachable("impossible case");
604 }
605 }
606 return ty;
607 } else if (arg->isIntOrIntVectorTy() || arg->isFunctionTy()) {
608 return integersAreConstant ? DIFFE_TYPE::CONSTANT : DIFFE_TYPE::DUP_ARG;
609 } else if (arg->isFPOrFPVectorTy()) {
610 return (mode == DerivativeMode::ForwardMode ||
615 } else {
616 assert(arg);
617 llvm::errs() << "arg: " << *arg << "\n";
618 assert(0 && "Cannot handle type");
620 }
621}
622
623llvm::Value *get1ULP(llvm::IRBuilder<> &builder, llvm::Value *res);
624
625static inline DIFFE_TYPE whatType(llvm::Type *arg, DerivativeMode mode) {
626 std::set<llvm::Type *> seen;
627 return whatType(arg, mode, /*intconst*/ true, seen);
628}
629
630/// Check whether this instruction is returned
631static inline bool isReturned(llvm::Instruction *inst) {
632 for (const auto a : inst->users()) {
633 if (llvm::isa<llvm::ReturnInst>(a))
634 return true;
635 }
636 return false;
637}
638
639/// Convert a floating point type to an integer type
640/// of the same size
641static inline llvm::Type *FloatToIntTy(llvm::Type *T) {
642 assert(T->isFPOrFPVectorTy());
643 if (auto ty = llvm::dyn_cast<llvm::VectorType>(T)) {
644 return llvm::VectorType::get(FloatToIntTy(ty->getElementType()),
645 ty->getElementCount());
646 }
647 if (T->isHalfTy())
648 return llvm::IntegerType::get(T->getContext(), 16);
649 if (T->isBFloatTy())
650 return llvm::IntegerType::get(T->getContext(), 16);
651 if (T->isFloatTy())
652 return llvm::IntegerType::get(T->getContext(), 32);
653 if (T->isDoubleTy())
654 return llvm::IntegerType::get(T->getContext(), 64);
655 if (T->isX86_FP80Ty())
656 return llvm::IntegerType::get(T->getContext(), 80);
657 if (T->isFP128Ty())
658 return llvm::IntegerType::get(T->getContext(), 128);
659 assert(0 && "unknown floating point type");
660 return nullptr;
661}
662
663/// Convert a integer type to a floating point type
664/// of the same size
665static inline llvm::Type *IntToFloatTy(llvm::Type *T) {
666 assert(T->isIntOrIntVectorTy());
667 if (auto ty = llvm::dyn_cast<llvm::VectorType>(T)) {
668 return llvm::VectorType::get(IntToFloatTy(ty->getElementType()),
669 ty->getElementCount());
670 }
671 if (auto ty = llvm::dyn_cast<llvm::IntegerType>(T)) {
672 switch (ty->getBitWidth()) {
673 case 16:
674 return llvm::Type::getHalfTy(T->getContext());
675 // return llvm::Type::getBFloat16Ty(T->getContext());
676 case 32:
677 return llvm::Type::getFloatTy(T->getContext());
678 case 64:
679 return llvm::Type::getDoubleTy(T->getContext());
680 case 80:
681 return llvm::Type::getX86_FP80Ty(T->getContext());
682 case 128:
683 return llvm::Type::getFP128Ty(T->getContext());
684 }
685 }
686 assert(0 && "unknown int to floating point type");
687 return nullptr;
688}
689
690static inline bool isDebugFunction(llvm::Function *called) {
691 if (!called)
692 return false;
693 if (called->getName() == "llvm.enzyme.lifetime_start" ||
694 called->getName() == "llvm.enzyme.lifetime_end") {
695 return true;
696 }
697 switch (called->getIntrinsicID()) {
698 case llvm::Intrinsic::dbg_declare:
699 case llvm::Intrinsic::dbg_value:
700 case llvm::Intrinsic::dbg_label:
701#if LLVM_VERSION_MAJOR <= 16
702 case llvm::Intrinsic::dbg_addr:
703#endif
704 case llvm::Intrinsic::lifetime_start:
705 case llvm::Intrinsic::lifetime_end:
706 return true;
707 default:
708 break;
709 }
710 return false;
711}
712
713static inline bool startsWith(llvm::StringRef string, llvm::StringRef prefix) {
714#if LLVM_VERSION_MAJOR >= 18
715 return string.starts_with(prefix);
716#else
717 return string.startswith(prefix);
718#endif // LLVM_VERSION_MAJOR
719}
720
721static inline bool endsWith(llvm::StringRef string, llvm::StringRef suffix) {
722#if LLVM_VERSION_MAJOR >= 18
723 return string.ends_with(suffix);
724#else
725 return string.endswith(suffix);
726#endif // LLVM_VERSION_MAJOR
727}
728
729static inline bool isCertainPrint(const llvm::StringRef name) {
730 if (name == "printf" || name == "puts" || name == "fprintf" ||
731 name == "putchar" || name == "fputc" ||
732 startsWith(name,
733 "_ZStlsISt11char_traitsIcEERSt13basic_ostreamIcT_ES5_") ||
734 startsWith(name, "_ZNSolsE") || startsWith(name, "_ZNSo9_M_insert") ||
735 startsWith(name, "_ZSt16__ostream_insert") ||
736 startsWith(name, "_ZNSo3put") || startsWith(name, "_ZSt4endl") ||
737 startsWith(name, "_ZN3std2io5stdio6_print") ||
738 startsWith(name, "_ZNSo5flushEv") || startsWith(name, "_ZN4core3fmt") ||
739 name == "vprintf")
740 return true;
741 return false;
742}
743
744struct BlasInfo {
745 std::string floatType;
746 std::string prefix;
747 std::string suffix;
748 std::string function;
749 bool is64;
750
751 llvm::Type *fpType(llvm::LLVMContext &ctx, bool to_scalar = false) const;
752 llvm::IntegerType *intType(llvm::LLVMContext &ctx) const;
753};
754
755#if LLVM_VERSION_MAJOR >= 16
756std::optional<BlasInfo> extractBLAS(llvm::StringRef in);
757#else
758llvm::Optional<BlasInfo> extractBLAS(llvm::StringRef in);
759#endif
760
761std::vector<std::tuple<llvm::Type *, size_t, size_t>>
762parseTrueType(const llvm::MDNode *, DerivativeMode, bool const_src);
763
764/// Create function for type that performs the derivative memcpy on floating
765/// point memory
767 llvm::Module &M, llvm::Type *T, unsigned dstalign, unsigned srcalign,
768 unsigned dstaddr, unsigned srcaddr, unsigned bitwidth);
769
770/// Create function for type that performs memcpy with a stride using blas copy
771void callMemcpyStridedBlas(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas,
772 llvm::ArrayRef<llvm::Value *> args,
773 llvm::Type *cublas_retty,
774 llvm::ArrayRef<llvm::OperandBundleDef> bundles);
775
776/// Create function for type that performs memcpy using lapack copy
777void callMemcpyStridedLapack(llvm::IRBuilder<> &B, llvm::Module &M,
778 BlasInfo blas, llvm::ArrayRef<llvm::Value *> args,
779 llvm::ArrayRef<llvm::OperandBundleDef> bundles);
780
781void callSPMVDiagUpdate(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas,
782 llvm::IntegerType *IT, llvm::Type *BlasCT,
783 llvm::Type *BlasFPT, llvm::Type *BlasPT,
784 llvm::Type *BlasIT, llvm::Type *fpTy,
785 llvm::ArrayRef<llvm::Value *> args,
786 const llvm::ArrayRef<llvm::OperandBundleDef> bundles,
787 bool byRef, bool julia_decl);
788
789llvm::CallInst *
790getorInsertInnerProd(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas,
791 llvm::IntegerType *IT, llvm::Type *BlasPT,
792 llvm::Type *BlasIT, llvm::Type *fpTy,
793 llvm::ArrayRef<llvm::Value *> args,
794 const llvm::ArrayRef<llvm::OperandBundleDef> bundles,
795 bool byRef, bool cublas, bool julia_decl);
796
797/// Create function for type that performs memcpy with a stride
798llvm::Function *getOrInsertMemcpyStrided(llvm::Module &M,
799 llvm::Type *elementType,
800 llvm::PointerType *T, llvm::Type *IT,
801 unsigned dstalign, unsigned srcalign);
802
803/// Turned out to be a faster alternatives to lapacks lacpy function
804llvm::Function *getOrInsertMemcpyMat(llvm::Module &M, llvm::Type *elementType,
805 llvm::PointerType *PT,
806 llvm::IntegerType *IT, unsigned dstalign,
807 unsigned srcalign);
808
810 llvm::Module &M, llvm::Type *elementType, llvm::PointerType *PT,
811 llvm::IntegerType *IT, llvm::IntegerType *CT, unsigned dstalign,
812 unsigned srcalign, bool zeroSrc);
813
814/// Create function for type that performs the derivative memmove on floating
815/// point memory
817 llvm::Module &M, llvm::Type *T, unsigned dstalign, unsigned srcalign,
818 unsigned dstaddr, unsigned srcaddr, unsigned bitwidth);
819
820llvm::Function *getOrInsertCheckedFree(llvm::Module &M, llvm::CallInst *call,
821 llvm::Type *Type, unsigned width);
822
823/// Create function for type that performs the derivative MPI_Wait
824llvm::Function *getOrInsertDifferentialMPI_Wait(llvm::Module &M,
825 llvm::ArrayRef<llvm::Type *> T,
826 llvm::Type *reqType,
827 llvm::StringRef caller);
828
829/// Create function to computer nearest power of two
830llvm::Value *nextPowerOfTwo(llvm::IRBuilder<> &B, llvm::Value *V);
831
832/// Insert into a map
833template <typename K, typename V>
834static inline typename std::map<K, V>::iterator
835insert_or_assign(std::map<K, V> &map, K &key, V &&val) {
836 auto found = map.find(key);
837 if (found != map.end()) {
838 map.erase(found);
839 }
840 return map.emplace(key, val).first;
841}
842
843/// Insert into a map
844template <typename K, typename V>
845static inline typename std::map<K, V>::iterator
846insert_or_assign2(std::map<K, V> &map, K key, V val) {
847 auto found = map.find(key);
848 if (found != map.end()) {
849 map.erase(found);
850 }
851 return map.emplace(key, val).first;
852}
853
854template <typename K, typename V>
855static inline V *findInMap(std::map<K, V> &map, K key) {
856 auto found = map.find(key);
857 if (found == map.end())
858 return nullptr;
859 V *val = &found->second;
860 return val;
861}
862
863#include "llvm/IR/CFG.h"
864#include <deque>
865#include <functional>
866/// Call the function f for all instructions that happen after inst
867/// If the function returns true, the iteration will early exit
868static inline void
869allFollowersOf(llvm::Instruction *inst,
870 llvm::function_ref<bool(llvm::Instruction *)> f) {
871
872 for (auto uinst = inst->getNextNode(); uinst != nullptr;
873 uinst = uinst->getNextNode()) {
874 if (f(uinst))
875 return;
876 }
877
878 std::deque<llvm::BasicBlock *> todo;
879 std::set<llvm::BasicBlock *> done;
880 for (auto suc : llvm::successors(inst->getParent())) {
881 todo.push_back(suc);
882 }
883 while (todo.size()) {
884 auto BB = todo.front();
885 todo.pop_front();
886 if (done.count(BB))
887 continue;
888 done.insert(BB);
889 for (auto &ni : *BB) {
890 if (f(&ni))
891 return;
892 if (&ni == inst)
893 break;
894 }
895 for (auto suc : llvm::successors(BB)) {
896 todo.push_back(suc);
897 }
898 }
899}
900
901/// Call the function f for all instructions that happen before inst
902/// If the function returns true, the iteration will early exit
903static inline void
904allPredecessorsOf(llvm::Instruction *inst,
905 llvm::function_ref<bool(llvm::Instruction *)> f) {
906
907 for (auto uinst = inst->getPrevNode(); uinst != nullptr;
908 uinst = uinst->getPrevNode()) {
909 if (f(uinst))
910 return;
911 }
912
913 std::deque<llvm::BasicBlock *> todo;
914 std::set<llvm::BasicBlock *> done;
915 for (auto suc : llvm::predecessors(inst->getParent())) {
916 todo.push_back(suc);
917 }
918 while (todo.size()) {
919 auto BB = todo.front();
920 todo.pop_front();
921 if (done.count(BB))
922 continue;
923 done.insert(BB);
924
925 llvm::BasicBlock::reverse_iterator I = BB->rbegin(), E = BB->rend();
926 for (; I != E; ++I) {
927 if (f(&*I))
928 return;
929 if (&*I == inst)
930 break;
931 }
932 for (auto suc : llvm::predecessors(BB)) {
933 todo.push_back(suc);
934 }
935 }
936}
937
938/// Call the function f for all instructions that happen before inst
939/// If the function returns true, the iteration will early exit
940static inline void
941allDomPredecessorsOf(llvm::Instruction *inst, llvm::DominatorTree &DT,
942 llvm::function_ref<bool(llvm::Instruction *)> f) {
943
944 for (auto uinst = inst->getPrevNode(); uinst != nullptr;
945 uinst = uinst->getPrevNode()) {
946 if (f(uinst))
947 return;
948 }
949
950 std::deque<llvm::BasicBlock *> todo;
951 std::set<llvm::BasicBlock *> done;
952 for (auto suc : llvm::predecessors(inst->getParent())) {
953 todo.push_back(suc);
954 }
955 while (todo.size()) {
956 auto BB = todo.front();
957 todo.pop_front();
958 if (done.count(BB))
959 continue;
960 done.insert(BB);
961
962 if (DT.properlyDominates(BB, inst->getParent())) {
963 llvm::BasicBlock::reverse_iterator I = BB->rbegin(), E = BB->rend();
964 for (; I != E; ++I) {
965 if (f(&*I))
966 return;
967 if (&*I == inst)
968 break;
969 }
970 for (auto suc : llvm::predecessors(BB)) {
971 todo.push_back(suc);
972 }
973 }
974 }
975}
976
977/// Call the function f for all instructions that happen before inst
978/// If the function returns true, the iteration will early exit
979static inline void
980allUnsyncdPredecessorsOf(llvm::Instruction *inst,
981 llvm::function_ref<bool(llvm::Instruction *)> f,
982 llvm::function_ref<void()> preEntry) {
983
984 for (auto uinst = inst->getPrevNode(); uinst != nullptr;
985 uinst = uinst->getPrevNode()) {
986 if (auto II = llvm::dyn_cast<llvm::IntrinsicInst>(uinst)) {
987 if (II->getIntrinsicID() == llvm::Intrinsic::amdgcn_s_barrier) {
988 return;
989 }
990#if LLVM_VERSION_MAJOR > 20
991 if (II->getIntrinsicID() ==
992 llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all) {
993 return;
994 }
995#else
996 if (II->getIntrinsicID() == llvm::Intrinsic::nvvm_barrier0) {
997 return;
998 }
999#endif
1000 }
1001 if (f(uinst))
1002 return;
1003 }
1004
1005 std::deque<llvm::BasicBlock *> todo;
1006 std::set<llvm::BasicBlock *> done;
1007 for (auto suc : llvm::predecessors(inst->getParent())) {
1008 todo.push_back(suc);
1009 }
1010 while (todo.size()) {
1011 auto BB = todo.front();
1012 todo.pop_front();
1013 if (done.count(BB))
1014 continue;
1015 done.insert(BB);
1016
1017 bool syncd = false;
1018 llvm::BasicBlock::reverse_iterator I = BB->rbegin(), E = BB->rend();
1019 for (; I != E; ++I) {
1020 if (auto II = llvm::dyn_cast<llvm::IntrinsicInst>(&*I)) {
1021 if (II->getIntrinsicID() == llvm::Intrinsic::amdgcn_s_barrier) {
1022 syncd = true;
1023 break;
1024 }
1025#if LLVM_VERSION_MAJOR > 20
1026 if (II->getIntrinsicID() ==
1027 llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all) {
1028#else
1029 if (II->getIntrinsicID() == llvm::Intrinsic::nvvm_barrier0) {
1030#endif
1031 syncd = true;
1032 break;
1033 }
1034 }
1035 if (f(&*I))
1036 return;
1037 if (&*I == inst)
1038 break;
1039 }
1040 if (!syncd) {
1041 for (auto suc : llvm::predecessors(BB)) {
1042 todo.push_back(suc);
1043 }
1044 if (&BB->getParent()->getEntryBlock() == BB) {
1045 preEntry();
1046 }
1047 }
1048 }
1049}
1050
1051#include "llvm/Analysis/LoopInfo.h"
1052
1053static inline llvm::Loop *getAncestor(llvm::Loop *R1, llvm::Loop *R2) {
1054 if (!R1 || !R2)
1055 return nullptr;
1056 for (llvm::Loop *L1 = R1; L1; L1 = L1->getParentLoop())
1057 for (llvm::Loop *L2 = R2; L2; L2 = L2->getParentLoop()) {
1058 if (L1 == L2) {
1059 return L1;
1060 }
1061 }
1062 return nullptr;
1063}
1064
1065// Add all of the stores which may execute after the instruction `inst`
1066// into the resutls vector.
1067void mayExecuteAfter(llvm::SmallVectorImpl<llvm::Instruction *> &results,
1068 llvm::Instruction *inst,
1069 const llvm::SmallPtrSetImpl<llvm::Instruction *> &stores,
1070 const llvm::Loop *region);
1071
1072/// Return whether maybeReader can read from memory written to by maybeWriter
1073bool writesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA,
1074 llvm::TargetLibraryInfo &TLI,
1075 llvm::Instruction *maybeReader,
1076 llvm::Instruction *maybeWriter);
1077
1078// A more advanced version of writesToMemoryReadBy, where the writing
1079// instruction comes after the reading function. Specifically, even if the two
1080// instructions may access the same location, this variant checks whether
1081// also checks whether ScalarEvolution ensures that a subsequent write will not
1082// overwrite the value read by the load.
1083// A simple example: the load/store might write/read from the same
1084// location. However, no store will overwrite a previous load.
1085// for(int i=0; i<N; i++) {
1086// load A[i-1]
1087// store A[i] = ...
1088// }
1089bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA,
1090 llvm::TargetLibraryInfo &TLI,
1091 llvm::ScalarEvolution &SE, llvm::LoopInfo &LI,
1092 llvm::DominatorTree &DT,
1093 llvm::Instruction *maybeReader,
1094 llvm::Instruction *maybeWriter,
1095 llvm::Loop *scope = nullptr);
1096static inline void
1097/// Call the function f for all instructions that happen between inst1 and inst2
1098/// If the function returns true, the iteration will early exit
1099allInstructionsBetween(llvm::LoopInfo &LI, llvm::Instruction *inst1,
1100 llvm::Instruction *inst2,
1101 llvm::function_ref<bool(llvm::Instruction *)> f) {
1102 assert(inst1->getParent()->getParent() == inst2->getParent()->getParent());
1103 for (auto uinst = inst1->getNextNode(); uinst != nullptr;
1104 uinst = uinst->getNextNode()) {
1105 if (f(uinst))
1106 return;
1107 if (uinst == inst2)
1108 return;
1109 }
1110
1111 std::set<llvm::Instruction *> instructions;
1112
1113 llvm::Loop *l1 = LI.getLoopFor(inst1->getParent());
1114 while (l1 && !l1->contains(inst2->getParent()))
1115 l1 = l1->getParentLoop();
1116
1117 // Do all instructions from inst1 up to first instance of inst2's start block
1118 {
1119 std::deque<llvm::BasicBlock *> todo;
1120 std::set<llvm::BasicBlock *> done;
1121 for (auto suc : llvm::successors(inst1->getParent())) {
1122 todo.push_back(suc);
1123 }
1124 while (todo.size()) {
1125 auto BB = todo.front();
1126 todo.pop_front();
1127 if (done.count(BB))
1128 continue;
1129 done.insert(BB);
1130
1131 for (auto &ni : *BB) {
1132 instructions.insert(&ni);
1133 }
1134 for (auto suc : llvm::successors(BB)) {
1135 if (!l1 || suc != l1->getHeader()) {
1136 todo.push_back(suc);
1137 }
1138 }
1139 }
1140 }
1141
1142 allPredecessorsOf(inst2, [&](llvm::Instruction *I) -> bool {
1143 if (instructions.find(I) == instructions.end())
1144 return /*earlyReturn*/ false;
1145 return f(I);
1146 });
1147}
1148
1149enum class MPI_CallType {
1150 ISEND = 1,
1151 IRECV = 2,
1152};
1153
1154enum class MPI_Elem {
1155 Buf = 0,
1156 Count = 1,
1157 DataType = 2,
1158 Src = 3,
1159 Tag = 4,
1160 Comm = 5,
1161 Call = 6,
1162 Old = 7
1163};
1164
1165static inline llvm::PointerType *getPointerType(llvm::Type *T,
1166 unsigned AddressSpace = 0) {
1167#if LLVM_VERSION_MAJOR >= 17
1168 return llvm::PointerType::get(T->getContext(), AddressSpace);
1169#else
1170 return llvm::PointerType::get(T, AddressSpace);
1171#endif
1172}
1173
1174static inline llvm::PointerType *getInt8PtrTy(llvm::LLVMContext &Context,
1175 unsigned AddressSpace = 0) {
1176 return getPointerType(llvm::Type::getInt8Ty(Context), AddressSpace);
1177}
1178
1179static inline llvm::PointerType *getUnqual(llvm::Type *T) {
1180 return getPointerType(T);
1181}
1182
1183static inline llvm::StructType *getMPIHelper(llvm::LLVMContext &Context) {
1184 using namespace llvm;
1185 auto i64 = Type::getInt64Ty(Context);
1186 Type *types[] = {
1187 /*buf 0 */ getInt8PtrTy(Context),
1188 /*count 1 */ i64,
1189 /*datatype 2 */ getInt8PtrTy(Context),
1190 /*src 3 */ i64,
1191 /*tag 4 */ i64,
1192 /*comm 5 */ getInt8PtrTy(Context),
1193 /*fn 6 */ Type::getInt8Ty(Context),
1194 /*old 7 */ getInt8PtrTy(Context),
1195 };
1196 return StructType::get(Context, types, false);
1197}
1198
1199template <MPI_Elem E, bool Pointer = true>
1200static inline llvm::Value *getMPIMemberPtr(llvm::IRBuilder<> &B, llvm::Value *V,
1201 llvm::Type *T) {
1202 using namespace llvm;
1203 auto i64 = Type::getInt64Ty(V->getContext());
1204 auto i32 = Type::getInt32Ty(V->getContext());
1205 auto c0_64 = ConstantInt::get(i64, 0);
1206
1207 if (Pointer) {
1208 return B.CreateInBoundsGEP(T, V,
1209 {c0_64, ConstantInt::get(i32, (uint64_t)E)});
1210 } else {
1211 return B.CreateExtractValue(V, {(unsigned)E});
1212 }
1213}
1214
1215llvm::Value *getOrInsertOpFloatSum(llvm::Module &M, llvm::Type *OpPtr,
1216 llvm::Type *OpType, ConcreteType CT,
1217 llvm::Type *intType, llvm::IRBuilder<> &B2);
1218
1219class AssertingReplacingVH final : public llvm::CallbackVH {
1220public:
1222
1223 AssertingReplacingVH(llvm::Value *new_value) { setValPtr(new_value); }
1224
1225 void deleted() override final {
1226 assert(0 && "attempted to delete value with remaining handle use");
1227 llvm_unreachable("attempted to delete value with remaining handle use");
1228 }
1229
1230 void allUsesReplacedWith(llvm::Value *new_value) override final {
1231 setValPtr(new_value);
1232 }
1234};
1235
1236template <typename T> static inline llvm::Function *getFunctionFromCall(T *op) {
1237 const llvm::Function *called = nullptr;
1238 using namespace llvm;
1239 const llvm::Value *callVal;
1240 callVal = op->getCalledOperand();
1241 while (!called) {
1242 if (auto castinst = dyn_cast<ConstantExpr>(callVal))
1243 if (castinst->isCast()) {
1244 callVal = castinst->getOperand(0);
1245 continue;
1246 }
1247 if (auto fn = dyn_cast<Function>(callVal)) {
1248 called = fn;
1249 break;
1250 }
1251 if (auto alias = dyn_cast<GlobalAlias>(callVal)) {
1252 callVal = alias->getAliasee();
1253 continue;
1254 }
1255 break;
1256 }
1257 return called ? const_cast<llvm::Function *>(called) : nullptr;
1258}
1259
1260static inline llvm::StringRef getFuncName(llvm::Function *called) {
1261 if (called->hasFnAttribute("enzyme_math"))
1262 return called->getFnAttribute("enzyme_math").getValueAsString();
1263 else if (called->hasFnAttribute("enzyme_allocator"))
1264 return "enzyme_allocator";
1265 else
1266 return called->getName();
1267}
1268
1269static inline llvm::StringRef getFuncNameFromCall(const llvm::CallBase *op) {
1270 auto AttrList =
1271 op->getAttributes().getAttributes(llvm::AttributeList::FunctionIndex);
1272 if (AttrList.hasAttribute("enzyme_math"))
1273 return AttrList.getAttribute("enzyme_math").getValueAsString();
1274 if (AttrList.hasAttribute("enzyme_allocator"))
1275 return "enzyme_allocator";
1276
1277 if (auto called = getFunctionFromCall(op)) {
1278 return getFuncName(called);
1279 }
1280 return "";
1281}
1282
1283static inline bool hasNoCache(llvm::Value *op) {
1284 using namespace llvm;
1285 if (auto CB = dyn_cast<CallBase>(op)) {
1286 if (auto called = getFunctionFromCall(CB)) {
1287 if (called->hasFnAttribute("enzyme_nocache"))
1288 return true;
1289 }
1290 if (EnzymeJuliaAddrLoad && getFuncNameFromCall(CB) == "julia.gc_loaded") {
1291 return true;
1292 }
1293 }
1294 if (auto I = dyn_cast<Instruction>(op))
1295 if (hasMetadata(I, "enzyme_nocache"))
1296 return true;
1297
1298 if (EnzymeJuliaAddrLoad) {
1299 if (auto PT = dyn_cast<PointerType>(op->getType())) {
1300 if (PT->getAddressSpace() == 11 || PT->getAddressSpace() == 13) {
1301 if (isa<CastInst>(op) || isa<GetElementPtrInst>(op))
1302 return true;
1303 }
1304 }
1305 }
1306 if (auto IT = dyn_cast<IntegerType>(op->getType()))
1307 if (!isPowerOf2_64(IT->getBitWidth()) && !EnzymeNonPower2Cache)
1308 return true;
1309
1310 return false;
1311}
1312
1313#if LLVM_VERSION_MAJOR >= 16
1314static inline std::optional<size_t>
1315getAllocationIndexFromCall(const llvm::CallBase *op)
1316#else
1317static inline llvm::Optional<size_t>
1318getAllocationIndexFromCall(const llvm::CallBase *op)
1319#endif
1320{
1321 auto AttrList =
1322 op->getAttributes().getAttributes(llvm::AttributeList::FunctionIndex);
1323 if (AttrList.hasAttribute("enzyme_allocator")) {
1324 size_t res;
1325 bool b = AttrList.getAttribute("enzyme_allocator")
1326 .getValueAsString()
1327 .getAsInteger(10, res);
1328 (void)b;
1329 assert(!b);
1330#if LLVM_VERSION_MAJOR >= 16
1331 return std::optional<size_t>(res);
1332#else
1333 return llvm::Optional<size_t>(res);
1334#endif
1335 }
1336
1337 if (auto called = getFunctionFromCall(op)) {
1338 if (called->hasFnAttribute("enzyme_allocator")) {
1339 size_t res;
1340 bool b = called->getFnAttribute("enzyme_allocator")
1341 .getValueAsString()
1342 .getAsInteger(10, res);
1343 (void)b;
1344 assert(!b);
1345#if LLVM_VERSION_MAJOR >= 16
1346 return std::optional<size_t>(res);
1347#else
1348 return llvm::Optional<size_t>(res);
1349#endif
1350 }
1351 }
1352#if LLVM_VERSION_MAJOR >= 16
1353 return std::optional<size_t>();
1354#else
1355 return llvm::Optional<size_t>();
1356#endif
1357}
1358
1359template <typename T>
1360static inline llvm::Function *getDeallocatorFnFromCall(T *op) {
1361 if (auto MD = hasMetadata(op, "enzyme_deallocator_fn")) {
1362 auto md2 = llvm::cast<llvm::MDTuple>(MD);
1363 assert(md2->getNumOperands() == 1);
1364 return llvm::cast<llvm::Function>(
1365 llvm::cast<llvm::ConstantAsMetadata>(md2->getOperand(0))->getValue());
1366 }
1367 if (auto called = getFunctionFromCall(op)) {
1368 if (auto MD = hasMetadata(called, "enzyme_deallocator_fn")) {
1369 auto md2 = llvm::cast<llvm::MDTuple>(MD);
1370 assert(md2->getNumOperands() == 1);
1371 return llvm::cast<llvm::Function>(
1372 llvm::cast<llvm::ConstantAsMetadata>(md2->getOperand(0))->getValue());
1373 }
1374 }
1375 llvm::errs() << "dealloc fn: " << *op->getParent()->getParent()->getParent()
1376 << "\n";
1377 llvm_unreachable("Illegal deallocatorfn");
1378}
1379
1380template <typename T>
1381static inline std::vector<ssize_t> getDeallocationIndicesFromCall(T *op) {
1382 llvm::StringRef res = "";
1383 auto AttrList =
1384 op->getAttributes().getAttributes(llvm::AttributeList::FunctionIndex);
1385 if (AttrList.hasAttribute("enzyme_deallocator"))
1386 res = AttrList.getAttribute("enzyme_deaellocator").getValueAsString();
1387
1388 if (auto called = getFunctionFromCall(op)) {
1389 if (called->hasFnAttribute("enzyme_deallocator"))
1390 res = called->getFnAttribute("enzyme_deallocator").getValueAsString();
1391 }
1392 if (res.size() == 0)
1393 llvm_unreachable("Illegal deallocator");
1394 llvm::SmallVector<llvm::StringRef, 1> inds;
1395 res.split(inds, ",");
1396 std::vector<ssize_t> vinds;
1397 for (auto ind : inds) {
1398 ssize_t Result;
1399 bool b = ind.getAsInteger(10, Result);
1400 (void)b;
1401 assert(!b);
1402 vinds.push_back(Result);
1403 }
1404 return vinds;
1405}
1406
1407llvm::Function *
1409 llvm::ArrayRef<llvm::Type *> T,
1410 llvm::PointerType *reqType);
1411
1412void ErrorIfRuntimeInactive(llvm::IRBuilder<> &B, llvm::Value *primal,
1413 llvm::Value *shadow, const char *Message,
1414 llvm::DebugLoc &&loc, llvm::Instruction *orig);
1415
1416llvm::Function *GetFunctionFromValue(llvm::Value *fn);
1417
1418llvm::Function *getFirstFunctionDefinition(llvm::Module &M);
1419
1420llvm::Value *simplifyLoad(llvm::Value *LI, size_t valSz = 0,
1421 size_t preOffset = 0);
1422
1423static inline bool shouldDisableNoWrite(const llvm::CallInst *CI) {
1424 auto F = getFunctionFromCall(CI);
1425 auto funcName = getFuncNameFromCall(CI);
1426
1427 if (CI->hasFnAttr("enzyme_preserve_primal") ||
1428 hasMetadata(CI, "enzyme_augment") || hasMetadata(CI, "enzyme_gradient") ||
1429 hasMetadata(CI, "enzyme_derivative") ||
1430 hasMetadata(CI, "enzyme_splitderivative") ||
1431 (F &&
1432 (F->hasFnAttribute("enzyme_preserve_primal") ||
1433 hasMetadata(F, "enzyme_augment") || hasMetadata(F, "enzyme_gradient") ||
1434 hasMetadata(F, "enzyme_derivative") ||
1435 hasMetadata(F, "enzyme_splitderivative"))) ||
1436 !F) {
1437 return true;
1438 }
1439 if (funcName == "MPI_Wait" || funcName == "MPI_Waitall") {
1440 return true;
1441 }
1442 return false;
1443}
1444
1445static inline bool isIntelSubscriptIntrinsic(const llvm::IntrinsicInst &II) {
1446 return startsWith(getFuncNameFromCall(&II), "llvm.intel.subscript");
1447}
1448
1449static inline bool isIntelSubscriptIntrinsic(const llvm::Value *val) {
1450 if (auto II = llvm::dyn_cast<llvm::IntrinsicInst>(val)) {
1451 return isIntelSubscriptIntrinsic(*II);
1452 }
1453 return false;
1454}
1455
1456static inline bool isPointerArithmeticInst(const llvm::Value *V,
1457 bool includephi = true,
1458 bool includebin = true) {
1459 if (llvm::isa<llvm::CastInst>(V) || llvm::isa<llvm::GetElementPtrInst>(V) ||
1460 (includephi && llvm::isa<llvm::PHINode>(V)))
1461 return true;
1462
1463 if (includebin)
1464 if (auto BI = llvm::dyn_cast<llvm::BinaryOperator>(V)) {
1465 switch (BI->getOpcode()) {
1466 case llvm::BinaryOperator::Add:
1467 case llvm::BinaryOperator::Sub:
1468 case llvm::BinaryOperator::Mul:
1469 case llvm::BinaryOperator::SDiv:
1470 case llvm::BinaryOperator::UDiv:
1471 case llvm::BinaryOperator::SRem:
1472 case llvm::BinaryOperator::URem:
1473 case llvm::BinaryOperator::Or:
1474 case llvm::BinaryOperator::And:
1475 case llvm::BinaryOperator::Shl:
1476 case llvm::BinaryOperator::LShr:
1477 case llvm::BinaryOperator::AShr:
1478 return true;
1479 default:
1480 break;
1481 }
1482 }
1483
1485 return true;
1486 }
1487
1488 if (auto *Call = llvm::dyn_cast<llvm::CallInst>(V)) {
1489 auto funcName = getFuncNameFromCall(Call);
1490 if (funcName == "julia.pointer_from_objref") {
1491 return true;
1492 }
1493 if (funcName == "julia.gc_loaded") {
1494 return true;
1495 }
1496 if (funcName.contains("__enzyme_todense")) {
1497 return true;
1498 }
1499 if (funcName.contains("__enzyme_ignore_derivatives")) {
1500 return true;
1501 }
1502 }
1503
1504 return false;
1505}
1506
1507static inline llvm::Value *getBaseObject(llvm::Value *V,
1508 bool offsetAllowed = true) {
1509 while (true) {
1510 if (auto CI = llvm::dyn_cast<llvm::CastInst>(V)) {
1511 V = CI->getOperand(0);
1512 continue;
1513 } else if (auto CI = llvm::dyn_cast<llvm::GetElementPtrInst>(V)) {
1514 if (offsetAllowed || CI->hasAllZeroIndices()) {
1515 V = CI->getOperand(0);
1516 continue;
1517 }
1518 } else if (auto II = llvm::dyn_cast<llvm::IntrinsicInst>(V);
1519 II && isIntelSubscriptIntrinsic(*II)) {
1520 if (offsetAllowed) {
1521 V = II->getOperand(3);
1522 continue;
1523 }
1524 } else if (auto CI = llvm::dyn_cast<llvm::PHINode>(V)) {
1525 if (CI->getNumIncomingValues() == 1) {
1526 V = CI->getOperand(0);
1527 continue;
1528 }
1529 } else if (auto *GA = llvm::dyn_cast<llvm::GlobalAlias>(V)) {
1530 if (GA->isInterposable())
1531 break;
1532 V = GA->getAliasee();
1533 continue;
1534 } else if (auto CE = llvm::dyn_cast<llvm::ConstantExpr>(V)) {
1535 if (CE->isCast() || CE->getOpcode() == llvm::Instruction::GetElementPtr) {
1536 V = CE->getOperand(0);
1537 continue;
1538 }
1539 } else if (auto *Call = llvm::dyn_cast<llvm::CallInst>(V)) {
1540 auto funcName = getFuncNameFromCall(Call);
1541 auto AttrList = Call->getAttributes().getAttributes(
1542 llvm::AttributeList::FunctionIndex);
1543 if (AttrList.hasAttribute("enzyme_pointermath") && offsetAllowed) {
1544 size_t res = 0;
1545 bool failed = AttrList.getAttribute("enzyme_pointermath")
1546 .getValueAsString()
1547 .getAsInteger(10, res);
1548 (void)failed;
1549 assert(!failed);
1550 V = Call->getArgOperand(res);
1551 continue;
1552 }
1553 if (funcName == "julia.pointer_from_objref") {
1554 V = Call->getArgOperand(0);
1555 continue;
1556 }
1557 if (funcName == "julia.gc_loaded") {
1558 V = Call->getArgOperand(1);
1559 continue;
1560 }
1561 if (funcName == "jl_reshape_array" || funcName == "ijl_reshape_array") {
1562 V = Call->getArgOperand(1);
1563 continue;
1564 }
1565 if (funcName.contains("__enzyme_ignore_derivatives")) {
1566 V = Call->getArgOperand(0);
1567 continue;
1568 }
1569 if (funcName.contains("__enzyme_todense")) {
1570#if LLVM_VERSION_MAJOR >= 14
1571 size_t numargs = Call->arg_size();
1572#else
1573 size_t numargs = Call->getNumArgOperands();
1574#endif
1575 if (numargs == 3) {
1576 V = Call->getArgOperand(2);
1577 continue;
1578 }
1579 }
1580 if (auto fn = getFunctionFromCall(Call)) {
1581 auto AttrList = fn->getAttributes().getAttributes(
1582 llvm::AttributeList::FunctionIndex);
1583 if (AttrList.hasAttribute("enzyme_pointermath") && offsetAllowed) {
1584 size_t res = 0;
1585 bool failed = AttrList.getAttribute("enzyme_pointermath")
1586 .getValueAsString()
1587 .getAsInteger(10, res);
1588 (void)failed;
1589 assert(!failed);
1590 V = Call->getArgOperand(res);
1591 continue;
1592 }
1593 bool found = false;
1594 for (auto &arg : fn->args()) {
1595 if (arg.hasAttribute(llvm::Attribute::Returned)) {
1596 found = true;
1597 V = Call->getArgOperand(arg.getArgNo());
1598 }
1599 }
1600 if (found)
1601 continue;
1602 }
1603
1604 // CaptureTracking can know about special capturing properties of some
1605 // intrinsics like launder.invariant.group, that can't be expressed with
1606 // the attributes, but have properties like returning aliasing pointer.
1607 // Because some analysis may assume that nocaptured pointer is not
1608 // returned from some special intrinsic (because function would have to
1609 // be marked with returns attribute), it is crucial to use this function
1610 // because it should be in sync with CaptureTracking. Not using it may
1611 // cause weird miscompilations where 2 aliasing pointers are assumed to
1612 // noalias.
1613 if (offsetAllowed)
1614 if (auto *RP =
1615 llvm::getArgumentAliasingToReturnedPointer(Call, false)) {
1616 V = RP;
1617 continue;
1618 }
1619 }
1620
1621 if (offsetAllowed)
1622 if (auto I = llvm::dyn_cast<llvm::Instruction>(V)) {
1623#if LLVM_VERSION_MAJOR >= 12
1624 auto V2 = llvm::getUnderlyingObject(I, 100);
1625#else
1626 auto V2 = llvm::GetUnderlyingObject(
1627 I, I->getParent()->getParent()->getParent()->getDataLayout(), 100);
1628#endif
1629 if (V2 != V) {
1630 V = V2;
1631 break;
1632 }
1633 }
1634 break;
1635 }
1636 return V;
1637}
1638static inline const llvm::Value *getBaseObject(const llvm::Value *V) {
1639 return getBaseObject(const_cast<llvm::Value *>(V));
1640}
1641
1642static inline llvm::SetVector<llvm::Value *>
1643getBaseObjects(llvm::Value *V, bool offsetAllowed = true) {
1644 llvm::SmallPtrSet<llvm::Value *, 1> seen;
1645 llvm::SetVector<llvm::Value *> results;
1646 llvm::SmallVector<llvm::Value *, 1> todo = {V};
1647
1648 while (todo.size()) {
1649 auto obj = todo.back();
1650 todo.pop_back();
1651 if (seen.contains(obj))
1652 continue;
1653 seen.insert(obj);
1654
1655 if (auto PN = llvm::dyn_cast<llvm::PHINode>(obj)) {
1656 for (auto &x : PN->incoming_values()) {
1657 todo.push_back(x);
1658 }
1659 continue;
1660 }
1661
1662 auto cur = getBaseObject(obj, offsetAllowed);
1663 if (cur != obj) {
1664 todo.push_back(cur);
1665 continue;
1666 }
1667
1668 results.insert(obj);
1669 }
1670 return results;
1671}
1672
1673static inline bool isReadOnly(const llvm::Function *F, ssize_t arg = -1) {
1674 if (F->onlyReadsMemory())
1675 return true;
1676
1677 if (F->hasFnAttribute(llvm::Attribute::ReadOnly) ||
1678 F->hasFnAttribute(llvm::Attribute::ReadNone))
1679 return true;
1680 if (arg != -1) {
1681 if (F->hasParamAttribute(arg, llvm::Attribute::ReadOnly) ||
1682 F->hasParamAttribute(arg, llvm::Attribute::ReadNone))
1683 return true;
1684 // if (F->getAttributes().hasParamAttribute(arg, "enzyme_ReadOnly") ||
1685 // F->getAttributes().hasParamAttribute(arg, "enzyme_ReadNone"))
1686 // return true;
1687 }
1688 return false;
1689}
1690
1691static inline bool isReadOnly(const llvm::CallBase *call, ssize_t arg = -1) {
1692 if (call->onlyReadsMemory())
1693 return true;
1694 if (arg != -1 && call->onlyReadsMemory(arg))
1695 return true;
1696
1697 if (auto F = getFunctionFromCall(call)) {
1698 // Do not use function attrs for if different calling conv, such as a julia
1699 // call wrapping args into an array. This is because the wrapped array
1700 // may be nocapure/readonly, but the actual arg (which will be put in the
1701 // array) may not be.
1702 if (F->getCallingConv() == call->getCallingConv())
1703 if (isReadOnly(F, arg))
1704 return true;
1705 }
1706 return false;
1707}
1708
1709// Whether the function does not write to memory visible before the function in
1710// all cases that it doesn't error. In other words, the legal operations here
1711// are:
1712//. 1) Throw [in which case any operation guaranteed to throw is valid]
1713//. 2) Read from any memory
1714//. 3) Write to memory which did not exist did not exist prior to the function
1715// call. This means that one can write . to memory whose allocation happened
1716// within the call to F (including a local alloca, a malloc call, even if .
1717// returned). This is also legal to write to an sret and/or returnroots
1718// parameter (which must be an alloca).
1719static inline bool isLocalReadOnlyOrThrow(const llvm::Function *F) {
1720 if (isReadOnly(F))
1721 return true;
1722
1723 if (F->hasFnAttribute("enzyme_LocalReadOnlyOrThrow") ||
1724 F->hasFnAttribute("enzyme_ReadOnlyOrThrow"))
1725 return true;
1726
1727 return false;
1728}
1729
1730static inline bool isLocalReadOnlyOrThrow(const llvm::CallBase *call) {
1731 if (isReadOnly(call))
1732 return true;
1733
1734 if (call->hasFnAttr("enzyme_LocalReadOnlyOrThrow") ||
1735 call->hasFnAttr("enzyme_ReadOnlyOrThrow"))
1736 return true;
1737
1738 if (auto F = getFunctionFromCall(call)) {
1739 // Do not use function attrs for if different calling conv, such as a julia
1740 // call wrapping args into an array. This is because the wrapped array
1741 // may be nocapure/readonly, but the actual arg (which will be put in the
1742 // array) may not be.
1743 if (F->getCallingConv() == call->getCallingConv())
1745 return true;
1746 }
1747 return false;
1748}
1749
1750// Whether the function does not write to memory visible outside the function in
1751// all cases that it doesn't error. In other words, the legal operations here
1752// are:
1753//. 1) Throw [in which case any operation guaranteed to throw is valid]
1754//. 2) Read from any memory
1755//. 3) Write to memory which did not exist did not exist prior to the function
1756// call. This means that one can write . to memory whose lifetime is
1757// entirely contained within F (including a local alloca, a malloc call locally
1758// freed, but not . a returned malloc call).
1759static inline bool isReadOnlyOrThrow(const llvm::Function *F) {
1760 if (isReadOnly(F))
1761 return true;
1762
1763 if (F->hasFnAttribute("enzyme_ReadOnlyOrThrow"))
1764 return true;
1765
1766 return false;
1767}
1768
1769static inline bool isReadOnlyOrThrow(const llvm::CallBase *call) {
1770 if (isReadOnly(call))
1771 return true;
1772
1773 if (call->hasFnAttr("enzyme_ReadOnlyOrThrow"))
1774 return true;
1775
1776 if (auto F = getFunctionFromCall(call)) {
1777 // Do not use function attrs for if different calling conv, such as a julia
1778 // call wrapping args into an array. This is because the wrapped array
1779 // may be nocapure/readonly, but the actual arg (which will be put in the
1780 // array) may not be.
1781 if (F->getCallingConv() == call->getCallingConv())
1782 if (isReadOnlyOrThrow(F))
1783 return true;
1784 }
1785 return false;
1786}
1787
1788static inline bool isWriteOnly(const llvm::Function *F, ssize_t arg = -1) {
1789#if LLVM_VERSION_MAJOR >= 14
1790 if (F->onlyWritesMemory())
1791 return true;
1792#endif
1793 if (F->hasFnAttribute(llvm::Attribute::WriteOnly) ||
1794 F->hasFnAttribute(llvm::Attribute::ReadNone))
1795 return true;
1796 if (arg != -1) {
1797 if (F->hasParamAttribute(arg, llvm::Attribute::WriteOnly) ||
1798 F->hasParamAttribute(arg, llvm::Attribute::ReadNone))
1799 return true;
1800 }
1801 return false;
1802}
1803
1804static inline bool isWriteOnly(const llvm::CallBase *call, ssize_t arg = -1) {
1805#if LLVM_VERSION_MAJOR >= 14
1806 if (call->onlyWritesMemory())
1807 return true;
1808 if (arg != -1 && call->onlyWritesMemory(arg))
1809 return true;
1810#else
1811 if (call->hasFnAttr(llvm::Attribute::WriteOnly) ||
1812 call->hasFnAttr(llvm::Attribute::ReadNone))
1813 return true;
1814 if (arg != -1) {
1815 if (call->dataOperandHasImpliedAttr(arg + 1, llvm::Attribute::WriteOnly) ||
1816 call->dataOperandHasImpliedAttr(arg + 1, llvm::Attribute::ReadNone))
1817 return true;
1818 }
1819#endif
1820
1821 if (auto F = getFunctionFromCall(call)) {
1822 // Do not use function attrs for if different calling conv, such as a julia
1823 // call wrapping args into an array. This is because the wrapped array
1824 // may be nocapure/readonly, but the actual arg (which will be put in the
1825 // array) may not be.
1826 if (F->getCallingConv() == call->getCallingConv())
1827 return isWriteOnly(F, arg);
1828 }
1829 return false;
1830}
1831
1832static inline bool isReadNone(const llvm::CallBase *call, ssize_t arg = -1) {
1833 return isReadOnly(call, arg) && isWriteOnly(call, arg);
1834}
1835
1836static inline bool isReadNone(const llvm::Function *F, ssize_t arg = -1) {
1837 return isReadOnly(F, arg) && isWriteOnly(F, arg);
1838}
1839
1840static inline bool isNoCapture(const llvm::CallBase *call, size_t idx) {
1841 if (call->doesNotCapture(idx))
1842 return true;
1843
1844 if (auto F = getFunctionFromCall(call)) {
1845 // Do not use function attrs for if different calling conv, such as a julia
1846 // call wrapping args into an array. This is because the wrapped array
1847 // may be nocapure/readonly, but the actual arg (which will be put in the
1848 // array) may not be.
1849 if (F->getCallingConv() == call->getCallingConv())
1850 if (idx < F->arg_size() && F->getArg(idx)->hasNoCaptureAttr())
1851 return true;
1852 // if (F->getAttributes().hasParamAttribute(idx, "enzyme_NoCapture"))
1853 // return true;
1854 }
1855 return false;
1856}
1857
1858static inline bool isNoAlias(const llvm::CallBase *call) {
1859 if (call->returnDoesNotAlias())
1860 return true;
1861
1862 if (auto F = getFunctionFromCall(call)) {
1863 if (F->returnDoesNotAlias())
1864 return true;
1865 }
1866 return false;
1867}
1868
1869static inline bool isNoAlias(const llvm::Value *val) {
1870 if (auto CB = llvm::dyn_cast<llvm::CallBase>(val))
1871 return isNoAlias(CB);
1872 if (auto arg = llvm::dyn_cast<llvm::Argument>(val)) {
1873 arg->hasNoAliasAttr();
1874 }
1875 return false;
1876}
1877
1878static inline bool isNoEscapingAllocation(const llvm::Function *F) {
1879 if (F->hasFnAttribute("enzyme_no_escaping_allocation"))
1880 return true;
1881 if (F->getName() == "llvm.enzyme.lifetime_start" ||
1882 F->getName() == "llvm.enzyme.lifetime_end") {
1883 return true;
1884 }
1885 using namespace llvm;
1886 switch (F->getIntrinsicID()) {
1887 case Intrinsic::memset:
1888 case Intrinsic::memcpy:
1889 case Intrinsic::memmove:
1890#if LLVM_VERSION_MAJOR >= 12
1891 case Intrinsic::experimental_noalias_scope_decl:
1892#endif
1893 case Intrinsic::objectsize:
1894 case Intrinsic::floor:
1895 case Intrinsic::ceil:
1896 case Intrinsic::trunc:
1897 case Intrinsic::rint:
1898 case Intrinsic::lrint:
1899 case Intrinsic::llrint:
1900 case Intrinsic::nearbyint:
1901 case Intrinsic::round:
1902 case Intrinsic::roundeven:
1903 case Intrinsic::lround:
1904 case Intrinsic::llround:
1905#if LLVM_VERSION_MAJOR <= 20
1906 case Intrinsic::nvvm_barrier0:
1907#else
1908 case Intrinsic::nvvm_barrier_cta_sync_aligned_all:
1909 case Intrinsic::nvvm_barrier_cta_sync_aligned_count:
1910#endif
1911#if LLVM_VERSION_MAJOR < 22
1912 case Intrinsic::nvvm_barrier0_popc:
1913 case Intrinsic::nvvm_barrier0_and:
1914 case Intrinsic::nvvm_barrier0_or:
1915#else
1916 case Intrinsic::nvvm_barrier_cta_red_and_aligned_all:
1917 case Intrinsic::nvvm_barrier_cta_red_and_aligned_count:
1918 case Intrinsic::nvvm_barrier_cta_red_or_aligned_all:
1919 case Intrinsic::nvvm_barrier_cta_red_or_aligned_count:
1920 case Intrinsic::nvvm_barrier_cta_red_popc_aligned_all:
1921 case Intrinsic::nvvm_barrier_cta_red_popc_aligned_count:
1922#endif
1923 case Intrinsic::nvvm_membar_cta:
1924 case Intrinsic::nvvm_membar_gl:
1925 case Intrinsic::nvvm_membar_sys:
1926 case Intrinsic::amdgcn_s_barrier:
1927 case Intrinsic::assume:
1928 case Intrinsic::lifetime_start:
1929 case Intrinsic::lifetime_end:
1930#if LLVM_VERSION_MAJOR <= 16
1931 case Intrinsic::dbg_addr:
1932#endif
1933
1934 case Intrinsic::dbg_declare:
1935 case Intrinsic::dbg_value:
1936 case Intrinsic::dbg_label:
1937 case Intrinsic::invariant_start:
1938 case Intrinsic::invariant_end:
1939 case Intrinsic::var_annotation:
1940 case Intrinsic::ptr_annotation:
1941 case Intrinsic::annotation:
1942 case Intrinsic::codeview_annotation:
1943 case Intrinsic::expect:
1944 case Intrinsic::type_test:
1945 case Intrinsic::donothing:
1946 case Intrinsic::prefetch:
1947 case Intrinsic::trap:
1948 case Intrinsic::is_constant:
1949#if LLVM_VERSION_MAJOR >= 12
1950 case Intrinsic::smax:
1951 case Intrinsic::smin:
1952 case Intrinsic::umax:
1953 case Intrinsic::umin:
1954#endif
1955 case Intrinsic::ctlz:
1956 case Intrinsic::cttz:
1957 case Intrinsic::sadd_with_overflow:
1958 case Intrinsic::ssub_with_overflow:
1959#if LLVM_VERSION_MAJOR >= 12
1960 case Intrinsic::abs:
1961#endif
1962 case Intrinsic::sqrt:
1963 case Intrinsic::exp:
1964 case Intrinsic::cos:
1965 case Intrinsic::sin:
1966#if LLVM_VERSION_MAJOR >= 19
1967 case Intrinsic::tanh:
1968 case Intrinsic::cosh:
1969 case Intrinsic::sinh:
1970#endif
1971 case Intrinsic::copysign:
1972 case Intrinsic::fabs:
1973 return true;
1974 default:
1975 break;
1976 }
1977 // if (F->empty())
1978 // llvm::errs() << " may escape:" << F->getName() << "\n";
1979 return false;
1980}
1981static inline bool isNoEscapingAllocation(const llvm::CallBase *call) {
1982 auto AttrList =
1983 call->getAttributes().getAttributes(llvm::AttributeList::FunctionIndex);
1984 if (AttrList.hasAttribute("enzyme_no_escaping_allocation"))
1985 return true;
1986 if (auto F = getFunctionFromCall(call)) {
1987 auto res = isNoEscapingAllocation(F);
1988 // if (!res && F->empty()) {
1989 // llvm::errs() << " may escape:" << *call << "\n";
1990 //}
1991 return res;
1992 }
1993 return false;
1994}
1995
1996bool attributeKnownFunctions(llvm::Function &F);
1997
1998llvm::Constant *getUndefinedValueForType(llvm::Module &M, llvm::Type *T,
1999 bool forceZero = false);
2000
2001llvm::Value *SanitizeDerivatives(llvm::Value *val, llvm::Value *toset,
2002 llvm::IRBuilder<> &BuilderM,
2003 llvm::Value *mask = nullptr);
2004
2005static inline llvm::Value *CreateSelect(llvm::IRBuilder<> &Builder2,
2006 llvm::Value *cmp, llvm::Value *tval,
2007 llvm::Value *fval,
2008 const llvm::Twine &Name = "") {
2009 if (auto cmpi = llvm::dyn_cast<llvm::ConstantInt>(cmp)) {
2010 if (cmpi->isZero())
2011 return fval;
2012 else
2013 return tval;
2014 }
2015 return Builder2.CreateSelect(cmp, tval, fval, Name);
2016}
2017
2018static inline llvm::Value *checkedMul(bool strongZero,
2019 llvm::IRBuilder<> &Builder2,
2020 llvm::Value *idiff, llvm::Value *pres,
2021 const llvm::Twine &Name = "") {
2022 llvm::Value *res = Builder2.CreateFMul(idiff, pres, Name);
2023 if (strongZero) {
2024 llvm::Value *zero = llvm::Constant::getNullValue(idiff->getType());
2025 if (auto C = llvm::dyn_cast<llvm::ConstantFP>(pres))
2026 if (!C->isInfinity() && !C->isNaN())
2027 return res;
2028 res = Builder2.CreateSelect(Builder2.CreateFCmpOEQ(idiff, zero), zero, res);
2029 }
2030 return res;
2031}
2032static inline llvm::Value *checkedDiv(bool strongZero,
2033 llvm::IRBuilder<> &Builder2,
2034 llvm::Value *idiff, llvm::Value *pres,
2035 const llvm::Twine &Name = "") {
2036 llvm::Value *res = Builder2.CreateFDiv(idiff, pres, Name);
2037 if (strongZero) {
2038 llvm::Value *zero = llvm::Constant::getNullValue(idiff->getType());
2039 if (auto C = llvm::dyn_cast<llvm::ConstantFP>(pres))
2040 if (!C->isZero() && !C->isNaN())
2041 return res;
2042 res = Builder2.CreateSelect(Builder2.CreateFCmpOEQ(idiff, zero), zero, res);
2043 }
2044 return res;
2045}
2046
2047static inline bool containsOnlyAtMostTopBit(const llvm::Value *V,
2048 llvm::Type *FT,
2049 const llvm::DataLayout &dl,
2050 llvm::Type **vFT = nullptr) {
2051 using namespace llvm;
2052 if (auto CI = dyn_cast_or_null<ConstantInt>(V)) {
2053 if (CI->isZero()) {
2054 if (vFT)
2055 *vFT = FT;
2056 return true;
2057 }
2058 if (dl.getTypeSizeInBits(FT) == dl.getTypeSizeInBits(CI->getType())) {
2059 if (CI->isNegative() && CI->isMinValue(/*signed*/ true)) {
2060 if (vFT)
2061 *vFT = FT;
2062 return true;
2063 }
2064 }
2065 }
2066 if (auto CV = dyn_cast_or_null<ConstantVector>(V)) {
2067 bool legal = true;
2068 for (size_t i = 0, end = CV->getNumOperands(); i < end; ++i) {
2069 legal &= containsOnlyAtMostTopBit(CV->getOperand(i), FT, dl);
2070 }
2071 if (legal && vFT) {
2072#if LLVM_VERSION_MAJOR >= 12
2073 *vFT = VectorType::get(FT, CV->getType()->getElementCount());
2074#else
2075 *vFT = VectorType::get(FT, CV->getType()->getNumElements());
2076#endif
2077 }
2078 return legal;
2079 }
2080
2081 if (auto CV = dyn_cast_or_null<ConstantDataVector>(V)) {
2082 bool legal = true;
2083 for (size_t i = 0, end = CV->getNumElements(); i < end; ++i) {
2084 auto CI = CV->getElementAsAPInt(i);
2085#if LLVM_VERSION_MAJOR > 16
2086 if (CI.isZero())
2087 continue;
2088#else
2089 if (CI.isNullValue())
2090 continue;
2091#endif
2092 if (dl.getTypeSizeInBits(FT) !=
2093 dl.getTypeSizeInBits(CV->getElementType())) {
2094 legal = false;
2095 break;
2096 }
2097 if (!CI.isMinSignedValue()) {
2098 legal = false;
2099 break;
2100 }
2101 }
2102 if (legal && vFT) {
2103#if LLVM_VERSION_MAJOR >= 12
2104 *vFT = VectorType::get(FT, CV->getType()->getElementCount());
2105#else
2106 *vFT = VectorType::get(FT, CV->getType()->getNumElements());
2107#endif
2108 }
2109 return legal;
2110 }
2111 if (auto BO = dyn_cast<BinaryOperator>(V)) {
2112 if (BO->getOpcode() == Instruction::And) {
2113 for (size_t i = 0; i < 2; i++) {
2114 if (containsOnlyAtMostTopBit(BO->getOperand(i), FT, dl))
2115 return true;
2116 }
2117 return false;
2118 }
2119 }
2120 return false;
2121}
2122
2123void addValueToCache(llvm::Value *arg, bool cache_arg, llvm::Type *ty,
2124 llvm::SmallVectorImpl<llvm::Value *> &cacheValues,
2125 llvm::IRBuilder<> &BuilderZ, const llvm::Twine &name = "");
2126
2127llvm::Value *load_if_ref(llvm::IRBuilder<> &B, llvm::Type *intType,
2128 llvm::Value *V, bool byRef);
2129
2130void copy_lower_to_upper(llvm::IRBuilder<> &B, llvm::Type *fpType,
2131 BlasInfo blas, bool byRef, llvm::Value *layout,
2132 llvm::Value *uplo, llvm::Value *A, llvm::Value *lda,
2133 llvm::Value *N);
2134
2135// julia_decl null means not julia decl, otherwise it is the integer type needed
2136// to cast to
2137llvm::Value *to_blas_callconv(llvm::IRBuilder<> &B, llvm::Value *V, bool byRef,
2138 bool cublas, llvm::IntegerType *julia_decl,
2139 llvm::IRBuilder<> &entryBuilder,
2140 llvm::Twine const & = "");
2141llvm::Value *to_blas_fp_callconv(llvm::IRBuilder<> &B, llvm::Value *V,
2142 bool byRef, llvm::Type *julia_decl,
2143 llvm::IRBuilder<> &entryBuilder,
2144 llvm::Twine const & = "");
2145
2146llvm::Value *get_cached_mat_width(llvm::IRBuilder<> &B,
2147 llvm::ArrayRef<llvm::Value *> trans,
2148 llvm::Value *arg_ld, llvm::Value *dim_1,
2149 llvm::Value *dim_2, bool cacheMat, bool byRef,
2150 bool cublas);
2151
2152template <typename T>
2153static inline void append(llvm::SmallVectorImpl<T> &vec) {}
2154template <typename T, typename... T2>
2155static inline void append(llvm::SmallVectorImpl<T> &vec, llvm::ArrayRef<T> vals,
2156 T2 &&...ts) {
2157 vec.append(vals.begin(), vals.end());
2158 append(vec, std::forward<T2>(ts)...);
2159}
2160template <typename... T>
2161static inline llvm::SmallVector<llvm::Value *, 1> concat_values(T &&...t) {
2162 llvm::SmallVector<llvm::Value *, 1> res;
2163 append(res, std::forward<T>(t)...);
2164 return res;
2165}
2166
2167llvm::Value *is_normal(llvm::IRBuilder<> &B, llvm::Value *trans, bool byRef,
2168 bool cublas);
2169llvm::Value *is_left(llvm::IRBuilder<> &B, llvm::Value *side, bool byRef,
2170 bool cublas);
2171llvm::Value *is_lower(llvm::IRBuilder<> &B, llvm::Value *uplo, bool byRef,
2172 bool cublas);
2173llvm::Value *is_nonunit(llvm::IRBuilder<> &B, llvm::Value *uplo, bool byRef,
2174 bool cublas);
2175
2176llvm::Value *lookup_with_layout(llvm::IRBuilder<> &B, llvm::Type *fpType,
2177 llvm::Value *layout, llvm::Value *base,
2178 llvm::Value *lda, llvm::Value *row,
2179 llvm::Value *col);
2180
2181// first one assume V is an Integer
2182llvm::Value *transpose(std::string floatType, llvm::IRBuilder<> &B,
2183 llvm::Value *V, bool cublas);
2184// secon one assume V is an Integer or a ptr to an int (depends on byRef)
2185llvm::Value *transpose(std::string floatType, llvm::IRBuilder<> &B,
2186 llvm::Value *V, bool byRef, bool cublas,
2187 llvm::IntegerType *IT, llvm::IRBuilder<> &entryBuilder,
2188 const llvm::Twine &name);
2189llvm::SmallVector<llvm::Value *, 1>
2190get_blas_row(llvm::IRBuilder<> &B, llvm::ArrayRef<llvm::Value *> trans,
2191 llvm::ArrayRef<llvm::Value *> row,
2192 llvm::ArrayRef<llvm::Value *> col, bool byRef, bool cublas);
2193
2194llvm::SmallVector<llvm::Value *, 1>
2195get_blas_row(llvm::IRBuilder<> &B, llvm::ArrayRef<llvm::Value *> trans,
2196 bool byRef, bool cublas);
2197
2198#ifdef __clang__
2199#pragma clang diagnostic push
2200#pragma clang diagnostic ignored "-Wunused-variable"
2201#else
2202#pragma GCC diagnostic push
2203#pragma GCC diagnostic ignored "-Wunused-variable"
2204#endif
2205
2206// Parameter attributes from the original function/call that
2207// we should preserve on the primal of the derivative code.
2208static inline llvm::Attribute::AttrKind PrimalParamAttrsToPreserve[] = {
2209 llvm::Attribute::AttrKind::ReadOnly,
2210 llvm::Attribute::AttrKind::WriteOnly,
2211 llvm::Attribute::AttrKind::ZExt,
2212 llvm::Attribute::AttrKind::SExt,
2213 llvm::Attribute::AttrKind::InReg,
2214 llvm::Attribute::AttrKind::ByVal,
2215#if LLVM_VERSION_MAJOR >= 12
2216 llvm::Attribute::AttrKind::ByRef,
2217#endif
2218 llvm::Attribute::AttrKind::Preallocated,
2219 llvm::Attribute::AttrKind::InAlloca,
2220#if LLVM_VERSION_MAJOR >= 13
2221 llvm::Attribute::AttrKind::ElementType,
2222#endif
2223#if LLVM_VERSION_MAJOR >= 15
2224 llvm::Attribute::AttrKind::AllocAlign,
2225#endif
2226 llvm::Attribute::AttrKind::NoFree,
2227 llvm::Attribute::AttrKind::Alignment,
2228 llvm::Attribute::AttrKind::StackAlignment,
2229#if LLVM_VERSION_MAJOR >= 20
2230 llvm::Attribute::AttrKind::Captures,
2231#else
2232 llvm::Attribute::AttrKind::NoCapture,
2233#endif
2234 llvm::Attribute::AttrKind::ReadNone
2235};
2236
2237// Parameter attributes from the original function/call that
2238// we should preserve on the shadow of the derivative code.
2239// Note that this will not occur on vectore > 1.
2240static inline llvm::Attribute::AttrKind ShadowParamAttrsToPreserve[] = {
2241 llvm::Attribute::AttrKind::ZExt,
2242 llvm::Attribute::AttrKind::SExt,
2243#if LLVM_VERSION_MAJOR >= 13
2244 llvm::Attribute::AttrKind::ElementType,
2245#endif
2246 llvm::Attribute::AttrKind::NoFree,
2247 llvm::Attribute::AttrKind::Alignment,
2248 llvm::Attribute::AttrKind::StackAlignment,
2249#if LLVM_VERSION_MAJOR >= 20
2250 llvm::Attribute::AttrKind::Captures,
2251#else
2252 llvm::Attribute::AttrKind::NoCapture,
2253#endif
2254 llvm::Attribute::AttrKind::ReadNone,
2255};
2256#ifdef __clang__
2257#pragma clang diagnostic pop
2258#else
2259#pragma GCC diagnostic pop
2260#endif
2261
2262static inline llvm::Function *
2263getIntrinsicDeclaration(llvm::Module *M, llvm::Intrinsic::ID id,
2264 llvm::ArrayRef<llvm::Type *> Tys = {}) {
2265#if LLVM_VERSION_MAJOR >= 20
2266 return llvm::Intrinsic::getOrInsertDeclaration(M, id, Tys);
2267#else
2268 return llvm::Intrinsic::getDeclaration(M, id, Tys);
2269#endif
2270}
2271
2272static inline llvm::Instruction *getFirstNonPHIOrDbg(llvm::BasicBlock *B) {
2273#if LLVM_VERSION_MAJOR >= 20
2274 return &*B->getFirstNonPHIOrDbg();
2275#else
2276 return B->getFirstNonPHIOrDbg();
2277#endif
2278}
2279
2280static inline llvm::Instruction *
2281getFirstNonPHIOrDbgOrLifetime(llvm::BasicBlock *B) {
2282#if LLVM_VERSION_MAJOR >= 20
2283 return &*B->getFirstNonPHIOrDbgOrLifetime();
2284#else
2285 return B->getFirstNonPHIOrDbgOrLifetime();
2286#endif
2287}
2288
2289static inline void addCallSiteNoCapture(llvm::CallBase *call, size_t idx) {
2290#if LLVM_VERSION_MAJOR > 20
2291 call->addParamAttr(
2292 idx, llvm::Attribute::get(call->getContext(), llvm::Attribute::Captures,
2293 llvm::CaptureInfo::none().toIntValue()));
2294#else
2295 call->addParamAttr(idx, llvm::Attribute::NoCapture);
2296#endif
2297}
2298
2299static inline void addFunctionNoCapture(llvm::Function *call, size_t idx) {
2300#if LLVM_VERSION_MAJOR > 20
2301 call->addParamAttr(
2302 idx, llvm::Attribute::get(call->getContext(), llvm::Attribute::Captures,
2303 llvm::CaptureInfo::none().toIntValue()));
2304#else
2305 call->addParamAttr(idx, llvm::Attribute::NoCapture);
2306#endif
2307}
2308
2309[[nodiscard]] static inline llvm::AttributeList
2310addFunctionNoCapture(llvm::LLVMContext &ctx, llvm::AttributeList list,
2311 size_t idx) {
2312 unsigned idxs = {(unsigned)idx};
2313#if LLVM_VERSION_MAJOR > 20
2314 return list.addParamAttribute(
2315 ctx, idxs,
2316 llvm::Attribute::get(ctx, llvm::Attribute::Captures,
2317 llvm::CaptureInfo::none().toIntValue()));
2318#else
2319 return list.addParamAttribute(ctx, idxs, llvm::Attribute::NoCapture);
2320#endif
2321}
2322
2323static inline llvm::Type *getSubType(llvm::Type *T) { return T; }
2324
2325template <typename Arg1, typename... Args>
2326static inline llvm::Type *getSubType(llvm::Type *T, Arg1 i, Args... args) {
2327 if (auto AT = llvm::dyn_cast<llvm::ArrayType>(T))
2328 return getSubType(AT->getElementType(), args...);
2329 if (auto VT = llvm::dyn_cast<llvm::VectorType>(T))
2330 return getSubType(VT->getElementType(), args...);
2331 if (auto ST = llvm::dyn_cast<llvm::StructType>(T)) {
2332 assert((int)i != -1);
2333 return getSubType(ST->getElementType(i), args...);
2334 }
2335 llvm::errs() << *T << "\n";
2336 llvm_unreachable("unknown subtype");
2337}
2338
2349 unsigned count = 0;
2350 bool all = true;
2351 bool derived = false;
2352 CountTrackedPointers(llvm::Type *T);
2353};
2354static inline bool isSpecialPtr(llvm::Type *Ty) {
2355 llvm::PointerType *PTy = llvm::dyn_cast<llvm::PointerType>(Ty);
2356 if (!PTy)
2357 return false;
2358 unsigned AS = PTy->getAddressSpace();
2360}
2361
2362#if LLVM_VERSION_MAJOR >= 20
2363bool collectOffset(
2364 llvm::GEPOperator *gep, const llvm::DataLayout &DL, unsigned BitWidth,
2365 llvm::SmallMapVector<llvm::Value *, llvm::APInt, 4> &VariableOffsets,
2366 llvm::APInt &ConstantOffset);
2367#else
2368bool collectOffset(llvm::GEPOperator *gep, const llvm::DataLayout &DL,
2369 unsigned BitWidth,
2370 llvm::MapVector<llvm::Value *, llvm::APInt> &VariableOffsets,
2371 llvm::APInt &ConstantOffset);
2372#endif
2373
2374llvm::CallInst *createIntrinsicCall(llvm::IRBuilderBase &B,
2375 llvm::Intrinsic::ID ID, llvm::Type *RetTy,
2376 llvm::ArrayRef<llvm::Value *> Args,
2377 llvm::Instruction *FMFSource = nullptr,
2378 const llvm::Twine &Name = "");
2379
2380bool isNVLoad(const llvm::Value *V);
2381
2382//! Check if value if b captured after definition before executing inst.
2383//! If checkLoadCaptured != 0, also consider catpures of any loads of the value
2384//! as a capture (for the number of loads set).
2385bool notCapturedBefore(llvm::Value *V, llvm::Instruction *inst,
2386 size_t checkLoadCaptured);
2387
2388//! Check if value if b captured
2389bool notCaptured(llvm::Value *V);
2390
2391// Return true if guaranteed not to alias
2392// Return false if guaranteed to alias [with possible offset depending on flag].
2393// Return {} if no information is given.
2394#if LLVM_VERSION_MAJOR >= 16
2395std::optional<bool>
2396#else
2397llvm::Optional<bool>
2398#endif
2399arePointersGuaranteedNoAlias(llvm::TargetLibraryInfo &TLI, llvm::AAResults &AA,
2400 llvm::LoopInfo &LI, llvm::Value *op0,
2401 llvm::Value *op1, bool offsetAllowed = false);
2402
2403static inline std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef>
2404tripleSplitDollar(llvm::StringRef caller) {
2405 if (!startsWith(caller, "ejl")) {
2406 return {"", caller, ""};
2407 }
2408 auto &&[prefix, todo] = caller.split("$");
2409 auto &&[name, postfix] = todo.split("$");
2410 return std::make_tuple(prefix, name, postfix);
2411}
2412
2413static inline std::string getRenamedPerCallingConv(llvm::StringRef caller,
2414 llvm::StringRef callee) {
2415 if (startsWith(caller, "ejl")) {
2416 auto &&[prefix, name, postfix] = tripleSplitDollar(caller);
2417 return (prefix + "$" + getRenamedPerCallingConv(name, callee) + "$" +
2418 postfix)
2419 .str();
2420 }
2421 if (startsWith(caller, "PMPI_")) {
2422 assert(startsWith(callee, "MPI"));
2423 return ("P" + callee).str();
2424 }
2425 return callee.str();
2426}
2427
2428static inline std::string convertSRetTypeToString(llvm::Type *T) {
2429 return std::to_string((size_t)T);
2430}
2431
2432static inline llvm::Type *
2433convertSRetTypeFromString(llvm::StringRef str, llvm::LLVMContext *C = nullptr) {
2434 if (str == "test_type") {
2435 assert(C);
2436 llvm::SmallVector<llvm::Type *, 1> elts;
2437#if LLVM_VERSION_MAJOR >= 17
2438 elts.push_back(llvm::PointerType::get(*C, AddressSpace::Tracked));
2439#else
2440 elts.push_back(llvm::PointerType::get(llvm::StructType::get(*C, {}),
2442#endif
2443 llvm::Type *inner = llvm::StructType::get(*C, elts);
2444 llvm::SmallVector<llvm::Type *, 1> innerElts;
2445 innerElts.push_back(inner);
2446 return llvm::StructType::get(*C, innerElts);
2447 }
2448 if (str == "test_type2") {
2449 assert(C);
2450 return llvm::ArrayType::get(llvm::Type::getInt64Ty(*C), 6);
2451 }
2452 if (str == "test_type3") {
2453 assert(C);
2454 llvm::SmallVector<llvm::Type *, 1> elts;
2455 elts.push_back(llvm::Type::getDoubleTy(*C));
2456 return llvm::StructType::get(*C, elts);
2457 }
2458 if (str == "test_type4") {
2459 assert(C);
2460 llvm::SmallVector<llvm::Type *, 3> elts;
2461 elts.push_back(llvm::ArrayType::get(llvm::Type::getDoubleTy(*C), 2));
2462 elts.push_back(llvm::Type::getDoubleTy(*C));
2463 elts.push_back(llvm::Type::getInt64Ty(*C));
2464 return llvm::StructType::get(*C, elts);
2465 }
2466 if (str == "test_type5") {
2467 assert(C);
2468 llvm::SmallVector<llvm::Type *, 3> elts;
2469 elts.push_back(llvm::ArrayType::get(llvm::Type::getDoubleTy(*C), 1));
2470 elts.push_back(llvm::Type::getDoubleTy(*C));
2471 elts.push_back(llvm::Type::getInt64Ty(*C));
2472 return llvm::StructType::get(*C, elts);
2473 }
2474 size_t idx;
2475 bool failed = str.consumeInteger(10, idx);
2476 (void)failed;
2477 assert(!failed);
2478 return (llvm::Type *)idx;
2479}
2480
2481static inline size_t convertRRootCountFromString(llvm::StringRef str) {
2482 size_t idx;
2483 bool failed = str.consumeInteger(10, idx);
2484 (void)failed;
2485 assert(!failed);
2486 return idx;
2487}
2488
2489static inline bool hasSRetRRootsOrUnionSRet(llvm::CallBase *CB) {
2490 if (CB->hasStructRetAttr())
2491 return true;
2492 for (size_t i = 0; i < CB->arg_size(); i++) {
2493 if (CB->getAttributeAtIndex(llvm::AttributeList::FirstArgIndex + i,
2494 "enzymejl_sret_union_bytes")
2495 .isValid())
2496 return true;
2497 if (CB->getAttributeAtIndex(llvm::AttributeList::FirstArgIndex + i,
2498 "enzymejl_returnRoots")
2499 .isValid())
2500 return true;
2501 }
2502 return false;
2503}
2504
2512
2513llvm::Value *moveSRetToFromRoots(llvm::IRBuilder<> &B, llvm::Type *jltype,
2514 llvm::Value *sret, llvm::Type *root_ty,
2515 llvm::Value *rootRet, size_t rootOffset,
2516 SRetRootMovement direction);
2517
2518void copyNonJLValueInto(llvm::IRBuilder<> &B, llvm::Type *curType,
2519 llvm::Type *dstType, llvm::Value *dst,
2520 llvm::ArrayRef<unsigned> dstPrefix, llvm::Type *srcType,
2521 llvm::Value *src, llvm::ArrayRef<unsigned> srcPrefix,
2522 bool shouldZero);
2523
2524static bool anyJuliaObjects(llvm::Type *T) {
2525 if (isSpecialPtr(T))
2526 return true;
2527 if (auto ST = llvm::dyn_cast<llvm::StructType>(T)) {
2528 for (auto elem : ST->elements()) {
2529 if (anyJuliaObjects(elem))
2530 return true;
2531 }
2532 return false;
2533 }
2534 if (auto AT = llvm::dyn_cast<llvm::ArrayType>(T)) {
2535 return anyJuliaObjects(AT->getElementType());
2536 }
2537 if (auto VT = llvm::dyn_cast<llvm::VectorType>(T)) {
2538 return anyJuliaObjects(VT->getElementType());
2539 }
2540 return false;
2541}
2542
2543llvm::SmallVector<llvm::Value *, 1> getJuliaObjects(llvm::Value *v,
2544 llvm::IRBuilder<> &B);
2545
2546// Find all user instructions of AI, returning tuples of <instruction, value,
2547// byte offet from AI> Unlike a simple get users, this will recurse through any
2548// constant gep offsets and casts
2549llvm::SmallVector<std::tuple<llvm::Instruction *, llvm::Value *, size_t>, 1>
2550findAllUsersOf(llvm::Value *AI);
2551
2552static bool hasTerminator(llvm::BasicBlock *BB) {
2553#if LLVM_VERSION_MAJOR >= 23
2554 return BB->hasTerminator();
2555#else
2556 return BB->getTerminator();
2557#endif
2558}
2559
2560#endif // ENZYME_UTILS_H
static std::string str(AugmentedStruct c)
Definition EnzymeLogic.h:62
static llvm::Value * checkedMul(bool strongZero, llvm::IRBuilder<> &Builder2, llvm::Value *idiff, llvm::Value *pres, const llvm::Twine &Name="")
Definition Utils.h:2018
llvm::Value * EmitNoDerivativeError(const std::string &message, llvm::Instruction &inst, GradientUtils *gutils, llvm::IRBuilder<> &B, llvm::Value *condition=nullptr)
Definition Utils.cpp:4295
static llvm::StringRef getFuncName(llvm::Function *called)
Definition Utils.h:1260
static bool isIntelSubscriptIntrinsic(const llvm::IntrinsicInst &II)
Definition Utils.h:1445
llvm::Function * getOrInsertDifferentialWaitallSave(llvm::Module &M, llvm::ArrayRef< llvm::Type * > T, llvm::PointerType *reqType)
static void allDomPredecessorsOf(llvm::Instruction *inst, llvm::DominatorTree &DT, llvm::function_ref< bool(llvm::Instruction *)> f)
Call the function f for all instructions that happen before inst If the function returns true,...
Definition Utils.h:941
llvm::Function * getOrInsertMemcpyMat(llvm::Module &M, llvm::Type *elementType, llvm::PointerType *PT, llvm::IntegerType *IT, unsigned dstalign, unsigned srcalign)
Turned out to be a faster alternatives to lapacks lacpy function.
llvm::Value * load_if_ref(llvm::IRBuilder<> &B, llvm::Type *intType, llvm::Value *V, bool byRef)
Definition Utils.cpp:4075
static bool isNoEscapingAllocation(const llvm::Function *F)
Definition Utils.h:1878
llvm::Value * is_nonunit(llvm::IRBuilder<> &B, llvm::Value *uplo, bool byRef, bool cublas)
static std::string getRenamedPerCallingConv(llvm::StringRef caller, llvm::StringRef callee)
Definition Utils.h:2413
void ZeroMemory(llvm::IRBuilder<> &Builder, llvm::Type *T, llvm::Value *obj, bool isTape)
Definition Utils.cpp:414
llvm::cl::opt< bool > EnzymeBlasCopy
llvm::SmallVector< llvm::Value *, 1 > get_blas_row(llvm::IRBuilder<> &B, llvm::ArrayRef< llvm::Value * > trans, llvm::ArrayRef< llvm::Value * > row, llvm::ArrayRef< llvm::Value * > col, bool byRef, bool cublas)
static llvm::Loop * getAncestor(llvm::Loop *R1, llvm::Loop *R2)
Definition Utils.h:1053
static llvm::Function * getDeallocatorFnFromCall(T *op)
Definition Utils.h:1360
BATCH_TYPE
Definition Utils.h:385
llvm::SmallVector< llvm::Instruction *, 2 > PostCacheStore(llvm::StoreInst *SI, llvm::IRBuilder<> &B)
Definition Utils.cpp:423
static llvm::Instruction * getNextNonDebugInstructionOrNull(llvm::Instruction *Z)
Get the next non-debug instruction, if one exists.
Definition Utils.h:318
static T max(T a, T b)
Pick the maximum value.
Definition Utils.h:262
static llvm::PointerType * getPointerType(llvm::Type *T, unsigned AddressSpace=0)
Definition Utils.h:1165
llvm::PointerType * getDefaultAnonymousTapeType(llvm::LLVMContext &C)
Definition Utils.cpp:437
void callSPMVDiagUpdate(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, llvm::IntegerType *IT, llvm::Type *BlasCT, llvm::Type *BlasFPT, llvm::Type *BlasPT, llvm::Type *BlasIT, llvm::Type *fpTy, llvm::ArrayRef< llvm::Value * > args, const llvm::ArrayRef< llvm::OperandBundleDef > bundles, bool byRef, bool julia_decl)
static void dumpMap(const llvm::ValueMap< T, N > &o, llvm::function_ref< bool(const llvm::Value *)> shouldPrint=[](T) { return true;})
Print a map, optionally with a shouldPrint function to decide to print a given value.
Definition Utils.h:286
bool attributeKnownFunctions(llvm::Function &F)
Definition Utils.cpp:114
static bool anyJuliaObjects(llvm::Type *T)
Definition Utils.h:2524
ReturnType
Potential return type of generated functions.
Definition Utils.h:355
@ Tape
Return is a tape type.
@ ArgsWithTwoReturns
Return is a struct of all args and two of the original return.
@ TapeAndReturn
Return is a tape type and the original return.
@ Args
Return is a struct of all args.
@ TapeAndTwoReturns
Return is a tape type and the two of the original return.
@ ArgsWithReturn
Return is a struct of all args and the original return.
void EmitWarningAlways(llvm::StringRef RemarkName, const llvm::Function &F, const Args &...args)
Definition Utils.h:166
static V * findInMap(std::map< K, V > &map, K key)
Definition Utils.h:855
llvm::Value * is_left(llvm::IRBuilder<> &B, llvm::Value *side, bool byRef, bool cublas)
static bool isNoAlias(const llvm::CallBase *call)
Definition Utils.h:1858
static llvm::Type * FloatToIntTy(llvm::Type *T)
Convert a floating point type to an integer type of the same size.
Definition Utils.h:641
static bool startsWith(llvm::StringRef string, llvm::StringRef prefix)
Definition Utils.h:713
DIFFE_TYPE
Potential differentiable argument classifications.
Definition Utils.h:374
llvm::Value * CreateReAllocation(llvm::IRBuilder<> &B, llvm::Value *prev, llvm::Type *T, llvm::Value *OuterCount, llvm::Value *InnerCount, const llvm::Twine &Name="", llvm::CallInst **caller=nullptr, bool ZeroMem=false)
Definition Utils.cpp:590
llvm::Function * getOrInsertDifferentialMPI_Wait(llvm::Module &M, llvm::ArrayRef< llvm::Type * > T, llvm::Type *reqType, llvm::StringRef caller)
Create function for type that performs the derivative MPI_Wait.
void EmitNoTypeError(const std::string &, llvm::Instruction &inst, GradientUtils *gutils, llvm::IRBuilder<> &B)
Definition Utils.cpp:4377
static size_t convertRRootCountFromString(llvm::StringRef str)
Definition Utils.h:2481
llvm::Value * CreateAllocation(llvm::IRBuilder<> &B, llvm::Type *T, llvm::Value *Count, const llvm::Twine &Name="", llvm::CallInst **caller=nullptr, llvm::Instruction **ZeroMem=nullptr, bool isDefault=false)
llvm::CallInst * getorInsertInnerProd(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, llvm::IntegerType *IT, llvm::Type *BlasPT, llvm::Type *BlasIT, llvm::Type *fpTy, llvm::ArrayRef< llvm::Value * > args, const llvm::ArrayRef< llvm::OperandBundleDef > bundles, bool byRef, bool cublas, bool julia_decl)
static void allPredecessorsOf(llvm::Instruction *inst, llvm::function_ref< bool(llvm::Instruction *)> f)
Call the function f for all instructions that happen before inst If the function returns true,...
Definition Utils.h:904
llvm::Value * get1ULP(llvm::IRBuilder<> &builder, llvm::Value *res)
Definition Utils.cpp:4272
void callMemcpyStridedLapack(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, llvm::ArrayRef< llvm::Value * > args, llvm::ArrayRef< llvm::OperandBundleDef > bundles)
Create function for type that performs memcpy using lapack copy.
Definition Utils.cpp:1336
bool notCapturedBefore(llvm::Value *V, llvm::Instruction *inst, size_t checkLoadCaptured)
Check if value if b captured after definition before executing inst.
static llvm::PointerType * getUnqual(llvm::Type *T)
Definition Utils.h:1179
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &B, llvm::Intrinsic::ID ID, llvm::Type *RetTy, llvm::ArrayRef< llvm::Value * > Args, llvm::Instruction *FMFSource=nullptr, const llvm::Twine &Name="")
Definition Utils.cpp:4233
llvm::Value * nextPowerOfTwo(llvm::IRBuilder<> &B, llvm::Value *V)
Create function to computer nearest power of two.
Definition Utils.cpp:2192
llvm::Function * getOrInsertCheckedFree(llvm::Module &M, llvm::CallInst *call, llvm::Type *Type, unsigned width)
llvm::Value * transpose(std::string floatType, llvm::IRBuilder<> &B, llvm::Value *V, bool cublas)
static llvm::Value * getMPIMemberPtr(llvm::IRBuilder<> &B, llvm::Value *V, llvm::Type *T)
Definition Utils.h:1200
static bool isCertainPrint(const llvm::StringRef name)
Definition Utils.h:729
llvm::Optional< bool > arePointersGuaranteedNoAlias(llvm::TargetLibraryInfo &TLI, llvm::AAResults &AA, llvm::LoopInfo &LI, llvm::Value *op0, llvm::Value *op1, bool offsetAllowed=false)
llvm::Value * SanitizeDerivatives(llvm::Value *val, llvm::Value *toset, llvm::IRBuilder<> &BuilderM, llvm::Value *mask=nullptr)
Definition Utils.cpp:3634
void addValueToCache(llvm::Value *arg, bool cache_arg, llvm::Type *ty, llvm::SmallVectorImpl< llvm::Value * > &cacheValues, llvm::IRBuilder<> &BuilderZ, const llvm::Twine &name="")
void EmitFailure(llvm::StringRef RemarkName, const llvm::DiagnosticLocation &Loc, const llvm::Instruction *CodeRegion, Args &...args)
Definition Utils.h:203
llvm::Optional< BlasInfo > extractBLAS(llvm::StringRef in)
Definition Utils.cpp:3563
static llvm::PointerType * getInt8PtrTy(llvm::LLVMContext &Context, unsigned AddressSpace=0)
Definition Utils.h:1174
static std::map< K, V >::iterator insert_or_assign2(std::map< K, V > &map, K key, V val)
Insert into a map.
Definition Utils.h:846
static llvm::Function * getIntrinsicDeclaration(llvm::Module *M, llvm::Intrinsic::ID id, llvm::ArrayRef< llvm::Type * > Tys={})
Definition Utils.h:2263
llvm::Value * is_normal(llvm::IRBuilder<> &B, llvm::Value *trans, bool byRef, bool cublas)
LLVMValueRef(* CustomErrorHandler)(const char *, LLVMValueRef, ErrorType, const void *, LLVMValueRef, LLVMBuilderRef)
Definition Utils.cpp:62
static bool isNoCapture(const llvm::CallBase *call, size_t idx)
Definition Utils.h:1840
static void addFunctionNoCapture(llvm::Function *call, size_t idx)
Definition Utils.h:2299
static llvm::Type * convertSRetTypeFromString(llvm::StringRef str, llvm::LLVMContext *C=nullptr)
Definition Utils.h:2433
bool notCaptured(llvm::Value *V)
Check if value if b captured.
Definition Utils.cpp:4608
static bool isDebugFunction(llvm::Function *called)
Definition Utils.h:690
static bool isReturned(llvm::Instruction *inst)
Check whether this instruction is returned.
Definition Utils.h:631
std::vector< std::tuple< llvm::Type *, size_t, size_t > > parseTrueType(const llvm::MDNode *, DerivativeMode, bool const_src)
Definition Utils.cpp:4411
static bool isPointerArithmeticInst(const llvm::Value *V, bool includephi=true, bool includebin=true)
Definition Utils.h:1456
static llvm::Value * getBaseObject(llvm::Value *V, bool offsetAllowed=true)
Definition Utils.h:1507
bool isNVLoad(const llvm::Value *V)
Definition Utils.cpp:4483
static llvm::Function * isCalledFunction(llvm::Value *val)
Definition Utils.h:239
static llvm::MDNode * hasMetadata(const llvm::GlobalObject *O, llvm::StringRef kind)
Check if a global has metadata.
Definition Utils.h:339
static bool containsOnlyAtMostTopBit(const llvm::Value *V, llvm::Type *FT, const llvm::DataLayout &dl, llvm::Type **vFT=nullptr)
Definition Utils.h:2047
bool writesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI, llvm::Instruction *maybeReader, llvm::Instruction *maybeWriter)
Return whether maybeReader can read from memory written to by maybeWriter.
Definition Utils.cpp:2880
llvm::Value * lookup_with_layout(llvm::IRBuilder<> &B, llvm::Type *fpType, llvm::Value *layout, llvm::Value *base, llvm::Value *lda, llvm::Value *row, llvm::Value *col)
static bool shouldDisableNoWrite(const llvm::CallInst *CI)
Definition Utils.h:1423
static llvm::SetVector< llvm::Value * > getBaseObjects(llvm::Value *V, bool offsetAllowed=true)
Definition Utils.h:1643
MPI_Elem
Definition Utils.h:1154
bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI, llvm::ScalarEvolution &SE, llvm::LoopInfo &LI, llvm::DominatorTree &DT, llvm::Instruction *maybeReader, llvm::Instruction *maybeWriter, llvm::Loop *scope=nullptr)
Definition Utils.cpp:2765
void EmitWarning(llvm::StringRef RemarkName, const llvm::DiagnosticLocation &Loc, const llvm::BasicBlock *BB, const Args &...args)
Definition Utils.h:133
static bool isReadOnlyOrThrow(const llvm::Function *F)
Definition Utils.h:1759
static bool isSpecialPtr(llvm::Type *Ty)
Definition Utils.h:2354
void callMemcpyStridedBlas(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, llvm::ArrayRef< llvm::Value * > args, llvm::Type *cublas_retty, llvm::ArrayRef< llvm::OperandBundleDef > bundles)
Create function for type that performs memcpy with a stride using blas copy.
Definition Utils.cpp:1316
ProbProgMode
Definition Utils.h:399
static bool hasNoCache(llvm::Value *op)
Definition Utils.h:1283
static llvm::StructType * getMPIHelper(llvm::LLVMContext &Context)
Definition Utils.h:1183
llvm::CallInst * CreateDealloc(llvm::IRBuilder<> &B, llvm::Value *ToFree)
Definition Utils.cpp:742
static bool isReadOnly(const llvm::Function *F, ssize_t arg=-1)
Definition Utils.h:1673
AddressSpace
Definition Utils.h:2339
@ CalleeRooted
Definition Utils.h:2343
@ Derived
Definition Utils.h:2342
@ LastSpecial
Definition Utils.h:2346
@ Tracked
Definition Utils.h:2341
@ FirstSpecial
Definition Utils.h:2345
@ Generic
Definition Utils.h:2340
@ Loaded
Definition Utils.h:2344
static llvm::Type * getSubType(llvm::Type *T)
Definition Utils.h:2323
static llvm::Instruction * getNextNonDebugInstruction(llvm::Instruction *Z)
Get the next non-debug instruction, erring if none exists.
Definition Utils.h:327
llvm::StringMap< std::function< llvm::Value *(llvm::IRBuilder<> &, llvm::CallInst *, llvm::ArrayRef< llvm::Value * >, GradientUtils *)> > shadowHandlers
static std::vector< ssize_t > getDeallocationIndicesFromCall(T *op)
Definition Utils.h:1381
llvm::Value * moveSRetToFromRoots(llvm::IRBuilder<> &B, llvm::Type *jltype, llvm::Value *sret, llvm::Type *root_ty, llvm::Value *rootRet, size_t rootOffset, SRetRootMovement direction)
Definition Utils.cpp:4714
static void allUnsyncdPredecessorsOf(llvm::Instruction *inst, llvm::function_ref< bool(llvm::Instruction *)> f, llvm::function_ref< void()> preEntry)
Call the function f for all instructions that happen before inst If the function returns true,...
Definition Utils.h:980
static void allInstructionsBetween(llvm::LoopInfo &LI, llvm::Instruction *inst1, llvm::Instruction *inst2, llvm::function_ref< bool(llvm::Instruction *)> f)
Call the function f for all instructions that happen between inst1 and inst2 If the function returns ...
Definition Utils.h:1099
void ErrorIfRuntimeInactive(llvm::IRBuilder<> &B, llvm::Value *primal, llvm::Value *shadow, const char *Message, llvm::DebugLoc &&loc, llvm::Instruction *orig)
Definition Utils.cpp:902
llvm::cl::opt< bool > EnzymeJuliaAddrLoad
MPI_CallType
Definition Utils.h:1149
static void dumpSet(const llvm::SmallPtrSetImpl< T * > &o)
Print a set.
Definition Utils.h:301
static llvm::Function * getFunctionFromCall(T *op)
Definition Utils.h:1236
llvm::Function * getOrInsertDifferentialFloatMemmove(llvm::Module &M, llvm::Type *T, unsigned dstalign, unsigned srcalign, unsigned dstaddr, unsigned srcaddr, unsigned bitwidth)
Create function for type that performs the derivative memmove on floating point memory.
llvm::Function * getOrInsertDifferentialFloatMemcpy(llvm::Module &M, llvm::Type *T, unsigned dstalign, unsigned srcalign, unsigned dstaddr, unsigned srcaddr, unsigned bitwidth)
Create function for type that performs the derivative memcpy on floating point memory.
bool collectOffset(llvm::GEPOperator *gep, const llvm::DataLayout &DL, unsigned BitWidth, llvm::MapVector< llvm::Value *, llvm::APInt > &VariableOffsets, llvm::APInt &ConstantOffset)
llvm::Value * get_cached_mat_width(llvm::IRBuilder<> &B, llvm::ArrayRef< llvm::Value * > trans, llvm::Value *arg_ld, llvm::Value *dim_1, llvm::Value *dim_2, bool cacheMat, bool byRef, bool cublas)
Definition Utils.cpp:4018
ErrorType
Definition Utils.h:77
@ IllegalFirstPointer
@ MixedActivityError
@ IllegalReplaceFicticiousPHIs
@ TypeDepthExceeded
@ IllegalTypeAnalysis
static llvm::Value * checkedDiv(bool strongZero, llvm::IRBuilder<> &Builder2, llvm::Value *idiff, llvm::Value *pres, const llvm::Twine &Name="")
Definition Utils.h:2032
static T min(T a, T b)
Pick the maximum value.
Definition Utils.h:268
static bool hasSRetRRootsOrUnionSRet(llvm::CallBase *CB)
Definition Utils.h:2489
static std::tuple< llvm::StringRef, llvm::StringRef, llvm::StringRef > tripleSplitDollar(llvm::StringRef caller)
Definition Utils.h:2404
llvm::SmallVector< llvm::Value *, 1 > getJuliaObjects(llvm::Value *v, llvm::IRBuilder<> &B)
Definition Utils.cpp:4913
SRetRootMovement
Definition Utils.h:2505
static void addCallSiteNoCapture(llvm::CallBase *call, size_t idx)
Definition Utils.h:2289
static llvm::Type * IntToFloatTy(llvm::Type *T)
Convert a integer type to a floating point type of the same size.
Definition Utils.h:665
static llvm::Instruction * getFirstNonPHIOrDbgOrLifetime(llvm::BasicBlock *B)
Definition Utils.h:2281
llvm::cl::opt< bool > EnzymePrintPerf
Print additional debug info relevant to performance.
llvm::Value * to_blas_callconv(llvm::IRBuilder<> &B, llvm::Value *V, bool byRef, bool cublas, llvm::IntegerType *julia_decl, llvm::IRBuilder<> &entryBuilder, llvm::Twine const &="")
llvm::Function * GetFunctionFromValue(llvm::Value *fn)
llvm::Value * is_lower(llvm::IRBuilder<> &B, llvm::Value *uplo, bool byRef, bool cublas)
static bool endsWith(llvm::StringRef string, llvm::StringRef suffix)
Definition Utils.h:721
static llvm::Value * CreateSelect(llvm::IRBuilder<> &Builder2, llvm::Value *cmp, llvm::Value *tval, llvm::Value *fval, const llvm::Twine &Name="")
Definition Utils.h:2005
llvm::Value * getOrInsertOpFloatSum(llvm::Module &M, llvm::Type *OpPtr, llvm::Type *OpType, ConcreteType CT, llvm::Type *intType, llvm::IRBuilder<> &B2)
static std::string convertSRetTypeToString(llvm::Type *T)
Definition Utils.h:2428
static llvm::SmallVector< llvm::Value *, 1 > concat_values(T &&...t)
Definition Utils.h:2161
static llvm::StringRef getFuncNameFromCall(const llvm::CallBase *op)
Definition Utils.h:1269
llvm::Value * to_blas_fp_callconv(llvm::IRBuilder<> &B, llvm::Value *V, bool byRef, llvm::Type *julia_decl, llvm::IRBuilder<> &entryBuilder, llvm::Twine const &="")
llvm::Function * getOrInsertMemcpyStrided(llvm::Module &M, llvm::Type *elementType, llvm::PointerType *T, llvm::Type *IT, unsigned dstalign, unsigned srcalign)
Create function for type that performs memcpy with a stride.
static bool isReadNone(const llvm::CallBase *call, ssize_t arg=-1)
Definition Utils.h:1832
static bool isWriteOnly(const llvm::Function *F, ssize_t arg=-1)
Definition Utils.h:1788
static bool isLocalReadOnlyOrThrow(const llvm::Function *F)
Definition Utils.h:1719
llvm::cl::opt< bool > EnzymeLapackCopy
ValueType
Classification of value as an original program variable, a derivative variable, neither,...
Definition Utils.h:409
static llvm::Instruction * getFirstNonPHIOrDbg(llvm::BasicBlock *B)
Definition Utils.h:2272
llvm::FastMathFlags getFast()
Get LLVM fast math flags.
Definition Utils.cpp:3731
static std::string to_string(const std::set< T > &us)
Output a set as a string.
Definition Utils.h:276
static llvm::Attribute::AttrKind PrimalParamAttrsToPreserve[]
Definition Utils.h:2208
void copyNonJLValueInto(llvm::IRBuilder<> &B, llvm::Type *curType, llvm::Type *dstType, llvm::Value *dst, llvm::ArrayRef< unsigned > dstPrefix, llvm::Type *srcType, llvm::Value *src, llvm::ArrayRef< unsigned > srcPrefix, bool shouldZero)
Definition Utils.cpp:4837
void mayExecuteAfter(llvm::SmallVectorImpl< llvm::Instruction * > &results, llvm::Instruction *inst, const llvm::SmallPtrSetImpl< llvm::Instruction * > &stores, const llvm::Loop *region)
llvm::cl::opt< bool > EnzymeNonPower2Cache
static llvm::Attribute::AttrKind ShadowParamAttrsToPreserve[]
Definition Utils.h:2240
llvm::SmallVector< std::tuple< llvm::Instruction *, llvm::Value *, size_t >, 1 > findAllUsersOf(llvm::Value *AI)
static llvm::raw_ostream & operator<<(llvm::raw_ostream &os, ValueType mode)
Definition Utils.h:435
static bool hasTerminator(llvm::BasicBlock *BB)
Definition Utils.h:2552
static std::map< K, V >::iterator insert_or_assign(std::map< K, V > &map, K &key, V &&val)
Insert into a map.
Definition Utils.h:835
llvm::Constant * getUndefinedValueForType(llvm::Module &M, llvm::Type *T, bool forceZero=false)
Definition Utils.cpp:3623
static void allFollowersOf(llvm::Instruction *inst, llvm::function_ref< bool(llvm::Instruction *)> f)
Call the function f for all instructions that happen after inst If the function returns true,...
Definition Utils.h:869
DerivativeMode
Definition Utils.h:390
void copy_lower_to_upper(llvm::IRBuilder<> &B, llvm::Type *fpType, BlasInfo blas, bool byRef, llvm::Value *layout, llvm::Value *uplo, llvm::Value *A, llvm::Value *lda, llvm::Value *N)
Definition Utils.cpp:1184
static void append(llvm::SmallVectorImpl< T > &vec)
Definition Utils.h:2153
static llvm::Optional< size_t > getAllocationIndexFromCall(const llvm::CallBase *op)
Definition Utils.h:1318
llvm::Value * simplifyLoad(llvm::Value *LI, size_t valSz=0, size_t preOffset=0)
static DIFFE_TYPE whatType(llvm::Type *arg, DerivativeMode mode, bool integersAreConstant, std::set< llvm::Type * > &seen)
Attempt to automatically detect the differentiable classification based off of a given type.
Definition Utils.h:519
llvm::Function * getFirstFunctionDefinition(llvm::Module &M)
llvm::Function * getOrInsertDifferentialFloatMemcpyMat(llvm::Module &M, llvm::Type *elementType, llvm::PointerType *PT, llvm::IntegerType *IT, llvm::IntegerType *CT, unsigned dstalign, unsigned srcalign, bool zeroSrc)
void deleted() override final
Definition Utils.h:1225
void allUsesReplacedWith(llvm::Value *new_value) override final
Definition Utils.h:1230
virtual ~AssertingReplacingVH()
Definition Utils.h:1233
AssertingReplacingVH()=default
AssertingReplacingVH(llvm::Value *new_value)
Definition Utils.h:1223
Concrete SubType of a given value.
EnzymeFailure(const llvm::Twine &Msg, const llvm::DiagnosticLocation &Loc, const llvm::Instruction *CodeRegion)
Definition Utils.cpp:785
EnzymeWarning(const llvm::Twine &Msg, const llvm::DiagnosticLocation &Loc, const llvm::Instruction *CodeRegion)
Definition Utils.cpp:775
bool getContext(llvm::BasicBlock *BB, LoopContext &lc)
A holder class representing the results of running TypeAnalysis on a given function.
llvm::Type * fpType(llvm::LLVMContext &ctx, bool to_scalar=false) const
Definition Utils.cpp:981
std::string function
Definition Utils.h:748
bool is64
Definition Utils.h:749
std::string suffix
Definition Utils.h:747
std::string prefix
Definition Utils.h:746
std::string floatType
Definition Utils.h:745
llvm::IntegerType * intType(llvm::LLVMContext &ctx) const
Definition Utils.cpp:1000
CountTrackedPointers(llvm::Type *T)
Definition Utils.cpp:4136