26#if LLVM_VERSION_MAJOR >= 16
28#include "llvm/Analysis/ScalarEvolution.h"
29#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
32#include "SCEV/ScalarEvolution.h"
33#include "SCEV/ScalarEvolutionExpander.h"
38#include "EnzymeLogic.h"
39#include "GradientUtils.h"
41#if LLVM_VERSION_MAJOR >= 16
42#include "llvm/Analysis/TargetLibraryInfo.h"
44#include "SCEV/TargetLibraryInfo.h"
49#include "llvm/Analysis/CallGraph.h"
50#include "llvm/Analysis/GlobalsModRef.h"
51#include "llvm/IR/DIBuilder.h"
52#include "llvm/IR/MDBuilder.h"
53#include "llvm/Transforms/Utils/Cloning.h"
55#include "llvm/IR/LegacyPassManager.h"
56#include "llvm/Transforms/IPO/Attributor.h"
58#define addAttribute addAttributeAtIndex
59#define removeAttribute removeAttributeAtIndex
60#define getAttribute getAttributeAtIndex
61#define hasAttribute hasAttributeAtIndex
65TargetLibraryInfo
eunwrap(LLVMTargetLibraryInfoRef P) {
66 return TargetLibraryInfo(*
reinterpret_cast<TargetLibraryInfoImpl *
>(P));
108 llvm_unreachable(
"Unknown concrete type to unwrap");
113 for (
size_t i = 0; i < IL.
size; i++) {
114 v.push_back((
int)IL.
data[i]);
120 for (
size_t i = 0; i < IL.
size; i++) {
121 v.insert((int64_t)IL.
data[i]);
131 if (flt->isFloatTy())
133 if (flt->isDoubleTy())
135 if (flt->isX86_FP80Ty())
137 if (flt->isBFloatTy())
139 if (flt->isFP128Ty())
152 llvm_unreachable(
"Illegal conversion of concretetype");
155 llvm_unreachable(
"Illegal conversion of concretetype");
160 IL.
size = offsets.size();
162 for (
size_t i = 0; i < offsets.size(); i++) {
163 IL.
data[i] = offsets[i];
178 for (
auto &arg : F->args()) {
189 auto cl = (llvm::cl::opt<bool> *)ptr;
190 cl->setValue((
bool)val);
194 auto cl = (llvm::cl::opt<bool> *)ptr;
195 return (uint8_t)(bool)cl->getValue();
199 auto cl = (llvm::cl::opt<int> *)ptr;
200 cl->setValue((
int)val);
204 auto cl = (llvm::cl::opt<int> *)ptr;
205 return (int64_t)cl->getValue();
209 if (
auto *clopt =
static_cast<cl::opt<std::string> *
>(ptr))
210 clopt->setValue(val);
218 eunwrap(Ref).ExternalContext = ExternalContext;
222 return eunwrap(Ref).ExternalContext;
230 LLVMContextRef C, LLVMValueRef getTraceFunction,
231 LLVMValueRef getChoiceFunction, LLVMValueRef insertCallFunction,
232 LLVMValueRef insertChoiceFunction, LLVMValueRef insertArgumentFunction,
233 LLVMValueRef insertReturnFunction, LLVMValueRef insertFunctionFunction,
234 LLVMValueRef insertChoiceGradientFunction,
235 LLVMValueRef insertArgumentGradientFunction, LLVMValueRef newTraceFunction,
236 LLVMValueRef freeTraceFunction, LLVMValueRef hasCallFunction,
237 LLVMValueRef hasChoiceFunction) {
239 *unwrap(C), cast<Function>(unwrap(getTraceFunction)),
240 cast<Function>(unwrap(getChoiceFunction)),
241 cast<Function>(unwrap(insertCallFunction)),
242 cast<Function>(unwrap(insertChoiceFunction)),
243 cast<Function>(unwrap(insertArgumentFunction)),
244 cast<Function>(unwrap(insertReturnFunction)),
245 cast<Function>(unwrap(insertFunctionFunction)),
246 cast<Function>(unwrap(insertChoiceGradientFunction)),
247 cast<Function>(unwrap(insertArgumentGradientFunction)),
248 cast<Function>(unwrap(newTraceFunction)),
249 cast<Function>(unwrap(freeTraceFunction)),
250 cast<Function>(unwrap(hasCallFunction)),
251 cast<Function>(unwrap(hasChoiceFunction))));
257 unwrap(interface), cast<Function>(unwrap(F))));
264 for (
const auto &pair : Logic.PPC.cache)
265 pair.second->eraseFromParent();
275 char **customRuleNames,
280 for (
size_t i = 0; i < numRules; i++) {
283 [=](
int direction,
TypeTree &returnTree, ArrayRef<TypeTree> argTrees,
284 ArrayRef<std::set<int64_t>> knownValues, CallBase *call,
289 for (
size_t i = 0; i < argTrees.size(); ++i) {
291 kvs[i].
size = knownValues[i].size();
292 kvs[i].
data =
new int64_t[kvs[i].
size];
294 for (
auto val : knownValues[i]) {
295 kvs[i].data[j] = val;
299 uint8_t result = rule(direction, creturnTree, cargs, kvs, argTrees.
size(),
302 for (
size_t i = 0; i < argTrees.size(); ++i) {
303 delete[] kvs[i].
data;
326 return (
void *)((
TypeAnalysis *)TAR)->analyzeFunction(FTI).analyzer;
338 return G->
erase(cast<Instruction>(unwrap(I)));
341 LLVMValueRef orig, uint8_t erase) {
343 cast<Instruction>(unwrap(orig)),
344 "_replacementABI", erase != 0);
355 ArrayRef<Value *>
Args,
357 SmallVector<LLVMValueRef, 3> refs;
359 refs.push_back(wrap(a));
361 AHandle(wrap(&B), wrap(CI),
Args.size(), refs.data(), gutils));
365 Value *ToFree) -> llvm::CallInst * {
366 return cast_or_null<CallInst>(unwrap(FHandle(wrap(&B), wrap(ToFree))));
374 pair.first = [=](IRBuilder<> &B, CallInst *CI,
GradientUtils &gutils,
375 Value *&normalReturn, Value *&shadowReturn,
376 Value *&tape) ->
bool {
377 LLVMValueRef normalR = wrap(normalReturn);
378 LLVMValueRef shadowR = wrap(shadowReturn);
379 LLVMValueRef tapeR = wrap(tape);
381 FwdHandle(wrap(&B), wrap(CI), &gutils, &normalR, &shadowR, &tapeR);
382 normalReturn = unwrap(normalR);
383 shadowReturn = unwrap(shadowR);
384 tape = unwrap(tapeR);
389 RevHandle(wrap(&B), wrap(CI), &gutils, wrap(tape));
395 pair = [=](IRBuilder<> &B, CallInst *CI,
GradientUtils &gutils,
396 Value *&normalReturn, Value *&shadowReturn) ->
bool {
397 LLVMValueRef normalR = wrap(normalReturn);
398 LLVMValueRef shadowR = wrap(shadowReturn);
399 uint8_t noMod = FwdHandle(wrap(&B), wrap(CI), &gutils, &normalR, &shadowR);
400 normalReturn = unwrap(normalR);
401 shadowReturn = unwrap(shadowR);
409 pair = [=](
const CallInst *CI,
const GradientUtils *gutils,
const Value *arg,
411 uint8_t useDefaultC = 0;
412 uint8_t noMod = Handle(wrap(CI), gutils, wrap(arg), isshadow,
414 useDefault = useDefaultC != 0;
455 auto orig = cast<Instruction>(unwrap(origC));
456 auto rep = cast<Instruction>(unwrap(repC));
459 auto newCall = found->second;
471 uint8_t foreignFunction) {
477 uint8_t *needsPrimal,
478 uint8_t *needsShadow,
483 unwrap(oval), &needsPrimalB, &needsShadowB, (
DerivativeMode)mode));
485 *needsPrimal = needsPrimalB;
487 *needsShadow = needsShadowB;
494 return cast<Instruction>(unwrap(val))
496 cast<Instruction>(unwrap(orig))->getDebugLoc()));
500 LLVMValueRef val2,
unsigned *sz, int64_t length,
502 return wrap(unwrap(B)->CreateInsertValue(
503 unwrap(val), unwrap(val2), ArrayRef<unsigned>(sz, sz + length), name));
508 return wrap(gutils->
lookupM(unwrap(val), *unwrap(B)));
518 LLVMValueRef val, LLVMBuilderRef B) {
519 return wrap(gutils->
diffe(unwrap(val), *unwrap(B)));
523 LLVMValueRef diffe, LLVMBuilderRef B,
525 gutils->
addToDiffe(unwrap(val), unwrap(diffe), *unwrap(B), unwrap(T));
530 LLVMTypeRef addingType,
unsigned start,
unsigned size, LLVMValueRef origptr,
531 LLVMValueRef dif, LLVMBuilderRef BuilderM,
unsigned align,
535 align2 = MaybeAlign(align);
536 auto inst = cast_or_null<Instruction>(unwrap(orig));
538 start, size, unwrap(origptr), unwrap(dif),
539 *unwrap(BuilderM), align2, unwrap(mask));
544 CTypeTreeRef vd,
unsigned LoadSize, LLVMValueRef origptr,
545 LLVMValueRef prediff, LLVMBuilderRef BuilderM,
unsigned align,
546 LLVMValueRef premask) {
549 align2 = MaybeAlign(align);
550 auto inst = cast_or_null<Instruction>(unwrap(orig));
552 LoadSize, unwrap(origptr), unwrap(prediff),
553 *unwrap(BuilderM), align2, unwrap(premask));
557 LLVMValueRef diffe, LLVMBuilderRef B) {
558 gutils->
setDiffe(unwrap(val), unwrap(diffe), *unwrap(B));
576 LLVMValueRef orig, uint8_t *data,
585 CallInst *call = cast<CallInst>(unwrap(orig));
590 llvm::errs() <<
" oldFunc " << *gutils->
oldFunc <<
"\n";
592 llvm::errs() <<
" + " << *pair.first <<
"\n";
594 llvm::errs() <<
" could not find call orig in overwritten_args_map_ptr "
599 const std::vector<bool> &overwritten_args = found->second.second;
601 if (size != overwritten_args.size()) {
602 llvm::errs() <<
" orig: " << *call <<
"\n";
603 llvm::errs() <<
" size: " << size
604 <<
" overwritten_args.size(): " << overwritten_args.size()
607 assert(size == overwritten_args.size());
608 for (uint64_t i = 0; i < size; i++) {
609 data[i] = overwritten_args[i];
616 auto v = unwrap(val);
628 uint64_t intrinsic, uint64_t dstAlign, uint64_t srcAlign, uint64_t offset,
629 uint8_t dstConstant, LLVMValueRef shadow_dst, uint8_t srcConstant,
630 LLVMValueRef shadow_src, LLVMValueRef length, LLVMValueRef isVolatile,
631 LLVMValueRef MTI, uint8_t allowForward, uint8_t shadowsLookedUp) {
632 auto orig = unwrap(MTI);
635 (Intrinsic::ID)intrinsic, (
unsigned)dstAlign,
636 (
unsigned)srcAlign, (
unsigned)offset, (
bool)dstConstant,
637 unwrap(shadow_dst), (
bool)srcConstant, unwrap(shadow_src),
638 unwrap(length), unwrap(isVolatile), cast<CallInst>(orig),
639 (
bool)allowForward, (
bool)shadowsLookedUp);
643 LLVMBasicBlockRef block,
647 return wrap(gutils->
addReverseBlock(cast<BasicBlock>(unwrap(block)), name,
652 LLVMBasicBlockRef block) {
653 auto endBlock = cast<BasicBlock>(unwrap(block));
658 vec.push_back(endBlock);
662 EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip,
666 uint8_t strongZero,
unsigned width, LLVMTypeRef additionalArg,
667 CFnTypeInfo typeInfo, uint8_t subsequent_calls_may_write,
668 uint8_t *_overwritten_args,
size_t overwritten_args_size,
670 SmallVector<DIFFE_TYPE, 4> nconstant_args((
DIFFE_TYPE *)constant_args,
673 std::vector<bool> overwritten_args;
674 assert(overwritten_args_size == cast<Function>(unwrap(todiff))->arg_size());
675 for (uint64_t i = 0; i < overwritten_args_size; i++) {
676 overwritten_args.push_back(_overwritten_args[i]);
678 return wrap(
eunwrap(Logic).CreateForwardDiff(
681 cast<Function>(unwrap(todiff)), (
DIFFE_TYPE)retType, nconstant_args,
683 runtimeActivity, strongZero, width, unwrap(additionalArg),
684 eunwrap(typeInfo, cast<Function>(unwrap(todiff))),
685 subsequent_calls_may_write, overwritten_args,
eunwrap(augmented)));
688 EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip,
692 uint8_t strongZero,
unsigned width, uint8_t freeMemory,
693 LLVMTypeRef additionalArg, uint8_t forceAnonymousTape,
CFnTypeInfo typeInfo,
694 uint8_t subsequent_calls_may_write, uint8_t *_overwritten_args,
697 std::vector<DIFFE_TYPE> nconstant_args((
DIFFE_TYPE *)constant_args,
700 std::vector<bool> overwritten_args;
701 assert(overwritten_args_size == cast<Function>(unwrap(todiff))->arg_size());
702 for (uint64_t i = 0; i < overwritten_args_size; i++) {
703 overwritten_args.push_back(_overwritten_args[i]);
705 return wrap(
eunwrap(Logic).CreatePrimalAndGradient(
709 .todiff = cast<Function>(unwrap(todiff)),
711 .constant_args = nconstant_args,
712 .subsequent_calls_may_write = (
bool)subsequent_calls_may_write,
713 .overwritten_args = overwritten_args,
714 .returnUsed = (bool)returnValue,
715 .shadowReturnUsed = (
bool)dretUsed,
718 .freeMemory = (
bool)freeMemory,
719 .AtomicAdd = (bool)AtomicAdd,
720 .additionalType = unwrap(additionalArg),
721 .forceAnonymousTape = (bool)forceAnonymousTape,
722 .typeInfo =
eunwrap(typeInfo, cast<Function>(unwrap(todiff))),
723 .runtimeActivity = (bool)runtimeActivity,
724 .strongZero = (
bool)strongZero},
728 EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip,
732 uint8_t subsequent_calls_may_write, uint8_t *_overwritten_args,
733 size_t overwritten_args_size, uint8_t forceAnonymousTape,
734 uint8_t runtimeActivity, uint8_t strongZero,
unsigned width,
737 SmallVector<DIFFE_TYPE, 4> nconstant_args((
DIFFE_TYPE *)constant_args,
740 std::vector<bool> overwritten_args;
741 assert(overwritten_args_size == cast<Function>(unwrap(todiff))->arg_size());
742 for (uint64_t i = 0; i < overwritten_args_size; i++) {
743 overwritten_args.push_back(_overwritten_args[i]);
748 cast<Function>(unwrap(todiff)), (
DIFFE_TYPE)retType, nconstant_args,
749 eunwrap(TA), returnUsed, shadowReturnUsed,
750 eunwrap(typeInfo, cast<Function>(unwrap(todiff))),
751 subsequent_calls_may_write, overwritten_args, forceAnonymousTape,
752 runtimeActivity, strongZero, width, AtomicAdd));
756 LLVMBuilderRef request_ip, LLVMValueRef tobatch,
760 return wrap(
eunwrap(Logic).CreateBatch(
763 cast<Function>(unwrap(tobatch)), width,
770 EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip,
771 LLVMValueRef totrace, LLVMValueRef *sample_functions,
772 size_t sample_functions_size, LLVMValueRef *observe_functions,
773 size_t observe_functions_size,
const char *active_random_variables[],
774 size_t active_random_variables_size,
CProbProgMode mode, uint8_t autodiff,
777 SmallPtrSet<Function *, 4> SampleFunctions;
778 for (
size_t i = 0; i < sample_functions_size; i++) {
779 SampleFunctions.insert(cast<Function>(unwrap(sample_functions[i])));
782 SmallPtrSet<Function *, 4> ObserveFunctions;
783 for (
size_t i = 0; i < observe_functions_size; i++) {
784 ObserveFunctions.insert(cast<Function>(unwrap(observe_functions[i])));
787 StringSet<> ActiveRandomVariables;
788 for (
size_t i = 0; i < active_random_variables_size; i++) {
789 ActiveRandomVariables.insert(active_random_variables[i]);
792 return wrap(
eunwrap(Logic).CreateTrace(
795 cast<Function>(unwrap(totrace)), SampleFunctions, ObserveFunctions,
796 ActiveRandomVariables, (
ProbProgMode)mode, (
bool)autodiff,
809 return wrap(AR->tapeType);
816 if (found == AR->returns.end()) {
817 return wrap((Type *)
nullptr);
819 if (found->second == -1) {
820 return wrap(AR->fn->getReturnType());
823 cast<StructType>(AR->fn->getReturnType())->getTypeAtIndex(found->second));
826 uint8_t *existed,
size_t len) {
831 for (
size_t i = 0; i < len; i++) {
832 auto found = AR->returns.find(todo[i]);
833 if (found != AR->returns.end()) {
835 data[i] = (int64_t)found->second;
843 Metadata *MD = MAV->getMetadata();
844 assert((isa<MDNode>(MD) || isa<ConstantAsMetadata>(MD)) &&
845 "Expected a metadata node or a canonicalized constant");
847 if (MDNode *N = dyn_cast<MDNode>(MD))
850 return MDNode::get(MAV->getContext(), MD);
855 MDNode *N = Val ?
extractMDNode(unwrap<MetadataAsValue>(Val)) :
nullptr;
861 auto MD = ((
TypeTree *)CTR)->toMD(*unwrap(ctx));
862 return wrap(MetadataAsValue::get(MD->getContext(), MD));
884 ->checkedOrIn(*(
TypeTree *)src,
false, legal);
902 ((
TypeTree *)CTT)->CanonicalizeInPlace(size, DataLayout(dl));
910 int64_t offset, int64_t maxSize,
911 uint64_t addOffset) {
912 DataLayout DL(datalayout);
914 ((
TypeTree *)CTT)->ShiftIndices(DL, offset, maxSize, addOffset);
918 std::vector<int> seq;
919 for (
size_t i = 0; i < len; i++) {
920 seq.push_back(indices[i]);
925 std::string tmp = ((
TypeTree *)src)->str();
926 char *cstr =
new char[tmp.length() + 1];
927 std::strcpy(cstr, tmp.c_str());
938 raw_string_ostream ss(
str);
941 char *cstr =
new char[
str.length() + 1];
942 std::strcpy(cstr,
str.c_str());
949 raw_string_ostream ss(
str);
951 ss <<
"available inversion for " << *z.first <<
" of " << *z.second <<
"\n";
954 char *cstr =
new char[
str.length() + 1];
955 std::strcpy(cstr,
str.c_str());
960 GradientUtils *gutils, LLVMValueRef func, LLVMTypeRef funcTy,
961 LLVMValueRef *args_vr, uint64_t args_size, LLVMValueRef orig_vr,
962 CValueType *valTys, uint64_t valTys_size, LLVMBuilderRef B,
964 auto orig = cast<CallInst>(unwrap(orig_vr));
966 ArrayRef<ValueType> ar((
ValueType *)valTys, valTys_size);
968 IRBuilder<> &BR = *unwrap(B);
972 SmallVector<Value *, 1> args;
973 for (
size_t i = 0; i < args_size; i++) {
974 args.push_back(unwrap(args_vr[i]));
977 auto callval = unwrap(func);
980 BR.CreateCall(cast<FunctionType>(unwrap(funcTy)), callval, args, Defs);
988 Instruction *I1 = cast<Instruction>(unwrap(inst1));
989 Instruction *I2 = cast<Instruction>(unwrap(inst2));
992 IRBuilder<> &BR = *unwrap(B);
993 if (I1->getIterator() == BR.GetInsertPoint()) {
994 if (I2->getNextNode() ==
nullptr)
995 BR.SetInsertPoint(I1->getParent());
997 BR.SetInsertPoint(I1->getNextNode());
1005 MDNode *N = Val ?
extractMDNode(unwrap<MetadataAsValue>(Val)) :
nullptr;
1006 Value *V = unwrap(Inst);
1007 if (
auto I = dyn_cast<Instruction>(V))
1008 I->setMetadata(Kind, N);
1010 cast<GlobalVariable>(V)->setMetadata(Kind, N);
1014 auto *I = unwrap<Instruction>(Inst);
1015 assert(I &&
"Expected instruction");
1016 if (
auto *MD = I->getMetadata(Kind))
1017 return wrap(MetadataAsValue::get(I->getContext(), MD));
1022 Instruction *I1 = cast<Instruction>(unwrap(inst1));
1023 I1->setMetadata(
"enzyme_mustcache", MDNode::get(I1->getContext(), {}));
1027 Instruction *I1 = cast<Instruction>(unwrap(inst1));
1032 auto &OldFunc = *cast<Function>(unwrap(F));
1033 auto &NewFunc = *cast<Function>(unwrap(NF));
1034 auto OldSP = OldFunc.getSubprogram();
1037 DIBuilder DIB(*OldFunc.getParent(),
false,
1039 auto SPType = DIB.createSubroutineType(DIB.getOrCreateTypeArray({}));
1040 DISubprogram::DISPFlags SPFlags = DISubprogram::SPFlagDefinition |
1041 DISubprogram::SPFlagOptimized |
1042 DISubprogram::SPFlagLocalToUnit;
1043 auto NewSP = DIB.createFunction(
1044 OldSP->getUnit(), NewFunc.getName(), NewFunc.getName(), OldSP->getFile(),
1045 0, SPType, 0, DINode::FlagZero, SPFlags);
1046 NewFunc.setSubprogram(NewSP);
1047 DIB.finalizeSubprogram(NewSP);
1060 llvm::errs() << *unwrap(M) <<
"\n";
1068 SetVector<Function *> &Functions,
1070 CallGraphUpdater &CGUpdater,
1071 bool DeleteFns,
bool IsModulePass) {
1072 if (Functions.empty())
1077 AttributorConfig AC(CGUpdater);
1078 AC.RewriteSignatures =
false;
1079 AC.IsModulePass = IsModulePass;
1080 AC.DeleteFns = DeleteFns;
1081 Attributor A(Functions, InfoCache, AC);
1083 for (Function *F : Functions) {
1086 A.identifyDefaultAbstractAttributes(*F);
1089 ChangeStatus Changed = A.run();
1091 return Changed == ChangeStatus::CHANGED;
1095 auto &M = *unwrap(M0);
1097 SetVector<Function *> Functions;
1098 for (Function &F : M)
1099 Functions.insert(&F);
1101 CallGraphUpdater CGUpdater;
1102 BumpPtrAllocator Allocator;
1103 InformationCache InfoCache(M, AG, Allocator,
nullptr);
1119 SetVector<Function *> Functions;
1120 for (Function &F : M)
1121 Functions.insert(&F);
1123 CallGraphUpdater CGUpdater;
1124 BumpPtrAllocator Allocator;
1125 InformationCache InfoCache(M, AG, Allocator,
nullptr);
1133 AU.addRequired<TargetLibraryInfoWrapperPass>();
1142 auto M = cast<MDNode>(unwrap(MD));
1143 if (M->getNumOperands() != 4)
1145 auto CAM = dyn_cast<ConstantAsMetadata>(M->getOperand(3));
1148 if (!CAM->getValue()->isOneValue())
1150 SmallVector<Metadata *, 4> MDs;
1151 for (
auto &M : M->operands())
1154 ConstantAsMetadata::get(ConstantInt::get(CAM->getValue()->getType(), 0));
1155 return wrap(MDNode::get(M->getContext(), MDs));
1158 cast<Instruction>(unwrap(inst1))
1159 ->copyMetadata(*cast<Instruction>(unwrap(inst2)));
1162 cast<AllocaInst>(unwrap(inst1))
1163 ->setAlignment(cast<AllocaInst>(unwrap(inst2))->getAlign());
1166 unwrap(inst1)->takeName(unwrap(inst2));
1170 LLVMContextRef ctx) {
1171 MDBuilder MDB(*unwrap(ctx));
1172 MDNode *scope = MDB.createAnonymousAliasScopeDomain(
str);
1177 auto dom = cast<MDNode>(unwrap(domain));
1178 MDBuilder MDB(dom->getContext());
1179 MDNode *scope = MDB.createAnonymousAliasScope(dom,
str);
1191 uint64_t *argrem, uint64_t num_argrem) {
1192 auto CI = cast<CallInst>(unwrap(C_CI));
1193 auto F = cast<Function>(unwrap(C_F));
1194 auto Attrs = CI->getAttributes();
1195 AttributeList NewAttrs;
1197 if (CI->getType() == F->getReturnType()) {
1198 for (
auto attr : Attrs.getAttributes(AttributeList::ReturnIndex))
1199 NewAttrs = NewAttrs.addAttribute(F->getContext(),
1200 AttributeList::ReturnIndex, attr);
1202 for (
auto attr : Attrs.getAttributes(AttributeList::FunctionIndex))
1203 NewAttrs = NewAttrs.addAttribute(F->getContext(),
1204 AttributeList::FunctionIndex, attr);
1206 size_t argremsz = 0;
1208 SmallVector<Value *, 1> vals;
1209 for (
size_t i = 0, end = CI->arg_size(); i < end; i++) {
1210 if (argremsz < num_argrem) {
1211 if (i == argrem[argremsz]) {
1216 for (
auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + i))
1217 NewAttrs = NewAttrs.addAttribute(
1218 F->getContext(), AttributeList::FirstArgIndex + nexti, attr);
1219 vals.push_back(CI->getArgOperand(i));
1222 assert(argremsz == num_argrem);
1225 SmallVector<OperandBundleDef, 1> Bundles;
1226 for (
unsigned I = 0, E = CI->getNumOperandBundles(); I != E; ++I)
1227 Bundles.emplace_back(CI->getOperandBundleAt(I));
1228 auto NC = B.CreateCall(F, vals, Bundles);
1229 NC->setAttributes(NewAttrs);
1230 NC->copyMetadata(*CI);
1232 if (CI->getType() == F->getReturnType())
1233 CI->replaceAllUsesWith(NC);
1235 if (!NC->getType()->isVoidTy())
1237 NC->setCallingConv(CI->getCallingConv());
1238 CI->eraseFromParent();
1243 uint8_t keepReturnU,
1245 uint64_t num_argrem) {
1246 auto F = cast<Function>(unwrap(FC));
1247 auto FT = F->getFunctionType();
1248 bool keepReturn = keepReturnU != 0;
1250 size_t argremsz = 0;
1252 SmallVector<Type *, 1> types;
1253 auto Attrs = F->getAttributes();
1254 AttributeList NewAttrs;
1255 for (
size_t i = 0, end = FT->getNumParams(); i < end; i++) {
1256 if (argremsz < num_argrem) {
1257 if (i == argrem[argremsz]) {
1262 for (
auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + i))
1263 NewAttrs = NewAttrs.addAttribute(
1264 F->getContext(), AttributeList::FirstArgIndex + nexti, attr);
1265 types.push_back(F->getFunctionType()->getParamType(i));
1269 for (
auto attr : Attrs.getAttributes(AttributeList::ReturnIndex))
1270 NewAttrs = NewAttrs.addAttribute(F->getContext(),
1271 AttributeList::ReturnIndex, attr);
1273 for (
auto attr : Attrs.getAttributes(AttributeList::FunctionIndex))
1274 NewAttrs = NewAttrs.addAttribute(F->getContext(),
1275 AttributeList::FunctionIndex, attr);
1277 FunctionType *FTy = FunctionType::get(
1278 keepReturn ? F->getReturnType() : Type::getVoidTy(F->getContext()), types,
1282 Function *NewF = Function::Create(FTy, F->getLinkage(), F->getAddressSpace(),
1283 F->getName(), F->getParent());
1285 ValueToValueMapTy VMap;
1289 Function::arg_iterator DestI = NewF->arg_begin();
1290 for (
const Argument &I : F->args()) {
1291 if (argremsz < num_argrem) {
1292 if (I.getArgNo() == argrem[argremsz]) {
1293 VMap[&I] = UndefValue::get(I.getType());
1298 DestI->setName(I.getName());
1299 VMap[&I] = &*DestI++;
1302 SmallVector<ReturnInst *, 8> Returns;
1303 CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
1304 Returns,
"",
nullptr);
1307 for (
auto &B : *NewF) {
1308 if (
auto RI = dyn_cast<ReturnInst>(B.getTerminator())) {
1310 auto NRI = B.CreateRetVoid();
1311 NRI->copyMetadata(*RI);
1312 RI->eraseFromParent();
1316 NewF->setAttributes(NewAttrs);
1318 for (
auto &Arg : NewF->args())
1319 Arg.removeAttr(Attribute::Returned);
1320 SmallVector<std::pair<unsigned, MDNode *>, 1> MD;
1321 F->getAllMetadata(MD);
1322 for (
auto pair : MD)
1323 if (pair.first != LLVMContext::MD_dbg)
1324 NewF->addMetadata(pair.first, *pair.second);
1326 NewF->setCallingConv(F->getCallingConv());
1328 NewF->addFnAttr(
"enzyme_retremove",
"");
1331 SmallVector<uint64_t, 1> previdx;
1332 if (Attrs.hasAttribute(AttributeList::FunctionIndex,
"enzyme_parmremove")) {
1334 Attrs.getAttribute(AttributeList::FunctionIndex,
"enzyme_parmremove");
1335 auto prevstr = attr.getValueAsString();
1336 SmallVector<StringRef, 1> sub;
1337 prevstr.split(sub,
",");
1338 for (
auto s : sub) {
1340 bool b = s.getAsInteger(10, ival);
1343 previdx.push_back(ival);
1346 SmallVector<uint64_t, 1> nextidx;
1347 for (
size_t i = 0; i < num_argrem; i++) {
1348 auto val = argrem[i];
1349 nextidx.push_back(val);
1354 SmallVector<uint64_t, 1> out;
1355 while (prevcnt < previdx.size() && nextcnt < nextidx.size()) {
1356 if (previdx[prevcnt] <= nextidx[nextcnt] + prevcnt) {
1357 out.push_back(previdx[prevcnt]);
1360 out.push_back(nextidx[nextcnt] + prevcnt);
1364 while (prevcnt < previdx.size()) {
1365 out.push_back(previdx[prevcnt]);
1368 while (nextcnt < nextidx.size()) {
1369 out.push_back(nextidx[nextcnt] + prevcnt);
1374 for (
auto arg : out) {
1377 remstr += std::to_string(arg);
1380 NewF->addFnAttr(
"enzyme_parmremove", remstr);
1385 return wrap(cast<AllocaInst>(unwrap(V))->getAllocatedType());
1389 IRBuilder<> &B = *unwrap(B_r);
1390 auto T = cast<IntegerType>(unwrap(T_r));
1391 auto width = T->getBitWidth();
1392 auto uw = unwrap(V_r);
1393 GEPOperator *gep = isa<GetElementPtrInst>(uw)
1394 ? cast<GEPOperator>(cast<GetElementPtrInst>(uw))
1395 : cast<GEPOperator>(cast<ConstantExpr>(uw));
1396 auto &DL = B.GetInsertBlock()->getParent()->getParent()->getDataLayout();
1398#if LLVM_VERSION_MAJOR >= 20
1399 SmallMapVector<Value *, APInt, 4> VariableOffsets;
1401 MapVector<Value *, APInt> VariableOffsets;
1403 APInt Offset(width, 0);
1404 bool success =
collectOffset(gep, DL, width, VariableOffsets, Offset);
1407 Value *start = ConstantInt::get(T, Offset);
1408 for (
auto &pair : VariableOffsets)
1409 start = B.CreateAdd(
1410 start, B.CreateMul(pair.first, ConstantInt::get(T, pair.second)));
1418 unsigned *Index,
unsigned Size,
1420 return wrap(unwrap(B)->CreateExtractValue(
1421 unwrap(AggVal), ArrayRef<unsigned>(Index, Size), Name));
1425 LLVMValueRef EltVal,
unsigned *Index,
1426 unsigned Size,
const char *Name) {
1427 return wrap(unwrap(B)->CreateInsertValue(
1428 unwrap(AggVal), unwrap(EltVal), ArrayRef<unsigned>(Index, Size), Name));
int64_t EnzymeGetCLInteger(void *ptr)
void EnzymeGradientUtilsAddToInvertedPointerDiffeTT(DiffeGradientUtils *gutils, LLVMValueRef orig, LLVMValueRef origVal, CTypeTreeRef vd, unsigned LoadSize, LLVMValueRef origptr, LLVMValueRef prediff, LLVMBuilderRef BuilderM, unsigned align, LLVMValueRef premask)
void EnzymeGradientUtilsSubTransferHelper(GradientUtils *gutils, CDerivativeMode mode, LLVMTypeRef secretty, uint64_t intrinsic, uint64_t dstAlign, uint64_t srcAlign, uint64_t offset, uint8_t dstConstant, LLVMValueRef shadow_dst, uint8_t srcConstant, LLVMValueRef shadow_src, LLVMValueRef length, LLVMValueRef isVolatile, LLVMValueRef MTI, uint8_t allowForward, uint8_t shadowsLookedUp)
EnzymeTypeAnalysisRef EnzymeGetTypeAnalysisFromTypeAnalyzer(void *TAR)
void EnzymeAttributeKnownFunctions(LLVMValueRef FC)
void EnzymeTypeTreeLookupEq(CTypeTreeRef CTT, int64_t size, const char *dl)
LLVMMetadataRef EnzymeAnonymousAliasScopeDomain(const char *str, LLVMContextRef ctx)
void * EnzymeGradientUtilsGetExternalContext(GradientUtils *gutils)
uint8_t EnzymeGradientUtilsIsConstantValue(GradientUtils *gutils, LLVMValueRef val)
LLVMValueRef EnzymeInsertValue(LLVMBuilderRef B, LLVMValueRef val, LLVMValueRef val2, unsigned *sz, int64_t length, const char *name)
void EnzymeSetCalledFunction(LLVMValueRef C_CI, LLVMValueRef C_F, uint64_t *argrem, uint64_t num_argrem)
CDIFFE_TYPE EnzymeGradientUtilsGetReturnDiffeType(GradientUtils *G, LLVMValueRef oval, uint8_t *needsPrimal, uint8_t *needsShadow, CDerivativeMode mode)
void EnzymeStringFree(const char *cstr)
void EnzymeCopyAlignment(LLVMValueRef inst1, LLVMValueRef inst2)
void EnzymeAddAttributorLegacyPass(LLVMPassManagerRef PM)
CTypeTreeRef EnzymeNewTypeTreeTR(CTypeTreeRef CTR)
void FreeTypeAnalysis(EnzymeTypeAnalysisRef TAR)
const char * EnzymeGradientUtilsInvertedPointersToString(GradientUtils *gutils, void *src)
LLVMBasicBlockRef EnzymeGradientUtilsAddReverseBlock(GradientUtils *gutils, LLVMBasicBlockRef block, const char *name, uint8_t forkCache, uint8_t push)
void EnzymeTakeName(LLVMValueRef inst1, LLVMValueRef inst2)
void EnzymeTypeTreeData0Eq(CTypeTreeRef CTT)
EnzymeTraceInterfaceRef CreateEnzymeDynamicTraceInterface(LLVMValueRef interface, LLVMValueRef F)
static bool runAttributorOnFunctions(InformationCache &InfoCache, SetVector< Function * > &Functions, AnalysisGetter &AG, CallGraphUpdater &CGUpdater, bool DeleteFns, bool IsModulePass)
LLVMValueRef EnzymeGetStringMD(LLVMValueRef Inst, const char *Kind)
void EnzymeFreeTypeTree(CTypeTreeRef CTT)
EnzymeTraceInterfaceRef FindEnzymeStaticTraceInterface(LLVMModuleRef M)
static MDNode * extractMDNode(MetadataAsValue *MAV)
LLVMTypeRef EnzymeExtractUnderlyingTapeTypeFromAugmentation(EnzymeAugmentedReturnPtr ret)
void EnzymeSetCLInteger(void *ptr, int64_t val)
void ClearTypeAnalysis(EnzymeTypeAnalysisRef TAR)
CConcreteType EnzymeTypeTreeInner0(CTypeTreeRef CTT)
void EnzymeTypeTreeToStringFree(const char *cstr)
EnzymeLogicRef CreateEnzymeLogic(uint8_t PostOpt)
void EnzymeRegisterFwdCallHandler(char *Name, CustomFunctionForward FwdHandle)
uint8_t EnzymeGetCLBool(void *ptr)
void EnzymeDumpValueRef(LLVMValueRef M)
void FreeEnzymeLogic(EnzymeLogicRef Ref)
void EnzymeGradientUtilsSetDebugLocFromOriginal(GradientUtils *gutils, LLVMValueRef val, LLVMValueRef orig)
LLVMValueRef EnzymeGradientUtilsLookup(GradientUtils *gutils, LLVMValueRef val, LLVMBuilderRef B)
const char * EnzymeTypeAnalyzerToString(void *src)
CDIFFE_TYPE EnzymeGradientUtilsGetDiffeType(GradientUtils *G, LLVMValueRef oval, uint8_t foreignFunction)
uint8_t EnzymeLowerSparsification(LLVMValueRef F, uint8_t replaceAll)
LLVMMetadataRef EnzymeAnonymousAliasScope(LLVMMetadataRef domain, const char *str)
void EnzymeReplaceOriginalToNew(GradientUtils *gutils, LLVMValueRef origC, LLVMValueRef repC)
void EnzymeSetStringMD(LLVMValueRef Inst, const char *Kind, LLVMValueRef Val)
uint8_t EnzymeGradientUtilsGetAtomicAdd(GradientUtils *gutils)
LLVMValueRef EnzymeGradientUtilsCallWithInvertedBundles(GradientUtils *gutils, LLVMValueRef func, LLVMTypeRef funcTy, LLVMValueRef *args_vr, uint64_t args_size, LLVMValueRef orig_vr, CValueType *valTys, uint64_t valTys_size, LLVMBuilderRef B, uint8_t lookup)
void EnzymeDetectReadonlyOrThrow(LLVMModuleRef M)
LLVMValueRef EnzymeBuildExtractValue(LLVMBuilderRef B, LLVMValueRef AggVal, unsigned *Index, unsigned Size, const char *Name)
void EnzymeLogicSetExternalContext(EnzymeLogicRef Ref, void *ExternalContext)
LLVMValueRef EnzymeGradientUtilsNewFromOriginal(GradientUtils *gutils, LLVMValueRef val)
void EnzymeTypeTreeInsertEq(CTypeTreeRef CTT, const int64_t *indices, size_t len, CConcreteType ct, LLVMContextRef ctx)
uint8_t EnzymeGradientUtilsGetRuntimeActivity(GradientUtils *gutils)
void EnzymeDumpTypeRef(LLVMTypeRef M)
void EnzymeCloneFunctionDISubprogramInto(LLVMValueRef NF, LLVMValueRef F)
void FreeTraceInterface(EnzymeTraceInterfaceRef Ref)
void EnzymeLogicErasePreprocessedFunctions(EnzymeLogicRef Ref)
CTypeTreeRef EnzymeNewTypeTree()
void EnzymeGradientUtilsSetReverseBlock(GradientUtils *gutils, LLVMBasicBlockRef block)
LLVMValueRef EnzymeCreateBatch(EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, LLVMValueRef tobatch, unsigned width, CBATCH_TYPE *arg_types, size_t arg_types_size, CBATCH_TYPE retType)
void * EnzymeGradientUtilsTypeAnalyzer(GradientUtils *G)
uint8_t EnzymeGradientUtilsIsConstantInstruction(GradientUtils *gutils, LLVMValueRef val)
LLVMTypeRef EnzymeGradientUtilsGetShadowType(GradientUtils *gutils, LLVMTypeRef T)
void RunAttributorOnModule(LLVMModuleRef M0)
void EnzymeRegisterCallHandler(const char *Name, CustomAugmentedFunctionForward FwdHandle, CustomFunctionReverse RevHandle)
LLVMValueRef EnzymeCreateForwardDiff(EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args, size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue, CDerivativeMode mode, uint8_t freeMemory, uint8_t runtimeActivity, uint8_t strongZero, unsigned width, LLVMTypeRef additionalArg, CFnTypeInfo typeInfo, uint8_t subsequent_calls_may_write, uint8_t *_overwritten_args, size_t overwritten_args_size, EnzymeAugmentedReturnPtr augmented)
void EnzymeGradientUtilsDumpTypeResults(GradientUtils *gutils)
uint8_t EnzymeSetTypeTree(CTypeTreeRef dst, CTypeTreeRef src)
uint8_t EnzymeGradientUtilsGetStrongZero(GradientUtils *gutils)
EnzymeLogicRef EnzymeTypeAnalysisGetLogic(EnzymeTypeAnalysisRef TAR)
LLVMBasicBlockRef EnzymeGradientUtilsAllocationBlock(GradientUtils *gutils)
void EnzymeGradientUtilsAddToDiffe(DiffeGradientUtils *gutils, LLVMValueRef val, LLVMValueRef diffe, LLVMBuilderRef B, LLVMTypeRef T)
LLVMValueRef EnzymeCreatePrimalAndGradient(EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args, size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue, uint8_t dretUsed, CDerivativeMode mode, uint8_t runtimeActivity, uint8_t strongZero, unsigned width, uint8_t freeMemory, LLVMTypeRef additionalArg, uint8_t forceAnonymousTape, CFnTypeInfo typeInfo, uint8_t subsequent_calls_may_write, uint8_t *_overwritten_args, size_t overwritten_args_size, EnzymeAugmentedReturnPtr augmented, uint8_t AtomicAdd)
uint8_t EnzymeGradientUtilsGetUncacheableArgs(GradientUtils *gutils, LLVMValueRef orig, uint8_t *data, uint64_t size)
LLVMValueRef EnzymeCreateTrace(EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, LLVMValueRef totrace, LLVMValueRef *sample_functions, size_t sample_functions_size, LLVMValueRef *observe_functions, size_t observe_functions_size, const char *active_random_variables[], size_t active_random_variables_size, CProbProgMode mode, uint8_t autodiff, EnzymeTraceInterfaceRef interface)
void EnzymeTypeTreeOnlyEq(CTypeTreeRef CTT, int64_t x)
TargetLibraryInfo eunwrap(LLVMTargetLibraryInfoRef P)
LLVMValueRef EnzymeComputeByteOffsetOfGEP(LLVMBuilderRef B_r, LLVMValueRef V_r, LLVMTypeRef T_r)
void EnzymeSetMustCache(LLVMValueRef inst1)
void EnzymeDumpModuleRef(LLVMModuleRef M)
void EnzymeGradientUtilsEraseWithPlaceholder(GradientUtils *G, LLVMValueRef I, LLVMValueRef orig, uint8_t erase)
LLVMTypeRef EnzymeExtractTapeTypeFromAugmentation(EnzymeAugmentedReturnPtr ret)
CTypeTreeRef EnzymeTypeTreeFromMD(LLVMValueRef Val)
CDerivativeMode EnzymeGradientUtilsGetMode(GradientUtils *gutils)
LLVMValueRef EnzymeGradientUtilsDiffe(DiffeGradientUtils *gutils, LLVMValueRef val, LLVMBuilderRef B)
EnzymeAugmentedReturnPtr ewrap(const AugmentedReturn &AR)
CTypeTreeRef EnzymeNewTypeTreeCT(CConcreteType CT, LLVMContextRef ctx)
EnzymeTraceInterfaceRef CreateEnzymeStaticTraceInterface(LLVMContextRef C, LLVMValueRef getTraceFunction, LLVMValueRef getChoiceFunction, LLVMValueRef insertCallFunction, LLVMValueRef insertChoiceFunction, LLVMValueRef insertArgumentFunction, LLVMValueRef insertReturnFunction, LLVMValueRef insertFunctionFunction, LLVMValueRef insertChoiceGradientFunction, LLVMValueRef insertArgumentGradientFunction, LLVMValueRef newTraceFunction, LLVMValueRef freeTraceFunction, LLVMValueRef hasCallFunction, LLVMValueRef hasChoiceFunction)
void EnzymeGradientUtilsAddToInvertedPointerDiffe(DiffeGradientUtils *gutils, LLVMValueRef orig, LLVMValueRef origVal, LLVMTypeRef addingType, unsigned start, unsigned size, LLVMValueRef origptr, LLVMValueRef dif, LLVMBuilderRef BuilderM, unsigned align, LLVMValueRef mask)
uint8_t EnzymeMergeTypeTree(CTypeTreeRef dst, CTypeTreeRef src)
LLVMValueRef EnzymeBuildInsertValue(LLVMBuilderRef B, LLVMValueRef AggVal, LLVMValueRef EltVal, unsigned *Index, unsigned Size, const char *Name)
void EnzymeGradientUtilsErase(GradientUtils *G, LLVMValueRef I)
uint8_t EnzymeHasFromStack(LLVMValueRef inst1)
EnzymeTypeAnalysisRef CreateTypeAnalysis(EnzymeLogicRef Log, char **customRuleNames, CustomRuleType *customRules, size_t numRules)
LLVMValueRef EnzymeTypeTreeToMD(CTypeTreeRef CTR, LLVMContextRef ctx)
void * EnzymeAnalyzeTypes(EnzymeTypeAnalysisRef TAR, CFnTypeInfo CTI, LLVMValueRef F)
LLVMMetadataRef EnzymeMakeNonConstTBAA(LLVMMetadataRef MD)
void EnzymeGradientUtilsSetDiffe(DiffeGradientUtils *gutils, LLVMValueRef val, LLVMValueRef diffe, LLVMBuilderRef B)
void EnzymeSetCLBool(void *ptr, uint8_t val)
LLVMValueRef EnzymeExtractFunctionFromAugmentation(EnzymeAugmentedReturnPtr ret)
LLVMValueRef EnzymeGradientUtilsInvertPointer(GradientUtils *gutils, LLVMValueRef val, LLVMBuilderRef B)
void EnzymeSetCLString(void *ptr, const char *val)
std::set< int64_t > eunwrap64(IntList IL)
void EnzymeExtractReturnInfo(EnzymeAugmentedReturnPtr ret, int64_t *data, uint8_t *existed, size_t len)
CTypeTreeRef EnzymeGradientUtilsAllocAndGetTypeTree(GradientUtils *gutils, LLVMValueRef val)
void EnzymeGradientUtilsReplaceAWithB(GradientUtils *G, LLVMValueRef A, LLVMValueRef B)
void EnzymeTypeTreeShiftIndiciesEq(CTypeTreeRef CTT, const char *datalayout, int64_t offset, int64_t maxSize, uint64_t addOffset)
void EnzymeCopyMetadata(LLVMValueRef inst1, LLVMValueRef inst2)
uint8_t EnzymeCheckedMergeTypeTree(CTypeTreeRef dst, CTypeTreeRef src, uint8_t *legalP)
void * EnzymeLogicGetExternalContext(EnzymeLogicRef Ref)
void EnzymeReplaceFunctionImplementation(LLVMModuleRef M)
void EnzymeTypeTreeCanonicalizeInPlace(CTypeTreeRef CTT, int64_t size, const char *dl)
void EnzymeRegisterDiffUseCallHandler(char *Name, CustomFunctionDiffUse Handle)
const char * EnzymeTypeTreeToString(CTypeTreeRef src)
void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle, CustomShadowFree FHandle)
uint64_t EnzymeGradientUtilsGetWidth(GradientUtils *gutils)
EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal(EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args, size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnUsed, uint8_t shadowReturnUsed, CFnTypeInfo typeInfo, uint8_t subsequent_calls_may_write, uint8_t *_overwritten_args, size_t overwritten_args_size, uint8_t forceAnonymousTape, uint8_t runtimeActivity, uint8_t strongZero, unsigned width, uint8_t AtomicAdd)
LLVMValueRef EnzymeCloneFunctionWithoutReturnOrArgs(LLVMValueRef FC, uint8_t keepReturnU, uint64_t *argrem, uint64_t num_argrem)
LLVMTypeRef EnzymeGetShadowType(uint64_t width, LLVMTypeRef T)
LLVMTypeRef EnzymeAllocaType(LLVMValueRef V)
void ClearEnzymeLogic(EnzymeLogicRef Ref)
void EnzymeMoveBefore(LLVMValueRef inst1, LLVMValueRef inst2, LLVMBuilderRef B)
uint8_t(* CustomFunctionForward)(LLVMBuilderRef, LLVMValueRef, GradientUtils *, LLVMValueRef *, LLVMValueRef *)
struct EnzymeOpaqueTypeAnalysis * EnzymeTypeAnalysisRef
struct EnzymeOpaqueTraceInterface * EnzymeTraceInterfaceRef
struct EnzymeTypeTree * CTypeTreeRef
uint8_t(* CustomAugmentedFunctionForward)(LLVMBuilderRef, LLVMValueRef, GradientUtils *, LLVMValueRef *, LLVMValueRef *, LLVMValueRef *)
uint8_t(* CustomRuleType)(int, CTypeTreeRef, CTypeTreeRef *, struct IntList *, size_t, LLVMValueRef, void *)
uint8_t(* CustomFunctionDiffUse)(LLVMValueRef, const GradientUtils *, LLVMValueRef, uint8_t, CDerivativeMode, uint8_t *)
void(* CustomFunctionReverse)(LLVMBuilderRef, LLVMValueRef, DiffeGradientUtils *, LLVMValueRef)
LLVMValueRef(* CustomShadowFree)(LLVMBuilderRef, LLVMValueRef)
struct EnzymeOpaqueLogic * EnzymeLogicRef
LLVMValueRef(* CustomShadowAlloc)(LLVMBuilderRef, LLVMValueRef, size_t, LLVMValueRef *, GradientUtils *)
struct EnzymeOpaqueAugmentedReturn * EnzymeAugmentedReturnPtr
StringMap< std::function< bool(const CallInst *, const GradientUtils *, const Value *, bool, DerivativeMode, bool &)> > customDiffUseHandlers
void ReplaceFunctionImplementation(Module &M)
bool DetectReadonlyOrThrow(Module &M)
bool LowerSparsification(llvm::Function *F, bool replaceAll)
Lower __enzyme_todense, returning if changed.
static std::string str(AugmentedStruct c)
StringMap< std::function< CallInst *(IRBuilder<> &, Value *)> > shadowErasers
StringMap< std::function< Value *(IRBuilder<> &, CallInst *, ArrayRef< Value * >, GradientUtils *)> > shadowHandlers
StringMap< std::pair< std::function< bool(IRBuilder<> &, CallInst *, GradientUtils &, Value *&, Value *&, Value *&)>, std::function< void(IRBuilder<> &, CallInst *, DiffeGradientUtils &, Value *)> > > customCallHandlers
StringMap< std::function< bool(IRBuilder<> &, CallInst *, GradientUtils &, Value *&, Value *&)> > customFwdCallHandlers
void SubTransferHelper(GradientUtils *gutils, DerivativeMode mode, Type *secretty, Intrinsic::ID intrinsic, unsigned dstalign, unsigned srcalign, unsigned offset, bool dstConstant, Value *shadow_dst, bool srcConstant, Value *shadow_src, Value *length, Value *isVolatile, llvm::CallInst *MTI, bool allowForward, bool shadowsLookedUp, bool backwardsShadow)
bool attributeKnownFunctions(llvm::Function &F)
bool collectOffset(GEPOperator *gep, const DataLayout &DL, unsigned BitWidth, MapVector< Value *, APInt > &VariableOffsets, APInt &ConstantOffset)
@ Args
Return is a struct of all args.
DIFFE_TYPE
Potential differentiable argument classifications.
static llvm::MDNode * hasMetadata(const llvm::GlobalObject *O, llvm::StringRef kind)
Check if a global has metadata.
ValueType
Classification of value as an original program variable, a derivative variable, neither,...
return structtype if recursive function
llvm::BasicBlock * inversionAllocs
Concrete SubType of a given value.
BaseType SubTypeEnum
Category of underlying type.
llvm::Type * isFloat() const
Return the floating point type, if this is a float.
void addToInvertedPtrDiffe(llvm::Instruction *orig, llvm::Value *origVal, llvm::Type *addingType, unsigned start, unsigned size, llvm::Value *origptr, llvm::Value *dif, llvm::IRBuilder<> &BuilderM, llvm::MaybeAlign align=llvm::MaybeAlign(), llvm::Value *mask=nullptr)
align is the alignment that should be specified for load/store to pointer
llvm::Value * diffe(llvm::Value *val, llvm::IRBuilder<> &BuilderM)
void setDiffe(llvm::Value *val, llvm::Value *toset, llvm::IRBuilder<> &BuilderM)
llvm::SmallVector< llvm::SelectInst *, 4 > addToDiffe(llvm::Value *val, llvm::Value *dif, llvm::IRBuilder<> &BuilderM, llvm::Type *addingType, llvm::ArrayRef< llvm::Value * > idxs={}, llvm::Value *mask=nullptr)
Returns created select instructions, if any.
void * ExternalContext
Provided through the frontend and only used from it.
llvm::ValueMap< const llvm::Value *, InvertedPointerVH > invertedPointers
llvm::DebugLoc getNewFromOriginal(const llvm::DebugLoc L) const
void eraseWithPlaceholder(llvm::Instruction *I, llvm::Instruction *orig, const llvm::Twine &suffix="_replacementA", bool erase=true)
llvm::BasicBlock * addReverseBlock(llvm::BasicBlock *currentBlock, llvm::Twine const &name, bool forkCache=true, bool push=true)
const std::map< llvm::CallInst *, std::pair< bool, const std::vector< bool > > > * overwritten_args_map_ptr
std::map< llvm::BasicBlock *, llvm::SmallVector< llvm::BasicBlock *, 4 > > reverseBlocks
Map of primal block to corresponding block(s) in reverse.
llvm::SmallVector< llvm::OperandBundleDef, 2 > getInvertedBundles(llvm::CallInst *orig, llvm::ArrayRef< ValueType > types, llvm::IRBuilder<> &Builder2, bool lookup, const llvm::ValueToValueMapTy &available=llvm::ValueToValueMapTy())
void replaceAWithB(llvm::Value *A, llvm::Value *B, bool storeInCache=false) override
Replace this instruction both in LLVM modules and any local data-structures.
DIFFE_TYPE getDiffeType(llvm::Value *v, bool foreignFunction) const
llvm::Value * lookupM(llvm::Value *val, llvm::IRBuilder<> &BuilderM, const llvm::ValueToValueMapTy &incoming_availalble=llvm::ValueToValueMapTy(), bool tryLegalRecomputeCheck=true, llvm::BasicBlock *scope=nullptr) override
High-level utility to get the value an instruction at a new location specified by BuilderM.
llvm::ValueMap< const llvm::Value *, AssertingReplacingVH > originalToNewFn
DIFFE_TYPE getReturnDiffeType(llvm::Value *orig, bool *primalReturnUsedP, bool *shadowReturnUsedP, DerivativeMode cmode) const
std::map< llvm::BasicBlock *, llvm::BasicBlock * > reverseBlockToPrimal
Map of block in reverse to corresponding primal block.
bool isConstantInstruction(const llvm::Instruction *inst) const
llvm::ValueMap< const llvm::Value *, AssertingReplacingVH > newToOriginalFn
static llvm::Type * getShadowType(llvm::Type *ty, unsigned width)
bool isConstantValue(llvm::Value *val) const
llvm::Value * invertPointerM(llvm::Value *val, llvm::IRBuilder<> &BuilderM, bool nullShadow=false)
void erase(llvm::Instruction *I) override
Erase this instruction both from LLVM modules and any local data-structures.
Full interprocedural TypeAnalysis.
llvm::StringMap< std::function< bool(int, TypeTree &, llvm::ArrayRef< TypeTree >, llvm::ArrayRef< std::set< int64_t > >, llvm::CallBase *, TypeAnalyzer *)> > CustomRules
Map of custom function call handlers.
void clear()
Clear existing analyses.
Helper class that computes the fixed-point type results of a given function.
TypeAnalysis & interprocedural
Calling TypeAnalysis to be used in the case of calls to other functions.
void dump(llvm::raw_ostream &ss=llvm::errs()) const
Prints all known information.
TypeTree query(llvm::Value *val) const
The TypeTree of a particular Value.
Class representing the underlying types of values as sequences of offsets to a ConcreteType.
void insertFromMD(llvm::MDNode *md, const std::vector< int > &prev={})
struct IntList * KnownValues
The specific constant(s) known to represented by an argument, if constant.
CTypeTreeRef * Arguments
Types of arguments, assumed of size len(Arguments)
CTypeTreeRef Return
Type of return.
Struct containing all contextual type information for a particular function call.
std::map< llvm::Argument *, TypeTree > Arguments
Types of arguments.
TypeTree Return
Type of return.
std::map< llvm::Argument *, std::set< int64_t > > KnownValues
The specific constant(s) known to represented by an argument, if constant.
void getAnalysisUsage(AnalysisUsage &AU) const override
bool runOnModule(Module &M) override
todiff is the function to differentiate retType is the activity info of the return.