26#include <llvm/Config/llvm-config.h>
29#if LLVM_VERSION_MAJOR >= 16
31#include "llvm/Analysis/ScalarEvolution.h"
32#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
35#include "SCEV/ScalarEvolution.h"
36#include "SCEV/ScalarEvolutionExpander.h"
39#include "llvm/ADT/ArrayRef.h"
40#include "llvm/ADT/MapVector.h"
42#if LLVM_VERSION_MAJOR <= 16
43#include "llvm/ADT/Optional.h"
45#include "llvm/ADT/SetVector.h"
46#include "llvm/ADT/SmallSet.h"
47#include "llvm/ADT/SmallVector.h"
49#include "llvm/Passes/PassBuilder.h"
51#include "llvm/IR/BasicBlock.h"
52#include "llvm/IR/Constants.h"
53#include "llvm/IR/Function.h"
54#include "llvm/IR/IRBuilder.h"
55#include "llvm/IR/InstrTypes.h"
56#include "llvm/IR/Instructions.h"
57#include "llvm/IR/MDBuilder.h"
58#include "llvm/IR/Metadata.h"
60#include "llvm/Analysis/ScalarEvolution.h"
61#include "llvm/Support/Debug.h"
62#include "llvm/Support/ErrorHandling.h"
63#include "llvm/Transforms/Scalar.h"
65#include "llvm/Analysis/BasicAliasAnalysis.h"
66#include "llvm/Analysis/GlobalsModRef.h"
67#include "llvm/Analysis/InlineAdvisor.h"
68#include "llvm/Analysis/InlineCost.h"
69#include "llvm/Analysis/ScalarEvolution.h"
70#include "llvm/Analysis/TargetLibraryInfo.h"
71#include "llvm/IR/AbstractCallSite.h"
72#include "llvm/Support/CommandLine.h"
73#include "llvm/Transforms/Utils/BasicBlockUtils.h"
74#include "llvm/Transforms/Utils/Cloning.h"
76#include "ActivityAnalysis.h"
78#include "EnzymeLogic.h"
79#include "GradientUtils.h"
87#include "llvm/Transforms/Utils.h"
89#include "llvm/Transforms/IPO/Attributor.h"
90#include "llvm/Transforms/IPO/OpenMPOpt.h"
91#include "llvm/Transforms/Utils/Mem2Reg.h"
98#define DEBUG_TYPE "lower-enzyme-intrinsic"
100llvm::cl::opt<bool>
EnzymeEnable(
"enzyme-enable", cl::init(
true), cl::Hidden,
101 cl::desc(
"Run the Enzyme pass"));
105 cl::desc(
"Run enzymepostprocessing optimizations"));
109 cl::desc(
"Run attributor post Enzyme"));
111llvm::cl::opt<bool>
EnzymeOMPOpt(
"enzyme-omp-opt", cl::init(
false), cl::Hidden,
112 cl::desc(
"Whether to enable openmp opt"));
115 "enzyme-detect-readthrow", cl::init(
true), cl::Hidden,
116 cl::desc(
"Run preprocessing detect readonly or throw optimization"));
119 "enzyme-truncate-all", cl::init(
""), cl::Hidden,
121 "Truncate all floating point operations. "
122 "E.g. \"64to32\" or \"64to<exponent_width>-<significand_width>\"."));
124#define addAttribute addAttributeAtIndex
125#define getAttribute getAttributeAtIndex
129castToDiffeFunctionArgType(IRBuilder<> &Builder, llvm::CallInst *CI,
130 llvm::FunctionType *FT, llvm::Type *destType,
132 llvm::Value *value,
unsigned int truei) {
134 if (
auto ptr = dyn_cast<PointerType>(res->getType())) {
135 if (
auto PT = dyn_cast<PointerType>(destType)) {
136 if (ptr->getAddressSpace() != PT->getAddressSpace()) {
137#if LLVM_VERSION_MAJOR < 17
138 if (CI->getContext().supportsTypedPointers()) {
139 res = Builder.CreateAddrSpaceCast(
140 res, PointerType::get(ptr->getPointerElementType(),
141 PT->getAddressSpace()));
143 res = Builder.CreateAddrSpaceCast(res, PT);
146 res = Builder.CreateAddrSpaceCast(res, PT);
151 llvm::errs() <<
"Warning cast(2) __enzyme_autodiff argument " << i
152 <<
" " << *res <<
"|" << *res->getType() <<
" to argument "
153 << truei <<
" " << *destType <<
"\n"
154 <<
"orig: " << *FT <<
"\n";
160 if (!res->getType()->canLosslesslyBitCastTo(destType)) {
162 assert(value->getType());
165 auto loc = CI->getDebugLoc();
166 if (
auto arg = dyn_cast<Instruction>(res)) {
167 loc = arg->getDebugLoc();
170 "Cannot cast __enzyme_autodiff shadow argument ", i,
", found ",
171 *res,
", type ", *res->getType(),
" - to arg ", truei,
" ",
175 return Builder.CreateBitCast(value, destType);
178#if LLVM_VERSION_MAJOR > 16
179static std::optional<StringRef> getMetadataName(llvm::Value *res);
181static Optional<StringRef> getMetadataName(llvm::Value *res);
185#if LLVM_VERSION_MAJOR > 16
186static std::optional<StringRef> recursePhiReads(PHINode *val)
188static Optional<StringRef> recursePhiReads(PHINode *val)
191#if LLVM_VERSION_MAJOR > 16
192 std::optional<StringRef> finalMetadata;
194 Optional<StringRef> finalMetadata;
196 SmallVector<PHINode *, 1> todo = {val};
197 SmallSet<PHINode *, 1> done;
198 while (todo.size()) {
199 auto phiInst = todo.back();
201 if (done.count(phiInst))
203 done.insert(phiInst);
204 for (
unsigned j = 0; j < phiInst->getNumIncomingValues(); ++j) {
205 auto newVal = phiInst->getIncomingValue(j);
206 if (
auto phi = dyn_cast<PHINode>(newVal)) {
209 auto metaString = getMetadataName(newVal);
211 if (!finalMetadata) {
212 finalMetadata = metaString;
213 }
else if (finalMetadata != metaString) {
220 return finalMetadata;
223#if LLVM_VERSION_MAJOR > 16
224std::optional<StringRef> getMetadataName(llvm::Value *res)
226Optional<StringRef> getMetadataName(llvm::Value *res)
230 return getMetadataName(S);
232 if (
auto av = dyn_cast<MetadataAsValue>(res)) {
233 return cast<MDString>(av->getMetadata())->getString();
234 }
else if ((isa<LoadInst>(res) || isa<CastInst>(res)) &&
235 isa<GlobalVariable>(cast<Instruction>(res)->getOperand(0))) {
237 cast<GlobalVariable>(cast<Instruction>(res)->getOperand(0));
238 return gv->getName();
239 }
else if (isa<LoadInst>(res) &&
240 isa<ConstantExpr>(cast<LoadInst>(res)->getOperand(0)) &&
241 cast<ConstantExpr>(cast<LoadInst>(res)->getOperand(0))->isCast() &&
243 cast<ConstantExpr>(cast<LoadInst>(res)->getOperand(0))
245 auto gv = cast<GlobalVariable>(
246 cast<ConstantExpr>(cast<LoadInst>(res)->getOperand(0))->getOperand(0));
247 return gv->getName();
248 }
else if (
auto gv = dyn_cast<GlobalVariable>(res)) {
249 return gv->getName();
250 }
else if (isa<ConstantExpr>(res) && cast<ConstantExpr>(res)->isCast() &&
251 isa<GlobalVariable>(cast<ConstantExpr>(res)->getOperand(0))) {
252 auto gv = cast<GlobalVariable>(cast<ConstantExpr>(res)->getOperand(0));
253 return gv->getName();
254 }
else if (isa<CastInst>(res) && cast<CastInst>(res) &&
255 isa<AllocaInst>(cast<CastInst>(res)->getOperand(0))) {
256 auto gv = cast<AllocaInst>(cast<CastInst>(res)->getOperand(0));
257 return gv->getName();
258 }
else if (
auto gv = dyn_cast<AllocaInst>(res)) {
259 return gv->getName();
260 }
else if (isa<PHINode>(res)) {
261 return recursePhiReads(cast<PHINode>(res));
267static Value *adaptReturnedVector(Value *ret, Value *diffret,
268 IRBuilder<> &Builder,
unsigned width) {
269 Type *returnType = ret->getType();
271 if (StructType *sty = dyn_cast<StructType>(returnType)) {
272 Value *agg = ConstantAggregateZero::get(sty);
274 for (
unsigned int i = 0; i < width; i++) {
275 Value *elem = Builder.CreateExtractValue(diffret, {i});
276 if (
auto vty = dyn_cast<FixedVectorType>(elem->getType())) {
277 for (
unsigned j = 0; j < vty->getNumElements(); ++j) {
278 Value *vecelem = Builder.CreateExtractElement(elem, j);
279 agg = Builder.CreateInsertValue(agg, vecelem, {i * j});
282 agg = Builder.CreateInsertValue(agg, elem, {i});
290static bool ReplaceOriginalCall(IRBuilder<> &Builder, Value *ret,
291 Type *retElemType, Value *diffret,
293 Type *retType = ret->getType();
294 Type *diffretType = diffret->getType();
295 auto &DL = CI->getModule()->getDataLayout();
297 if (diffretType->isEmptyTy() || diffretType->isVoidTy() ||
298 retType->isEmptyTy() || retType->isVoidTy()) {
299 CI->replaceAllUsesWith(UndefValue::get(CI->getType()));
300 CI->eraseFromParent();
304 if (retType == diffretType) {
305 CI->replaceAllUsesWith(diffret);
306 CI->eraseFromParent();
310 if (
auto sretType = dyn_cast<StructType>(retType),
311 diffsretType = dyn_cast<StructType>(diffretType);
312 sretType && diffsretType && sretType->isLayoutIdentical(diffsretType)) {
313 Value *newStruct = UndefValue::get(sretType);
314 for (
unsigned int i = 0; i < sretType->getStructNumElements(); i++) {
315 Value *elem = Builder.CreateExtractValue(diffret, {i});
316 newStruct = Builder.CreateInsertValue(newStruct, elem, {i});
318 CI->replaceAllUsesWith(newStruct);
319 CI->eraseFromParent();
323 if (isa<PointerType>(retType)) {
324 retType = retElemType;
326 if (
auto sretType = dyn_cast<StructType>(retType),
327 diffsretType = dyn_cast<StructType>(diffretType);
328 sretType && diffsretType && sretType->isLayoutIdentical(diffsretType)) {
329 for (
unsigned int i = 0; i < sretType->getStructNumElements(); i++) {
330 Value *sgep = Builder.CreateStructGEP(retType, ret, i);
331 Builder.CreateStore(Builder.CreateExtractValue(diffret, {i}), sgep);
333 CI->eraseFromParent();
337 if (DL.getTypeSizeInBits(retType) >= DL.getTypeSizeInBits(diffretType)) {
339 diffret, Builder.CreatePointerCast(ret,
getUnqual(diffretType)));
340 CI->eraseFromParent();
346 DL.getTypeSizeInBits(retType) >= DL.getTypeSizeInBits(diffretType)) ||
349 DL.getTypeSizeInBits(retType) == DL.getTypeSizeInBits(diffretType))) {
350 IRBuilder<> EB(CI->getFunction()->getEntryBlock().getFirstNonPHI());
351 auto AL = EB.CreateAlloca(retType);
352 Builder.CreateStore(diffret,
353 Builder.CreatePointerCast(AL,
getUnqual(diffretType)));
354 Value *cload = Builder.CreateLoad(retType, AL);
355 CI->replaceAllUsesWith(cload);
356 CI->eraseFromParent();
361 diffret->getType()->isAggregateType()) {
362 auto diffreti = Builder.CreateExtractValue(diffret, {0});
363 if (diffreti->getType() == retType) {
364 CI->replaceAllUsesWith(diffreti);
365 CI->eraseFromParent();
367 }
else if (diffretType == retType) {
368 CI->replaceAllUsesWith(diffret);
369 CI->eraseFromParent();
374 auto diffretsize = DL.getTypeSizeInBits(diffretType);
375 auto retsize = DL.getTypeSizeInBits(retType);
376 EmitFailure(
"IllegalReturnCast", CI->getDebugLoc(), CI,
377 "Cannot cast return type of gradient ", *diffretType, *diffret,
378 " of size ", diffretsize,
" bits ",
", to desired type ",
379 *retType,
" of size ", retsize,
" bits");
386 EnzymeBase(
bool PostOpt)
391 Function *parseFunctionParameter(CallInst *CI) {
392 Value *fn = CI->getArgOperand(0);
395 if (CI->hasStructRetAttr()) {
396 fn = CI->getArgOperand(1);
402 if (!fn || !isa<Function>(fn)) {
404 EmitFailure(
"NoFunctionToDifferentiate", CI->getDebugLoc(), CI,
405 "failed to find fn to differentiate", *CI,
" - found - ",
409 if (cast<Function>(fn)->empty()) {
410 EmitFailure(
"EmptyFunctionToDifferentiate", CI->getDebugLoc(), CI,
411 "failed to find fn to differentiate", *CI,
" - found - ",
416 return cast<Function>(fn);
419#if LLVM_VERSION_MAJOR > 16
420 static std::optional<unsigned> parseWidthParameter(CallInst *CI)
422 static Optional<unsigned> parseWidthParameter(CallInst *CI)
427 for (
auto [i, found] = std::tuple{0u,
false}; i < CI->arg_size(); ++i) {
428 Value *arg = CI->getArgOperand(i);
430 if (
auto MDName = getMetadataName(arg)) {
431 if (*MDName ==
"enzyme_width") {
433 EmitFailure(
"IllegalVectorWidth", CI->getDebugLoc(), CI,
434 "vector width declared more than once",
435 *CI->getArgOperand(i),
" in", *CI);
439 if (i + 1 >= CI->arg_size()) {
440 EmitFailure(
"MissingVectorWidth", CI->getDebugLoc(), CI,
441 "constant integer followong enzyme_width is missing",
442 *CI->getArgOperand(i),
" in", *CI);
446 Value *width_arg = CI->getArgOperand(i + 1);
447 if (
auto cint = dyn_cast<ConstantInt>(width_arg)) {
448 width = cint->getZExtValue();
451 EmitFailure(
"IllegalVectorWidth", CI->getDebugLoc(), CI,
452 "enzyme_width must be a constant integer",
453 *CI->getArgOperand(i),
" in", *CI);
458 EmitFailure(
"IllegalVectorWidth", CI->getDebugLoc(), CI,
459 "illegal enzyme vector argument width ",
460 *CI->getArgOperand(i),
" in", *CI);
472 Value *dynamic_interface;
476 Value *diffeLikelihood;
478 int allocatedTapeSize;
482 bool differentialReturn;
486 StringSet<> ActiveRandomVariables;
487 std::vector<bool> overwritten_args;
488 bool runtimeActivity;
490 bool subsequent_calls_may_write;
493#if LLVM_VERSION_MAJOR > 16
494 static std::optional<Options>
495 handleArguments(IRBuilder<> &Builder, CallInst *CI, Function *fn,
497 std::vector<DIFFE_TYPE> &constants,
498 SmallVectorImpl<Value *> &args, std::map<int, Type *> &byVal)
500 static Optional<Options>
501 handleArguments(IRBuilder<> &Builder, CallInst *CI, Function *fn,
503 std::vector<DIFFE_TYPE> &constants,
504 SmallVectorImpl<Value *> &args, std::map<int, Type *> &byVal)
507 FunctionType *FT = fn->getFunctionType();
509 Value *differet =
nullptr;
510 Value *tape =
nullptr;
511 Value *dynamic_interface =
nullptr;
512 Value *trace =
nullptr;
513 Value *observations =
nullptr;
514 Value *likelihood =
nullptr;
515 Value *diffeLikelihood =
nullptr;
517 int allocatedTapeSize = -1;
518 bool freeMemory =
true;
519 bool tapeIsPointer =
false;
520 bool diffeTrace =
false;
522 unsigned byRefSize = 0;
523 bool primalReturn =
false;
524 bool runtimeActivity =
false;
525 bool strongZero =
false;
526 bool subsequent_calls_may_write =
530 StringSet<> ActiveRandomVariables;
534 if (fn->hasParamAttribute(0, Attribute::StructRet)) {
536 Ty = fn->getParamAttribute(0, Attribute::StructRet).getValueAsType();
543 !fn->getReturnType()->isVoidTy() && !fn->getReturnType()->isEmptyTy();
545 bool sret = CI->hasStructRetAttr() ||
546 fn->hasParamAttribute(0, Attribute::StructRet);
548 std::vector<bool> overwritten_args(
549 fn->getFunctionType()->getNumParams(),
552 for (
unsigned i = 1 + sret; i < CI->arg_size(); ++i) {
553 Value *res = CI->getArgOperand(i);
554 auto metaString = getMetadataName(res);
556 if (metaString &&
startsWith(*metaString,
"enzyme_")) {
557 if (*metaString ==
"enzyme_const_return") {
560 }
else if (*metaString ==
"enzyme_active_return") {
563 }
else if (*metaString ==
"enzyme_dup_return") {
566 }
else if (*metaString ==
"enzyme_noret") {
569 }
else if (*metaString ==
"enzyme_primal_return") {
580 if (
auto parsedWidth = parseWidthParameter(CI)) {
581 width = *parsedWidth;
587 if (fn->hasParamAttribute(0, Attribute::StructRet)) {
590 const DataLayout &DL = CI->getParent()->getModule()->getDataLayout();
592 Ty = fn->getParamAttribute(0, Attribute::StructRet).getValueAsType();
594 CTy = CI->getAttribute(AttributeList::FirstArgIndex, Attribute::StructRet)
596 auto FnSize = (DL.getTypeSizeInBits(Ty) / 8);
597 auto CSize = CTy ? (DL.getTypeSizeInBits(CTy) / 8) : 0;
605 if (CSize < count * FnSize) {
607 "IllegalByRefSize", CI->getDebugLoc(), CI,
"Struct return type ",
608 *CTy,
" (", CSize,
" bytes), not large enough to store ", count,
609 " returns of type ", *Ty,
" (", FnSize,
" bytes), width=", width,
610 " primal requested=", primalReturn);
612 Value *primal =
nullptr;
614 Value *sretPt = CI->getArgOperand(0);
615 PointerType *pty = cast<PointerType>(sretPt->getType());
616 primal = Builder.CreatePointerCast(
617 sretPt, PointerType::get(Ty, pty->getAddressSpace()));
619 AllocaInst *primalA =
new AllocaInst(Ty, DL.getAllocaAddrSpace(),
620 nullptr, DL.getPrefTypeAlign(Ty));
621 primalA->insertBefore(CI);
625 Value *shadow =
nullptr;
631 Value *sretPt = CI->getArgOperand(0);
632 PointerType *pty = cast<PointerType>(sretPt->getType());
633 auto shadowPtr = Builder.CreatePointerCast(
634 sretPt, PointerType::get(Ty, pty->getAddressSpace()));
637 shadowPtr = Builder.CreateConstGEP1_64(Ty, shadowPtr, 1);
640 Value *acc = UndefValue::get(ArrayType::get(
641 PointerType::get(Ty, pty->getAddressSpace()), width));
642 for (
size_t i = 0; i < width; ++i) {
644 Builder.CreateConstGEP1_64(Ty, shadowPtr, i + primalReturn);
645 acc = Builder.CreateInsertValue(acc, elem, i);
656 shadow = CI->getArgOperand(1);
662 args.push_back(primal);
664 args.push_back(shadow);
667 constants.push_back(retType);
669 primalReturn =
false;
672 ssize_t interleaved = -1;
675 maxsize = CI->arg_size();
676 size_t num_args = maxsize;
677 for (
unsigned i = 1 + sret; i < maxsize; ++i) {
678 Value *res = CI->getArgOperand(i);
679 auto metaString = getMetadataName(res);
680 if (metaString &&
startsWith(*metaString,
"enzyme_")) {
681 if (*metaString ==
"enzyme_interleave") {
691 for (ssize_t i = 1 + sret; (size_t)i < maxsize; ++i) {
692 Value *res = CI->getArgOperand(i);
693 auto metaString = getMetadataName(res);
694#if LLVM_VERSION_MAJOR > 16
695 std::optional<Value *> batchOffset;
696 std::optional<DIFFE_TYPE> opt_ty;
698 Optional<Value *> batchOffset;
699 Optional<DIFFE_TYPE> opt_ty;
704 bool skipArg =
false;
707 while (metaString &&
startsWith(*metaString,
"enzyme_")) {
708 if (*metaString ==
"enzyme_not_overwritten") {
710 }
else if (*metaString ==
"enzyme_byref") {
712 if (!isa<ConstantInt>(CI->getArgOperand(i))) {
713 EmitFailure(
"IllegalAllocatedSize", CI->getDebugLoc(), CI,
714 "illegal enzyme byref size ", *CI->getArgOperand(i),
718 byRefSize = cast<ConstantInt>(CI->getArgOperand(i))->getZExtValue();
719 assert(byRefSize > 0);
722 }
else if (*metaString ==
"enzyme_dup") {
724 }
else if (*metaString ==
"enzyme_dupv") {
727 Value *offset_arg = CI->getArgOperand(i);
728 if (offset_arg->getType()->isIntegerTy()) {
729 batchOffset = offset_arg;
731 EmitFailure(
"IllegalVectorOffset", CI->getDebugLoc(), CI,
732 "enzyme_batch must be followd by an integer "
734 *CI->getArgOperand(i),
" in", *CI);
737 }
else if (*metaString ==
"enzyme_dupnoneed") {
739 }
else if (*metaString ==
"enzyme_dupnoneedv") {
742 Value *offset_arg = CI->getArgOperand(i);
743 if (offset_arg->getType()->isIntegerTy()) {
744 batchOffset = offset_arg;
746 EmitFailure(
"IllegalVectorOffset", CI->getDebugLoc(), CI,
747 "enzyme_batch must be followd by an integer "
749 *CI->getArgOperand(i),
" in", *CI);
752 }
else if (*metaString ==
"enzyme_out") {
754 }
else if (*metaString ==
"enzyme_const") {
756 }
else if (*metaString ==
"enzyme_noret") {
759 }
else if (*metaString ==
"enzyme_allocated") {
762 if (!isa<ConstantInt>(CI->getArgOperand(i))) {
763 EmitFailure(
"IllegalAllocatedSize", CI->getDebugLoc(), CI,
764 "illegal enzyme allocated size ", *CI->getArgOperand(i),
769 cast<ConstantInt>(CI->getArgOperand(i))->getZExtValue();
772 }
else if (*metaString ==
"enzyme_tape") {
775 tape = CI->getArgOperand(i);
776 tapeIsPointer =
true;
779 }
else if (*metaString ==
"enzyme_nofree") {
784 }
else if (*metaString ==
"enzyme_runtime_activity") {
785 runtimeActivity =
true;
788 }
else if (*metaString ==
"enzyme_strong_zero") {
792 }
else if (*metaString ==
"enzyme_primal_return") {
795 }
else if (*metaString ==
"enzyme_const_return") {
798 }
else if (*metaString ==
"enzyme_active_return") {
801 }
else if (*metaString ==
"enzyme_dup_return") {
804 }
else if (*metaString ==
"enzyme_width") {
808 }
else if (*metaString ==
"enzyme_interface") {
810 dynamic_interface = CI->getArgOperand(i);
813 }
else if (*metaString ==
"enzyme_trace") {
814 trace = CI->getArgOperand(++i);
818 }
else if (*metaString ==
"enzyme_duptrace") {
819 trace = CI->getArgOperand(++i);
824 }
else if (*metaString ==
"enzyme_likelihood") {
825 likelihood = CI->getArgOperand(++i);
829 }
else if (*metaString ==
"enzyme_duplikelihood") {
830 likelihood = CI->getArgOperand(++i);
831 diffeLikelihood = CI->getArgOperand(++i);
835 }
else if (*metaString ==
"enzyme_observations") {
836 observations = CI->getArgOperand(++i);
840 }
else if (*metaString ==
"enzyme_active_rand_var") {
841 Value *
string = CI->getArgOperand(++i);
842 StringRef const_string;
843 if (getConstantStringInfo(
string, const_string)) {
844 ActiveRandomVariables.insert(const_string);
847 "IllegalStringType", CI->getDebugLoc(), CI,
848 "active variable address must be a compile-time constant", *CI,
854 EmitFailure(
"IllegalDiffeType", CI->getDebugLoc(), CI,
855 "illegal enzyme metadata classification ", *CI,
861 constants.push_back(*opt_ty);
867 if (i == CI->arg_size()) {
868 EmitFailure(
"EnzymeCallingError", CI->getDebugLoc(), CI,
869 "Too few arguments to Enzyme call ", *CI);
872 res = CI->getArgOperand(i);
873 metaString = getMetadataName(res);
880 Type *subTy =
nullptr;
881 if (truei < FT->getNumParams()) {
882 subTy = FT->getParamType(i);
885 if (differentialReturn && differet ==
nullptr) {
886 subTy = FT->getReturnType();
892 "illegal enzyme byval arg", truei,
" ", *res);
896 auto &DL = fn->getParent()->getDataLayout();
897 auto BitSize = DL.getTypeSizeInBits(subTy);
898 if (BitSize / 8 != byRefSize) {
899 EmitFailure(
"IllegalByRefSize", CI->getDebugLoc(), CI,
900 "illegal enzyme pointer type size ", *res,
" expected ",
901 byRefSize,
" (bytes) actual size ", BitSize,
904 res = Builder.CreateBitCast(
907 subTy, cast<PointerType>(res->getType())->getAddressSpace()));
908 res = Builder.CreateLoad(subTy, res);
912 if (truei >= FT->getNumParams()) {
913 if (!isa<MetadataAsValue>(res) &&
916 if (differentialReturn && differet ==
nullptr) {
918 if (CI->paramHasAttr(i, Attribute::ByVal)) {
920 T = CI->getParamAttr(i, Attribute::ByVal).getValueAsType();
921 differet = Builder.CreateLoad(T, differet);
923 if (differet->getType() != fn->getReturnType())
924 if (
auto ST0 = dyn_cast<StructType>(differet->getType()))
925 if (
auto ST1 = dyn_cast<StructType>(fn->getReturnType()))
926 if (ST0->isLayoutIdentical(ST1)) {
927 IRBuilder<> B(&Builder.GetInsertBlock()
931 auto AI = B.CreateAlloca(ST1);
932 Builder.CreateStore(differet, Builder.CreatePointerCast(
934 differet = Builder.CreateLoad(ST1, AI);
937 if (differet->getType() !=
940 "Bad DiffRet type ", *differet,
" expected ",
941 *fn->getReturnType());
945 }
else if (tape ==
nullptr) {
947 if (CI->paramHasAttr(i, Attribute::ByVal)) {
949 T = CI->getParamAttr(i, Attribute::ByVal).getValueAsType();
950 tape = Builder.CreateLoad(T, tape);
956 "Had too many arguments to __enzyme_autodiff", *CI,
957 " - extra arg - ", *res);
960 assert(truei < FT->getNumParams());
961 overwritten_args[truei] = overwritten;
963 auto PTy = FT->getParamType(truei);
966 : ((interleaved == -1) ?
whatType(PTy, mode) : last_ty);
969 constants.push_back(ty);
971 assert(truei < FT->getNumParams());
973 if (PTy != res->getType()) {
974 if (
auto ptr = dyn_cast<PointerType>(res->getType())) {
975 if (
auto PT = dyn_cast<PointerType>(PTy)) {
976 if (ptr->getAddressSpace() != PT->getAddressSpace()) {
977#if LLVM_VERSION_MAJOR < 17
978 if (CI->getContext().supportsTypedPointers()) {
979 res = Builder.CreateAddrSpaceCast(
980 res, PointerType::get(ptr->getPointerElementType(),
981 PT->getAddressSpace()));
983 res = Builder.CreateAddrSpaceCast(res, PT);
986 res = Builder.CreateAddrSpaceCast(res, PT);
991 llvm::errs() <<
"Warning cast(1) __enzyme_autodiff argument " << i
992 <<
" " << *res <<
"|" << *res->getType()
993 <<
" to argument " << truei <<
" " << *PTy <<
"\n"
994 <<
"orig: " << *FT <<
"\n";
998 if (res->getType()->canLosslesslyBitCastTo(PTy)) {
999 res = Builder.CreateBitCast(res, PTy);
1001 if (res->getType() != PTy && res->getType()->isIntegerTy() &&
1002 PTy->isIntegerTy(1)) {
1003 res = Builder.CreateTrunc(res, PTy);
1005 if (res->getType() != PTy) {
1006 auto loc = CI->getDebugLoc();
1007 if (
auto arg = dyn_cast<Instruction>(res)) {
1008 loc = arg->getDebugLoc();
1014 "Cannot cast __enzyme_autodiff primal argument ", i,
1015 ", found ", *res,
", type ", *res->getType(),
1016 " (simplified to ", *S,
" ) ",
" - to arg ", truei,
", ",
1021 if (CI->isByValArgument(i)) {
1022 byVal[args.size()] = CI->getParamByValType(i);
1025 args.push_back(res);
1027 if (interleaved == -1)
1030 Value *res =
nullptr;
1031#if LLVM_VERSION_MAJOR >= 16
1032 bool batch = batchOffset.has_value();
1034 bool batch = batchOffset.hasValue();
1037 for (
unsigned v = 0; v < width; ++v) {
1038 if ((
size_t)((interleaved == -1) ? i : interleaved) >= num_args) {
1039 EmitFailure(
"MissingArgShadow", CI->getDebugLoc(), CI,
1040 "__enzyme_autodiff missing argument shadow at index ",
1041 *((interleaved == -1) ? &i : &interleaved),
1042 ", need shadow of type ", *PTy,
1043 " to shadow primal argument ", *args.back(),
1050 CI->getArgOperand((interleaved == -1) ? i : interleaved);
1052 if (
auto elementPtrTy = dyn_cast<PointerType>(element->getType())) {
1053 element = Builder.CreateBitCast(
1054 element, PointerType::get(Type::getInt8Ty(CI->getContext()),
1055 elementPtrTy->getAddressSpace()));
1056 element = Builder.CreateGEP(
1057 Type::getInt8Ty(CI->getContext()), element,
1060 ConstantInt::get((*batchOffset)->getType(), v)));
1061 element = Builder.CreateBitCast(element, elementPtrTy);
1064 "NonPointerBatch", CI->getDebugLoc(), CI,
1065 "Batched argument at index ",
1066 *((interleaved == -1) ? &i : &interleaved),
1067 " must be of pointer type, found: ", *element->getType());
1071 if (PTy != element->getType()) {
1072 element = castToDiffeFunctionArgType(
1073 Builder, CI, FT, PTy, (interleaved == -1) ? i : interleaved,
1074 mode, element, truei);
1082 res ? Builder.CreateInsertValue(res, element, {v})
1083 : Builder.CreateInsertValue(UndefValue::get(ArrayType::get(
1084 element->getType(), width)),
1087 if (v < width - 1 && !batch && (interleaved == -1)) {
1095 if (interleaved != -1)
1099 args.push_back(res);
1104 if (truei < FT->getNumParams()) {
1105 auto numParams = FT->getNumParams();
1107 "EnzymeInsufficientArgs", CI->getDebugLoc(), CI,
1108 "Insufficient number of args passed to derivative call required ",
1109 numParams,
" primal args, found ", truei);
1113 return Options({differet,
1129 ActiveRandomVariables,
1133 subsequent_calls_may_write});
1139 for (
auto &a : type_args.
Function->args()) {
1141 if (a.getType()->isFPOrFPVectorTy()) {
1143 }
else if (a.getType()->isPointerTy()) {
1144#if LLVM_VERSION_MAJOR < 17
1145 if (a.getContext().supportsTypedPointers()) {
1146 auto et = a.getType()->getPointerElementType();
1147 if (et->isFPOrFPVectorTy()) {
1149 }
else if (et->isPointerTy()) {
1155 }
else if (a.getType()->isIntOrIntVectorTy()) {
1159 std::pair<Argument *, TypeTree>(&a, dt.
Only(-1,
nullptr)));
1163 std::pair<Argument *, std::set<int64_t>>(&a, {}));
1166 if (fn->getReturnType()->isFPOrFPVectorTy()) {
1167 dt =
ConcreteType(fn->getReturnType()->getScalarType());
1184 llvm_unreachable(
"Invalid float width");
1188 bool HandleTruncateFunc(CallInst *CI,
TruncateMode mode) {
1189 IRBuilder<> Builder(CI);
1190 Function *F = parseFunctionParameter(CI);
1193 unsigned ArgSize = CI->arg_size();
1194 if (ArgSize != 4 && ArgSize != 3) {
1196 "Had incorrect number of args to __enzyme_truncate_func", *CI,
1197 " - expected 3 or 4");
1202 auto Cfrom = cast<ConstantInt>(CI->getArgOperand(1));
1204 auto Cto = cast<ConstantInt>(CI->getArgOperand(2));
1207 getDefaultFloatRepr((
unsigned)Cfrom->getValue().getZExtValue()),
1208 getDefaultFloatRepr((
unsigned)Cto->getValue().getZExtValue()),
1210 }
else if (ArgSize == 4) {
1211 auto Cfrom = cast<ConstantInt>(CI->getArgOperand(1));
1213 auto Cto_exponent = cast<ConstantInt>(CI->getArgOperand(2));
1214 assert(Cto_exponent);
1215 auto Cto_significand = cast<ConstantInt>(CI->getArgOperand(3));
1216 assert(Cto_significand);
1218 getDefaultFloatRepr((
unsigned)Cfrom->getValue().getZExtValue()),
1220 (
unsigned)Cto_exponent->getValue().getZExtValue(),
1221 (
unsigned)Cto_significand->getValue().getZExtValue()),
1224 llvm_unreachable(
"??");
1231 res = Builder.CreatePointerCast(res, CI->getType());
1232 CI->replaceAllUsesWith(res);
1233 CI->eraseFromParent();
1237 bool HandleTruncateValue(CallInst *CI,
bool isTruncate) {
1238 IRBuilder<> Builder(CI);
1239 if (CI->arg_size() != 3) {
1241 "Had incorrect number of args to __enzyme_truncate_value",
1242 *CI,
" - expected 3");
1245 auto Cfrom = cast<ConstantInt>(CI->getArgOperand(1));
1247 auto Cto = cast<ConstantInt>(CI->getArgOperand(2));
1249 auto Addr = CI->getArgOperand(0);
1253 getDefaultFloatRepr((
unsigned)Cfrom->getValue().getZExtValue()),
1254 getDefaultFloatRepr((
unsigned)Cto->getValue().getZExtValue()),
1261 bool HandleBatch(CallInst *CI) {
1264 std::map<unsigned, Value *> batchOffset;
1265 SmallVector<Value *, 4> args;
1266 SmallVector<BATCH_TYPE, 4> arg_types;
1267 IRBuilder<> Builder(CI);
1268 Function *F = parseFunctionParameter(CI);
1273 FunctionType *FT = F->getFunctionType();
1276 if (
auto parsedWidth = parseWidthParameter(CI)) {
1277 width = *parsedWidth;
1284 CI->hasStructRetAttr() || F->hasParamAttribute(0, Attribute::StructRet);
1286 if (F->hasParamAttribute(0, Attribute::StructRet)) {
1288 Value *sretPt = CI->getArgOperand(0);
1290 args.push_back(sretPt);
1294 for (
unsigned i = 1 + sret; i < CI->arg_size(); ++i) {
1295 Value *res = CI->getArgOperand(i);
1297 if (truei >= FT->getNumParams()) {
1299 "Had too many arguments to __enzyme_batch", *CI,
1300 " - extra arg - ", *res);
1303 assert(truei < FT->getNumParams());
1304 auto PTy = FT->getParamType(truei);
1307 auto metaString = getMetadataName(res);
1310 if (metaString &&
startsWith(*metaString,
"enzyme_")) {
1311 if (*metaString ==
"enzyme_scalar") {
1313 }
else if (*metaString ==
"enzyme_vector") {
1315 }
else if (*metaString ==
"enzyme_buffer") {
1318 Value *offset_arg = CI->getArgOperand(i);
1319 if (offset_arg->getType()->isIntegerTy()) {
1320 batchOffset[i + 1] = offset_arg;
1322 EmitFailure(
"IllegalVectorOffset", CI->getDebugLoc(), CI,
1323 "enzyme_batch must be followd by an integer "
1325 *CI->getArgOperand(i),
" in", *CI);
1329 }
else if (*metaString ==
"enzyme_width") {
1333 EmitFailure(
"IllegalDiffeType", CI->getDebugLoc(), CI,
1334 "illegal enzyme metadata classification ", *CI,
1339 res = CI->getArgOperand(i);
1342 arg_types.push_back(ty);
1346 Value *res =
nullptr;
1347 bool batch = batchOffset.count(i - 1) != 0;
1349 for (
unsigned v = 0; v < width; ++v) {
1350 if (i >= CI->arg_size()) {
1351 EmitFailure(
"MissingVectorArg", CI->getDebugLoc(), CI,
1352 "__enzyme_batch missing vector argument at index ", i,
1353 ", need argument of type ", *PTy,
" at call ", *CI);
1358 Value *element = CI->getArgOperand(i);
1360 if (
auto elementPtrTy = dyn_cast<PointerType>(element->getType())) {
1361 element = Builder.CreateBitCast(
1362 element, PointerType::get(Type::getInt8Ty(CI->getContext()),
1363 elementPtrTy->getAddressSpace()));
1364 element = Builder.CreateGEP(
1365 Type::getInt8Ty(CI->getContext()), element,
1368 ConstantInt::get(batchOffset[i - 1]->getType(), v)));
1369 element = Builder.CreateBitCast(element, elementPtrTy);
1377 res ? Builder.CreateInsertValue(res, element, {v})
1378 : Builder.CreateInsertValue(UndefValue::get(ArrayType::get(
1379 element->getType(), width)),
1382 if (v < width - 1 && !batch) {
1391 args.push_back(res);
1394 args.push_back(res);
1400 BATCH_TYPE ret_type = (F->getReturnType()->isVoidTy() || width == 1)
1405 arg_types, ret_type);
1411 Builder.CreateCall(newFunc->getFunctionType(), newFunc, args);
1413 batch = adaptReturnedVector(CI, batch, Builder, width);
1416 Type *retElemType =
nullptr;
1417 if (CI->hasStructRetAttr()) {
1418 ret = CI->getArgOperand(0);
1420 CI->getAttribute(AttributeList::FirstArgIndex, Attribute::StructRet)
1423 ReplaceOriginalCall(Builder, ret, retElemType, batch, CI,
1429 bool HandleAutoDiff(Instruction *CI, CallingConv::ID CallingConv, Value *ret,
1430 Type *retElemType, SmallVectorImpl<Value *> &args,
1431 const std::map<int, Type *> &byVal,
1432 const std::vector<DIFFE_TYPE> &constants, Function *fn,
1434 SmallVectorImpl<CallInst *> &calls) {
1435 auto &differet = options.differet;
1436 auto &tape = options.tape;
1437 auto &width = options.width;
1438 auto &allocatedTapeSize = options.allocatedTapeSize;
1439 auto &freeMemory = options.freeMemory;
1440 auto &returnUsed = options.returnUsed;
1441 auto &tapeIsPointer = options.tapeIsPointer;
1442 auto &differentialReturn = options.differentialReturn;
1443 auto &retType = options.retType;
1444 auto &overwritten_args = options.overwritten_args;
1445 auto primalReturn = options.primalReturn;
1446 auto subsequent_calls_may_write = options.subsequent_calls_may_write;
1448 auto Arch = Triple(CI->getModule()->getTargetTriple()).getArch();
1449 bool AtomicAdd = Arch == Triple::nvptx || Arch == Triple::nvptx64 ||
1450 Arch == Triple::amdgcn;
1453 FnTypeInfo type_args = populate_type_args(TA, fn, mode);
1455 IRBuilder Builder(CI);
1459 Function *newFunc =
nullptr;
1460 Type *tapeType =
nullptr;
1465 if (primalReturn && fn->getReturnType()->isVoidTy()) {
1466 auto fnname = fn->getName();
1467 EmitFailure(
"PrimalRetOfVoid", CI->getDebugLoc(), CI,
1468 "Requested primal result of void-returning function type ",
1469 *fn->getFunctionType(),
" ", fnname,
" ", *CI);
1472 context, fn, retType, constants, TA,
1473 primalReturn, mode, freeMemory,
1474 options.runtimeActivity, options.strongZero, width,
1475 nullptr, type_args, subsequent_calls_may_write,
1480 bool forceAnonymousTape = !sizeOnly && allocatedTapeSize == -1;
1482 context, fn, retType, constants, TA,
1483 false,
false, type_args,
1484 subsequent_calls_may_write, overwritten_args, forceAnonymousTape,
1485 options.runtimeActivity, options.strongZero, width,
1487 auto &DL = fn->getParent()->getDataLayout();
1488 if (!forceAnonymousTape) {
1492 tapeType = (tapeIdx == -1)
1493 ? aug->
fn->getReturnType()
1494 : cast<StructType>(aug->
fn->getReturnType())
1495 ->getElementType(tapeIdx);
1498 CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), 0,
false));
1499 CI->eraseFromParent();
1504 auto size = DL.getTypeSizeInBits(tapeType) / 8;
1505 CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), size,
false));
1506 CI->eraseFromParent();
1510 DL.getTypeSizeInBits(tapeType) > 8 * (
size_t)allocatedTapeSize) {
1511 auto bytes = DL.getTypeSizeInBits(tapeType) / 8;
1512 EmitFailure(
"Insufficient tape allocation size", CI->getDebugLoc(),
1513 CI,
"need ", bytes,
" bytes have ", allocatedTapeSize,
1520 context, fn, retType, constants, TA,
1521 primalReturn, mode, freeMemory,
1522 options.runtimeActivity, options.strongZero, width,
1523 tapeType, type_args, subsequent_calls_may_write,
1524 overwritten_args, aug);
1533 .constant_args = constants,
1534 .subsequent_calls_may_write =
1535 subsequent_calls_may_write,
1536 .overwritten_args = overwritten_args,
1537 .returnUsed = primalReturn,
1538 .shadowReturnUsed =
false,
1541 .freeMemory = freeMemory,
1542 .AtomicAdd = AtomicAdd,
1543 .additionalType =
nullptr,
1544 .forceAnonymousTape =
false,
1545 .typeInfo = type_args,
1546 .runtimeActivity = options.runtimeActivity,
1547 .strongZero = options.strongZero},
1554 "SplitPrimalRet", CI->getDebugLoc(), CI,
1555 "Option enzyme_primal_return not available in reverse split mode");
1557 bool forceAnonymousTape = !sizeOnly && allocatedTapeSize == -1;
1561 context, fn, retType, constants, TA, returnUsed, shadowReturnUsed,
1562 type_args, subsequent_calls_may_write, overwritten_args,
1563 forceAnonymousTape, options.runtimeActivity, options.strongZero,
1566 auto &DL = fn->getParent()->getDataLayout();
1567 if (!forceAnonymousTape) {
1568 assert(!aug->tapeType);
1571 tapeType = (tapeIdx == -1)
1572 ? aug->fn->getReturnType()
1573 : cast<StructType>(aug->fn->getReturnType())
1574 ->getElementType(tapeIdx);
1577 CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), 0,
false));
1578 CI->eraseFromParent();
1583 auto size = DL.getTypeSizeInBits(tapeType) / 8;
1584 CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), size,
false));
1585 CI->eraseFromParent();
1589 DL.getTypeSizeInBits(tapeType) > 8 * (
size_t)allocatedTapeSize) {
1590 auto bytes = DL.getTypeSizeInBits(tapeType) / 8;
1591 EmitFailure(
"Insufficient tape allocation size", CI->getDebugLoc(),
1592 CI,
"need ", bytes,
" bytes have ", allocatedTapeSize,
1605 .constant_args = constants,
1606 .subsequent_calls_may_write =
1607 subsequent_calls_may_write,
1608 .overwritten_args = overwritten_args,
1609 .returnUsed =
false,
1610 .shadowReturnUsed =
false,
1613 .freeMemory = freeMemory,
1614 .AtomicAdd = AtomicAdd,
1615 .additionalType = tapeType,
1616 .forceAnonymousTape = forceAnonymousTape,
1617 .typeInfo = type_args,
1618 .runtimeActivity = options.runtimeActivity,
1619 .strongZero = options.strongZero},
1625 StringRef n = fn->getName();
1626 EmitFailure(
"FailedToDifferentiate", fn->getSubprogram(),
1627 &*fn->getEntryBlock().begin(),
1628 "Could not generate derivative function of ", n);
1632 if (differentialReturn) {
1634 args.push_back(differet);
1635 else if (fn->getReturnType()->isFPOrFPVectorTy()) {
1636 Constant *seed = ConstantFP::get(fn->getReturnType(), 1.0);
1638 args.push_back(seed);
1640 ArrayType *arrayType = ArrayType::get(fn->getReturnType(), width);
1641 args.push_back(ConstantArray::get(
1642 arrayType, SmallVector<Constant *, 3>(width, seed)));
1644 }
else if (
auto ST = dyn_cast<StructType>(fn->getReturnType())) {
1645 SmallVector<Constant *, 2> csts;
1646 for (
auto e : ST->elements()) {
1647 csts.push_back(ConstantFP::get(e, 1.0));
1649 args.push_back(ConstantStruct::get(ST, csts));
1650 }
else if (
auto AT = dyn_cast<ArrayType>(fn->getReturnType())) {
1651 SmallVector<Constant *, 2> csts(
1652 AT->getNumElements(), ConstantFP::get(AT->getElementType(), 1.0));
1653 args.push_back(ConstantArray::get(AT, csts));
1655 auto RT = fn->getReturnType();
1656 EmitFailure(
"EnzymeCallingError", CI->getDebugLoc(), CI,
1657 "Differential return required for call ", *CI,
1658 " but one of type ", *RT,
" could not be auto deduced");
1666 auto &DL = fn->getParent()->getDataLayout();
1667 if (tapeIsPointer) {
1668 tape = Builder.CreateBitCast(
1669 tape, PointerType::get(
1671 cast<PointerType>(tape->getType())->getAddressSpace()));
1672 tape = Builder.CreateLoad(tapeType, tape);
1673 }
else if (tapeType != tape->getType() &&
1674 DL.getTypeSizeInBits(tapeType) <=
1675 DL.getTypeSizeInBits(tape->getType())) {
1676 IRBuilder<> EB(&CI->getParent()->getParent()->getEntryBlock().front());
1677 auto AL = EB.CreateAlloca(tape->getType());
1678 Builder.CreateStore(tape, AL);
1679 tape = Builder.CreateLoad(
1680 tapeType, Builder.CreatePointerCast(AL,
getUnqual(tapeType)));
1682 assert(tape->getType() == tapeType);
1683 args.push_back(tape);
1687 llvm::errs() <<
"postfn:\n" << *newFunc <<
"\n";
1689 Builder.setFastMathFlags(
getFast());
1692 if (args.size() != newFunc->getFunctionType()->getNumParams()) {
1693 llvm::errs() << *CI <<
"\n";
1694 llvm::errs() << *newFunc <<
"\n";
1695 for (
auto arg : args) {
1696 llvm::errs() <<
" + " << *arg <<
"\n";
1700 "TooFewArguments", CI->getDebugLoc(), CI,
1701 "Too few arguments passed to __enzyme_autodiff mode=", modestr);
1704 assert(args.size() == newFunc->getFunctionType()->getNumParams());
1705 for (
size_t i = 0; i < args.size(); i++) {
1706 if (args[i]->getType() != newFunc->getFunctionType()->getParamType(i)) {
1707 llvm::errs() << *CI <<
"\n";
1708 llvm::errs() << *newFunc <<
"\n";
1709 for (
auto arg : args) {
1710 llvm::errs() <<
" + " << *arg <<
"\n";
1713 EmitFailure(
"BadArgumentType", CI->getDebugLoc(), CI,
1714 "Incorrect argument type passed to __enzyme_autodiff mode=",
1715 modestr,
" at index ", i,
" expected ",
1716 *newFunc->getFunctionType()->getParamType(i),
" found ",
1717 *args[i]->getType());
1721 CallInst *diffretc = cast<CallInst>(Builder.CreateCall(newFunc, args));
1722 diffretc->setCallingConv(CallingConv);
1723 diffretc->setDebugLoc(CI->getDebugLoc());
1725 for (
auto &&[attr, ty] : byVal) {
1726 diffretc->addParamAttr(
1727 attr, Attribute::getWithByValType(diffretc->getContext(), ty));
1730 Value *diffret = diffretc;
1734 tapeType = (tapeIdx == -1) ? aug->
fn->getReturnType()
1735 : cast<StructType>(aug->
fn->getReturnType())
1736 ->getElementType(tapeIdx);
1737 unsigned idxs[] = {(unsigned)tapeIdx};
1738 Value *tapeRes = (tapeIdx == -1)
1740 : Builder.CreateExtractValue(diffret, idxs);
1741 Builder.CreateStore(
1743 Builder.CreateBitCast(
1747 cast<PointerType>(tape->getType())->getAddressSpace())));
1748 if (tapeIdx != -1) {
1749 auto ST = cast<StructType>(diffret->getType());
1750 SmallVector<Type *, 2> tys(ST->elements().begin(),
1751 ST->elements().end());
1752 tys.erase(tys.begin());
1753 auto ST0 = StructType::get(ST->getContext(), tys);
1754 Value *out = UndefValue::get(ST0);
1755 for (
unsigned i = 0; i < tys.size(); i++) {
1756 out = Builder.CreateInsertValue(
1757 out, Builder.CreateExtractValue(diffret, {i + 1}), {i});
1761 auto ST0 = StructType::get(tape->getContext(), {});
1762 diffret = UndefValue::get(ST0);
1769 if (width > 1 && !diffret->getType()->isEmptyTy() &&
1770 !diffret->getType()->isVoidTy() &&
1774 diffret = adaptReturnedVector(ret, diffret, Builder, width);
1777 ReplaceOriginalCall(Builder, ret, retElemType, diffret, CI, mode);
1778 calls.push_back(diffretc);
1783 bool HandleAutoDiffArguments(CallInst *CI,
DerivativeMode mode,
bool sizeOnly,
1784 SmallVectorImpl<CallInst *> &calls) {
1787 Function *fn = parseFunctionParameter(CI);
1791 IRBuilder<> Builder(CI);
1794 llvm::errs() <<
"prefn:\n" << *fn <<
"\n";
1796 std::map<int, Type *> byVal;
1797 std::vector<DIFFE_TYPE> constants;
1798 SmallVector<Value *, 2> args;
1800 auto options = handleArguments(Builder, CI, fn, mode, sizeOnly, constants,
1808 Type *retElemType =
nullptr;
1809 if (CI->hasStructRetAttr()) {
1810 ret = CI->getArgOperand(0);
1812 CI->getAttribute(AttributeList::FirstArgIndex, Attribute::StructRet)
1816 return HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, args,
1817 byVal, constants, fn, mode, *options, sizeOnly,
1822 SmallVectorImpl<CallInst *> &calls) {
1823 IRBuilder<> Builder(CI);
1824 Function *F = parseFunctionParameter(CI);
1830 std::vector<DIFFE_TYPE> constants;
1831 std::map<int, Type *> byVal;
1832 SmallVector<Value *, 4> args;
1836 auto opt = handleArguments(Builder, CI, F, diffeMode,
false, constants,
1839 SmallVector<Value *, 6> dargs(args.begin(), args.end());
1841#if LLVM_VERSION_MAJOR >= 16
1842 if (!opt.has_value())
1845 if (!opt.hasValue())
1849 auto dynamic_interface = opt->dynamic_interface;
1850 auto trace = opt->trace;
1851 auto dtrace = opt->diffeTrace;
1852 auto observations = opt->observations;
1853 auto likelihood = opt->likelihood;
1854 auto dlikelihood = opt->diffeLikelihood;
1857 bool has_dynamic_interface = dynamic_interface !=
nullptr;
1858 bool needs_interface =
1860 std::unique_ptr<TraceInterface> interface;
1861 if (has_dynamic_interface) {
1864 }
else if (needs_interface) {
1869 SmallPtrSet<Function *, 4> sampleFunctions;
1870 SmallPtrSet<Function *, 4> observeFunctions;
1871 for (
auto &func : F->getParent()->functions()) {
1872 if (func.getName().contains(
"__enzyme_sample")) {
1873 assert(func.getFunctionType()->getNumParams() >= 3);
1874 sampleFunctions.insert(&func);
1875 }
else if (func.getName().contains(
"__enzyme_observe")) {
1876 assert(func.getFunctionType()->getNumParams() >= 3);
1877 observeFunctions.insert(&func);
1881 assert(!sampleFunctions.empty() || !observeFunctions.empty());
1883 bool autodiff = dtrace || dlikelihood;
1884 IRBuilder<> AllocaBuilder(CI->getParent()->getFirstNonPHI());
1887 likelihood = AllocaBuilder.CreateAlloca(AllocaBuilder.getDoubleTy(),
1888 nullptr,
"likelihood");
1889 Builder.CreateStore(ConstantFP::getNullValue(Builder.getDoubleTy()),
1892 args.push_back(likelihood);
1894 if (autodiff && !dlikelihood) {
1895 dlikelihood = AllocaBuilder.CreateAlloca(AllocaBuilder.getDoubleTy(),
1896 nullptr,
"dlikelihood");
1897 Builder.CreateStore(ConstantFP::get(Builder.getDoubleTy(), 1.0),
1902 dargs.push_back(likelihood);
1903 dargs.push_back(dlikelihood);
1905 opt->overwritten_args.push_back(
false);
1908 opt->overwritten_args.push_back(
false);
1912 opt->overwritten_args.push_back(
false);
1913 args.push_back(observations);
1914 dargs.push_back(observations);
1919 opt->overwritten_args.push_back(
false);
1920 args.push_back(trace);
1921 dargs.push_back(trace);
1926 RequestContext(CI, &Builder), F, sampleFunctions, observeFunctions,
1927 opt->ActiveRandomVariables, mode, autodiff, interface.get());
1930 auto call = CallInst::Create(newFunc->getFunctionType(), newFunc, args);
1931 ReplaceInstWithInst(CI, call);
1936 Type *retElemType =
nullptr;
1937 if (CI->hasStructRetAttr()) {
1938 ret = CI->getArgOperand(0);
1940 CI->getAttribute(AttributeList::FirstArgIndex, Attribute::StructRet)
1944 bool status = HandleAutoDiff(
1945 CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, constants,
1951 bool handleFullModuleTrunc(Function &F) {
1954 typedef std::vector<FloatTruncation> TruncationsTy;
1955 static TruncationsTy FullModuleTruncs = []() -> TruncationsTy {
1957 auto Invalid = [=]() {
1959 llvm::report_fatal_error(
"error: invalid format for truncation config");
1963 auto parseFloatRepr = [&]() -> std::optional<FloatRepresentation> {
1965 if (ConfigStr.consumeInteger(10, Tmp))
1967 if (ConfigStr.consume_front(
"-")) {
1969 if (ConfigStr.consumeInteger(10, Tmp2))
1973 return getDefaultFloatRepr(Tmp);
1979 auto From = parseFloatRepr();
1980 if (!From && !ConfigStr.empty())
1984 if (!ConfigStr.consume_front(
"to"))
1986 auto To = parseFloatRepr();
1990 ConfigStr.consume_front(
";");
1995 if (FullModuleTruncs.empty())
1999 for (
auto Truncation : FullModuleTruncs) {
2000 IRBuilder<> Builder(F.getContext());
2005 ValueToValueMapTy Mapping;
2006 for (
auto &&[Arg, TArg] : llvm::zip(F.args(), TruncatedFunc->args()))
2007 Mapping[&TArg] = &Arg;
2011#if LLVM_VERSION_MAJOR >= 16
2012 F.splice(F.begin(), TruncatedFunc);
2014 F.getBasicBlockList().splice(F.begin(),
2015 TruncatedFunc->getBasicBlockList());
2017 RemapFunction(F, Mapping,
2018 RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
2019 TruncatedFunc->deleteBody();
2024 bool lowerEnzymeCalls(Function &F, std::set<Function *> &done) {
2032 if (handleFullModuleTrunc(F))
2035 bool Changed =
false;
2037 for (BasicBlock &BB : F)
2038 if (InvokeInst *II = dyn_cast<InvokeInst>(BB.getTerminator())) {
2040 Function *Fn = II->getCalledFunction();
2042 if (
auto castinst = dyn_cast<ConstantExpr>(II->getCalledOperand())) {
2043 if (castinst->isCast())
2044 if (
auto fn = dyn_cast<Function>(castinst->getOperand(0)))
2050 if (!(Fn->getName().contains(
"__enzyme_float") ||
2051 Fn->getName().contains(
"__enzyme_double") ||
2052 Fn->getName().contains(
"__enzyme_integer") ||
2053 Fn->getName().contains(
"__enzyme_pointer") ||
2054 Fn->getName().contains(
"__enzyme_virtualreverse") ||
2055 Fn->getName().contains(
"__enzyme_call_inactive") ||
2056 Fn->getName().contains(
"__enzyme_autodiff") ||
2057 Fn->getName().contains(
"__enzyme_fwddiff") ||
2058 Fn->getName().contains(
"__enzyme_fwdsplit") ||
2059 Fn->getName().contains(
"__enzyme_augmentfwd") ||
2060 Fn->getName().contains(
"__enzyme_augmentsize") ||
2061 Fn->getName().contains(
"__enzyme_reverse") ||
2062 Fn->getName().contains(
"__enzyme_truncate") ||
2063 Fn->getName().contains(
"__enzyme_batch") ||
2064 Fn->getName().contains(
"__enzyme_error_estimate") ||
2065 Fn->getName().contains(
"__enzyme_trace") ||
2066 Fn->getName().contains(
"__enzyme_condition")))
2069 SmallVector<Value *, 16> CallArgs(II->arg_begin(), II->arg_end());
2070 SmallVector<OperandBundleDef, 1> OpBundles;
2071 II->getOperandBundlesAsDefs(OpBundles);
2074 CallInst::Create(II->getFunctionType(), II->getCalledOperand(),
2075 CallArgs, OpBundles,
"", II);
2076 NewCall->takeName(II);
2077 NewCall->setCallingConv(II->getCallingConv());
2078 NewCall->setAttributes(II->getAttributes());
2079 NewCall->setDebugLoc(II->getDebugLoc());
2080 II->replaceAllUsesWith(NewCall);
2083 BranchInst::Create(II->getNormalDest(), II);
2086 II->getUnwindDest()->removePredecessor(&BB);
2088 II->eraseFromParent();
2092 MapVector<CallInst *, DerivativeMode> toLower;
2093 MapVector<CallInst *, DerivativeMode> toVirtual;
2094 MapVector<CallInst *, DerivativeMode> toSize;
2095 SmallVector<CallInst *, 4> toBatch;
2096 SmallVector<CallInst *, 4> toTruncateFuncMem;
2097 SmallVector<CallInst *, 4> toTruncateFuncOp;
2098 SmallVector<CallInst *, 4> toTruncateValue;
2099 SmallVector<CallInst *, 4> toExpandValue;
2100 MapVector<CallInst *, ProbProgMode> toProbProg;
2101 SetVector<CallInst *> InactiveCalls;
2102 SetVector<CallInst *> IterCalls;
2104 for (BasicBlock &BB : F) {
2105 for (Instruction &I : BB) {
2106 CallInst *CI = dyn_cast<CallInst>(&I);
2111 Function *Fn =
nullptr;
2113 Value *FnOp = CI->getCalledOperand();
2115 if ((Fn = dyn_cast<Function>(FnOp)))
2117 if (
auto castinst = dyn_cast<ConstantExpr>(FnOp)) {
2118 if (castinst->isCast()) {
2119 FnOp = castinst->getOperand(0);
2129 size_t num_args = CI->arg_size();
2131 if (Fn->getName().contains(
"__enzyme_todense") ||
2132 Fn->getName().contains(
"__enzyme_ignore_derivatives")) {
2133#if LLVM_VERSION_MAJOR >= 16
2134 CI->setOnlyReadsMemory();
2135 CI->setOnlyWritesMemory();
2137 CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
2140 if (Fn->getName().contains(
"__enzyme_float")) {
2141#if LLVM_VERSION_MAJOR >= 16
2142 CI->setOnlyReadsMemory();
2143 CI->setOnlyWritesMemory();
2145 CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
2147 for (
size_t i = 0; i < num_args; ++i) {
2148 if (CI->getArgOperand(i)->getType()->isPointerTy()) {
2149 CI->addParamAttr(i, Attribute::ReadNone);
2154 if (Fn->getName().contains(
"__enzyme_integer")) {
2155#if LLVM_VERSION_MAJOR >= 16
2156 CI->setOnlyReadsMemory();
2157 CI->setOnlyWritesMemory();
2159 CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
2161 for (
size_t i = 0; i < num_args; ++i) {
2162 if (CI->getArgOperand(i)->getType()->isPointerTy()) {
2163 CI->addParamAttr(i, Attribute::ReadNone);
2168 if (Fn->getName().contains(
"__enzyme_double")) {
2169#if LLVM_VERSION_MAJOR >= 16
2170 CI->setOnlyReadsMemory();
2171 CI->setOnlyWritesMemory();
2173 CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
2175 for (
size_t i = 0; i < num_args; ++i) {
2176 if (CI->getArgOperand(i)->getType()->isPointerTy()) {
2177 CI->addParamAttr(i, Attribute::ReadNone);
2182 if (Fn->getName().contains(
"__enzyme_pointer")) {
2183#if LLVM_VERSION_MAJOR >= 16
2184 CI->setOnlyReadsMemory();
2185 CI->setOnlyWritesMemory();
2187 CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
2189 for (
size_t i = 0; i < num_args; ++i) {
2190 if (CI->getArgOperand(i)->getType()->isPointerTy()) {
2191 CI->addParamAttr(i, Attribute::ReadNone);
2196 if (Fn->getName().contains(
"__enzyme_virtualreverse")) {
2197#if LLVM_VERSION_MAJOR >= 16
2198 CI->setOnlyReadsMemory();
2199 CI->setOnlyWritesMemory();
2201 CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
2204 if (Fn->getName().contains(
"__enzyme_iter")) {
2205#if LLVM_VERSION_MAJOR >= 16
2206 CI->setOnlyReadsMemory();
2207 CI->setOnlyWritesMemory();
2209 CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
2212 if (Fn->getName().contains(
"__enzyme_call_inactive")) {
2213 InactiveCalls.insert(CI);
2215 if (Fn->getName() ==
"omp_get_max_threads" ||
2216 Fn->getName() ==
"omp_get_thread_num") {
2217#if LLVM_VERSION_MAJOR >= 16
2218 Fn->setOnlyAccessesInaccessibleMemory();
2219 CI->setOnlyAccessesInaccessibleMemory();
2220 Fn->setOnlyReadsMemory();
2221 CI->setOnlyReadsMemory();
2223 Fn->addFnAttr(Attribute::InaccessibleMemOnly);
2224 CI->addAttribute(AttributeList::FunctionIndex,
2225 Attribute::InaccessibleMemOnly);
2226 Fn->addFnAttr(Attribute::ReadOnly);
2227 CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadOnly);
2230 if ((Fn->getName() ==
"cblas_ddot" || Fn->getName() ==
"cblas_sdot") &&
2231 Fn->isDeclaration()) {
2232#if LLVM_VERSION_MAJOR >= 16
2233 Fn->setOnlyAccessesArgMemory();
2234 Fn->setOnlyReadsMemory();
2235 CI->setOnlyReadsMemory();
2237 Fn->addFnAttr(Attribute::ArgMemOnly);
2238 Fn->addFnAttr(Attribute::ReadOnly);
2239 CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadOnly);
2241 CI->addParamAttr(1, Attribute::ReadOnly);
2243 CI->addParamAttr(3, Attribute::ReadOnly);
2246 if (Fn->getName() ==
"frexp" || Fn->getName() ==
"frexpf" ||
2247 Fn->getName() ==
"frexpl") {
2248#if LLVM_VERSION_MAJOR >= 16
2249 CI->setOnlyAccessesArgMemory();
2251 CI->addAttribute(AttributeList::FunctionIndex, Attribute::ArgMemOnly);
2253 CI->addParamAttr(1, Attribute::WriteOnly);
2255 if (Fn->getName() ==
"__fd_sincos_1" || Fn->getName() ==
"__fd_cos_1" ||
2256 Fn->getName() ==
"__mth_i_ipowi") {
2257#if LLVM_VERSION_MAJOR >= 16
2258 CI->setOnlyReadsMemory();
2259 CI->setOnlyWritesMemory();
2261 CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
2265 Fn->addParamAttr(0, Attribute::ReadOnly);
2266 Fn->addParamAttr(1, Attribute::ReadOnly);
2267#if LLVM_VERSION_MAJOR >= 16
2268 Fn->setOnlyReadsMemory();
2269 CI->setOnlyReadsMemory();
2271 Fn->addFnAttr(Attribute::ReadOnly);
2272 CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadOnly);
2275 if (Fn->getName() ==
"f90io_fmtw_end" ||
2276 Fn->getName() ==
"f90io_unf_end") {
2277#if LLVM_VERSION_MAJOR >= 16
2278 Fn->setOnlyAccessesInaccessibleMemory();
2279 CI->setOnlyAccessesInaccessibleMemory();
2281 Fn->addFnAttr(Attribute::InaccessibleMemOnly);
2282 CI->addAttribute(AttributeList::FunctionIndex,
2283 Attribute::InaccessibleMemOnly);
2286 if (Fn->getName() ==
"f90io_open2003a") {
2287#if LLVM_VERSION_MAJOR >= 16
2288 Fn->setOnlyAccessesInaccessibleMemOrArgMem();
2289 CI->setOnlyAccessesInaccessibleMemOrArgMem();
2291 Fn->addFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
2292 CI->addAttribute(AttributeList::FunctionIndex,
2293 Attribute::InaccessibleMemOrArgMemOnly);
2295 for (
size_t i : {0, 1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13}) {
2297 CI->getArgOperand(i)->getType()->isPointerTy()) {
2298 CI->addParamAttr(i, Attribute::ReadOnly);
2302 for (
size_t i : {0, 1}) {
2304 CI->getArgOperand(i)->getType()->isPointerTy()) {
2309 if (Fn->getName() ==
"f90io_fmtw_inita") {
2310#if LLVM_VERSION_MAJOR >= 16
2311 Fn->setOnlyAccessesInaccessibleMemOrArgMem();
2312 CI->setOnlyAccessesInaccessibleMemOrArgMem();
2314 Fn->addFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
2315 CI->addAttribute(AttributeList::FunctionIndex,
2316 Attribute::InaccessibleMemOrArgMemOnly);
2319 for (
size_t i : {0, 2}) {
2321 CI->getArgOperand(i)->getType()->isPointerTy()) {
2322 CI->addParamAttr(i, Attribute::ReadOnly);
2327 for (
size_t i : {0, 2}) {
2329 CI->getArgOperand(i)->getType()->isPointerTy()) {
2335 if (Fn->getName() ==
"f90io_unf_init") {
2336#if LLVM_VERSION_MAJOR >= 16
2337 Fn->setOnlyAccessesInaccessibleMemOrArgMem();
2338 CI->setOnlyAccessesInaccessibleMemOrArgMem();
2340 Fn->addFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
2341 CI->addAttribute(AttributeList::FunctionIndex,
2342 Attribute::InaccessibleMemOrArgMemOnly);
2345 for (
size_t i : {0, 1, 2, 3}) {
2347 CI->getArgOperand(i)->getType()->isPointerTy()) {
2348 CI->addParamAttr(i, Attribute::ReadOnly);
2353 for (
size_t i : {0, 1, 2, 3}) {
2355 CI->getArgOperand(i)->getType()->isPointerTy()) {
2361 if (Fn->getName() ==
"f90io_src_info03a") {
2362#if LLVM_VERSION_MAJOR >= 16
2363 Fn->setOnlyAccessesInaccessibleMemOrArgMem();
2364 CI->setOnlyAccessesInaccessibleMemOrArgMem();
2366 Fn->addFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
2367 CI->addAttribute(AttributeList::FunctionIndex,
2368 Attribute::InaccessibleMemOrArgMemOnly);
2371 for (
size_t i : {0, 1}) {
2373 CI->getArgOperand(i)->getType()->isPointerTy()) {
2374 CI->addParamAttr(i, Attribute::ReadOnly);
2379 for (
size_t i : {0}) {
2381 CI->getArgOperand(i)->getType()->isPointerTy()) {
2386 if (Fn->getName() ==
"f90io_sc_d_fmt_write" ||
2387 Fn->getName() ==
"f90io_sc_i_fmt_write" ||
2388 Fn->getName() ==
"ftnio_fmt_write64" ||
2389 Fn->getName() ==
"f90io_fmt_write64_aa" ||
2390 Fn->getName() ==
"f90io_fmt_writea" ||
2391 Fn->getName() ==
"f90io_unf_writea" ||
2392 Fn->getName() ==
"f90_pausea") {
2393#if LLVM_VERSION_MAJOR >= 16
2394 Fn->setOnlyAccessesInaccessibleMemOrArgMem();
2395 CI->setOnlyAccessesInaccessibleMemOrArgMem();
2397 Fn->addFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
2398 CI->addAttribute(AttributeList::FunctionIndex,
2399 Attribute::InaccessibleMemOrArgMemOnly);
2401 for (
size_t i = 0; i < num_args; ++i) {
2402 if (CI->getArgOperand(i)->getType()->isPointerTy()) {
2403 CI->addParamAttr(i, Attribute::ReadOnly);
2409 bool enableEnzyme =
false;
2410 bool virtualCall =
false;
2411 bool sizeOnly =
false;
2413 bool truncateFuncOp =
false;
2414 bool truncateFuncMem =
false;
2415 bool truncateValue =
false;
2416 bool expandValue =
false;
2417 bool probProg =
false;
2420 if (Fn->getName().contains(
"__enzyme_autodiff")) {
2421 enableEnzyme =
true;
2423 }
else if (Fn->getName().contains(
"__enzyme_fwddiff")) {
2424 enableEnzyme =
true;
2426 }
else if (Fn->getName().contains(
"__enzyme_error_estimate")) {
2427 enableEnzyme =
true;
2429 }
else if (Fn->getName().contains(
"__enzyme_fwdsplit")) {
2430 enableEnzyme =
true;
2432 }
else if (Fn->getName().contains(
"__enzyme_augmentfwd")) {
2433 enableEnzyme =
true;
2435 }
else if (Fn->getName().contains(
"__enzyme_augmentsize")) {
2436 enableEnzyme =
true;
2439 }
else if (Fn->getName().contains(
"__enzyme_reverse")) {
2440 enableEnzyme =
true;
2442 }
else if (Fn->getName().contains(
"__enzyme_virtualreverse")) {
2443 enableEnzyme =
true;
2446 }
else if (Fn->getName().contains(
"__enzyme_batch")) {
2447 enableEnzyme =
true;
2449 }
else if (Fn->getName().contains(
"__enzyme_truncate_mem_func")) {
2450 enableEnzyme =
true;
2451 truncateFuncMem =
true;
2452 }
else if (Fn->getName().contains(
"__enzyme_truncate_op_func")) {
2453 enableEnzyme =
true;
2454 truncateFuncOp =
true;
2455 }
else if (Fn->getName().contains(
"__enzyme_truncate_mem_value")) {
2456 enableEnzyme =
true;
2457 truncateValue =
true;
2458 }
else if (Fn->getName().contains(
"__enzyme_expand_mem_value")) {
2459 enableEnzyme =
true;
2461 }
else if (Fn->getName().contains(
"__enzyme_likelihood")) {
2462 enableEnzyme =
true;
2465 }
else if (Fn->getName().contains(
"__enzyme_trace")) {
2466 enableEnzyme =
true;
2469 }
else if (Fn->getName().contains(
"__enzyme_condition")) {
2470 enableEnzyme =
true;
2477 Value *fn = CI->getArgOperand(0);
2478 while (
auto ci = dyn_cast<CastInst>(fn)) {
2479 fn = ci->getOperand(0);
2481 while (
auto ci = dyn_cast<BlockAddress>(fn)) {
2482 fn = ci->getFunction();
2484 while (
auto ci = dyn_cast<ConstantExpr>(fn)) {
2485 fn = ci->getOperand(0);
2487 if (
auto si = dyn_cast<SelectInst>(fn)) {
2488 BasicBlock *post = BB.splitBasicBlock(CI);
2489 BasicBlock *sel1 = BasicBlock::Create(BB.getContext(),
"sel1", &F);
2490 BasicBlock *sel2 = BasicBlock::Create(BB.getContext(),
"sel2", &F);
2491 BB.getTerminator()->eraseFromParent();
2492 IRBuilder<> PB(&BB);
2493 PB.CreateCondBr(si->getCondition(), sel1, sel2);
2494 IRBuilder<> S1(sel1);
2495 auto B1 = S1.CreateBr(post);
2496 CallInst *cloned = cast<CallInst>(CI->clone());
2497 cloned->insertBefore(B1);
2498 cloned->setOperand(0, si->getTrueValue());
2499 IRBuilder<> S2(sel2);
2500 auto B2 = S2.CreateBr(post);
2502 CI->setOperand(0, si->getFalseValue());
2503 if (CI->getNumUses() != 0) {
2504 IRBuilder<> P(post->getFirstNonPHI());
2505 auto merge = P.CreatePHI(CI->getType(), 2);
2506 merge->addIncoming(cloned, sel1);
2507 merge->addIncoming(CI, sel2);
2508 CI->replaceAllUsesWith(merge);
2513 toVirtual[CI] = derivativeMode;
2515 toSize[CI] = derivativeMode;
2517 toBatch.push_back(CI);
2518 else if (truncateFuncOp)
2519 toTruncateFuncOp.push_back(CI);
2520 else if (truncateFuncMem)
2521 toTruncateFuncMem.push_back(CI);
2522 else if (truncateValue)
2523 toTruncateValue.push_back(CI);
2524 else if (expandValue)
2525 toExpandValue.push_back(CI);
2526 else if (probProg) {
2527 toProbProg[CI] = probProgMode;
2529 toLower[CI] = derivativeMode;
2531 if (
auto dc = dyn_cast<Function>(fn)) {
2536 Changed |= lowerEnzymeCalls(*dc, done);
2543 for (
auto CI : InactiveCalls) {
2545 Value *fn = CI->getArgOperand(0);
2546 SmallVector<Value *, 4>
Args;
2547 SmallVector<Type *, 4> ArgTypes;
2548 for (
size_t i = 1; i < CI->arg_size(); ++i) {
2549 Args.push_back(CI->getArgOperand(i));
2550 ArgTypes.push_back(CI->getArgOperand(i)->getType());
2552 auto FT = FunctionType::get(CI->getType(), ArgTypes,
false);
2553 if (fn->getType() != FT) {
2554 fn = B.CreatePointerCast(fn,
getUnqual(FT));
2556 auto Rep = B.CreateCall(FT, fn,
Args);
2557 Rep->addAttribute(AttributeList::FunctionIndex,
2558 Attribute::get(Rep->getContext(),
"enzyme_inactive"));
2559 CI->replaceAllUsesWith(Rep);
2560 CI->eraseFromParent();
2564 SmallVector<CallInst *, 1> calls;
2567 for (
auto pair : toSize) {
2568 bool successful = HandleAutoDiffArguments(pair.first, pair.second,
2574 for (
auto pair : toLower) {
2575 bool successful = HandleAutoDiffArguments(pair.first, pair.second,
2582 for (
auto pair : toVirtual) {
2583 auto CI = pair.first;
2584 Constant *fn = dyn_cast<Constant>(CI->getArgOperand(0));
2586 EmitFailure(
"IllegalVirtual", CI->getDebugLoc(), CI,
2587 "Cannot create virtual version of non-constant value ", *CI,
2588 *CI->getArgOperand(0));
2595 CI->getParent()->getParent()->getParent()->getTargetTriple())
2598 bool AtomicAdd = Arch == Triple::nvptx || Arch == Triple::nvptx64 ||
2599 Arch == Triple::amdgcn;
2601 IRBuilder<> Builder(CI);
2604 Logic.
PPC.
FAM.getResult<TargetLibraryAnalysis>(F), TA, fn,
2605 pair.second,
false,
false,
2607 CI->replaceAllUsesWith(ConstantExpr::getPointerCast(val, CI->getType()));
2608 CI->eraseFromParent();
2612 for (
auto call : toBatch) {
2615 for (
auto call : toTruncateFuncMem) {
2618 for (
auto call : toTruncateFuncOp) {
2621 for (
auto call : toTruncateValue) {
2622 HandleTruncateValue(call,
true);
2624 for (
auto call : toExpandValue) {
2625 HandleTruncateValue(call,
false);
2628 for (
auto &&[call, mode] : toProbProg) {
2629 HandleProbProg(call, mode, calls);
2633 auto Params = llvm::getInlineParams();
2635 llvm::SetVector<CallInst *> Q;
2636 for (
auto call : calls)
2639 auto cur = *Q.begin();
2640 Function *outerFunc = cur->getParent()->getParent();
2641 llvm::OptimizationRemarkEmitter ORE(outerFunc);
2643 if (
auto F = cur->getCalledFunction()) {
2646 SmallVector<std::unique_ptr<AssumptionCache>, 2> ACAlloc;
2647 auto getAC = [&](Function &F) -> llvm::AssumptionCache & {
2648 auto AC = std::make_unique<AssumptionCache>(F);
2649 ACAlloc.push_back(std::move(AC));
2650 return *ACAlloc.back();
2653 [&](llvm::Function &F) ->
const llvm::TargetLibraryInfo & {
2654 return Logic.
PPC.
FAM.getResult<TargetLibraryAnalysis>(F);
2657 TargetTransformInfo TTI(F->getParent()->getDataLayout());
2658 auto GetInlineCost = [&](CallBase &CB) {
2659 auto cst = llvm::getInlineCost(CB, Params, TTI, getAC, GetTLI);
2662#if LLVM_VERSION_MAJOR >= 20
2663 if (llvm::shouldInline(*cur, TTI, GetInlineCost, ORE))
2665 if (llvm::shouldInline(*cur, GetInlineCost, ORE))
2668 InlineFunctionInfo IFI;
2669 InlineResult IR = InlineFunction(*cur, IFI);
2670 if (IR.isSuccess()) {
2672 for (
auto U : outerFunc->users()) {
2673 if (
auto CI = dyn_cast<CallInst>(U)) {
2674 if (CI->getCalledFunction() == outerFunc) {
2691#if !defined(FLANG) && !defined(ROCM)
2693 AnalysisGetter AG(Logic.
PPC.
FAM);
2694 SetVector<Function *> Functions;
2695 for (Function &F2 : *F.getParent()) {
2696 Functions.insert(&F2);
2699 CallGraphUpdater CGUpdater;
2700 BumpPtrAllocator Allocator;
2701 InformationCache InfoCache(*F.getParent(), AG, Allocator,
2704 DenseSet<const char *> Allowed = {
2708 &AAMemoryBehavior::ID,
2709 &AAMemoryLocation::ID,
2717 &AADereferenceable::ID,
2719#if LLVM_VERSION_MAJOR < 17
2720 &AAReturnedValues::ID,
2732 AttributorConfig aconfig(CGUpdater);
2733 aconfig.Allowed = &Allowed;
2734 aconfig.DeleteFns =
false;
2735 Attributor A(Functions, InfoCache, aconfig);
2736 for (Function *F : Functions) {
2739 A.identifyDefaultAbstractAttributes(*F);
2748 bool run(Module &M) {
2751 for (Function &F : make_early_inc_range(M)) {
2755 bool changed =
false;
2756 for (Function &F : M) {
2759 for (BasicBlock &BB : F) {
2760 for (Instruction &I : make_early_inc_range(BB)) {
2761 if (
auto CI = dyn_cast<CallInst>(&I)) {
2762 Function *F = CI->getCalledFunction();
2764 dyn_cast<ConstantExpr>(CI->getCalledOperand())) {
2765 if (castinst->isCast())
2766 if (
auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
2770 if (F && F->getName() ==
"f90_mzero8") {
2774 args[0] = CI->getArgOperand(0);
2775 args[1] = ConstantInt::get(Type::getInt8Ty(M.getContext()), 0);
2776 args[2] = B.CreateMul(
2777 CI->getArgOperand(1),
2778 ConstantInt::get(CI->getArgOperand(1)->getType(), 8));
2779 B.CreateMemSet(args[0], args[1], args[2], MaybeAlign());
2781 CI->eraseFromParent();
2789 OpenMPOptPass().run(M, Logic.
PPC.
MAM);
2791 AttributorPass().run(M, Logic.
PPC.
MAM);
2794 PromotePass().run(F, Logic.
PPC.
FAM);
2802 std::set<Function *> done;
2803 for (Function &F : M) {
2807 changed |= lowerEnzymeCalls(F, done);
2810 for (Function &F : M) {
2814 for (BasicBlock &BB : F) {
2815 for (Instruction &I : make_early_inc_range(BB)) {
2816 if (
auto CI = dyn_cast<CallInst>(&I)) {
2817 Function *F = CI->getCalledFunction();
2819 dyn_cast<ConstantExpr>(CI->getCalledOperand())) {
2820 if (castinst->isCast())
2821 if (
auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
2826 if (F->getName().contains(
"__enzyme_float") ||
2827 F->getName().contains(
"__enzyme_double") ||
2828 F->getName().contains(
"__enzyme_integer") ||
2829 F->getName().contains(
"__enzyme_pointer")) {
2830 CI->eraseFromParent();
2833 if (F->getName().contains(
"__enzyme_iter") ||
2834 F->getName().contains(
"__enzyme_ignore_derivatives")) {
2835 CI->replaceAllUsesWith(CI->getArgOperand(0));
2836 CI->eraseFromParent();
2845 SmallPtrSet<CallInst *, 16> sample_calls;
2846 SmallPtrSet<CallInst *, 16> observe_calls;
2847 for (
auto &&func : M) {
2848 for (
auto &&BB : func) {
2849 for (
auto &&Inst : BB) {
2850 if (
auto CI = dyn_cast<CallInst>(&Inst)) {
2851 Function *fun = CI->getCalledFunction();
2855 if (fun->getName().contains(
"__enzyme_sample")) {
2856 if (CI->getNumOperands() < 3) {
2858 "IllegalNumberOfArguments", CI->getDebugLoc(), CI,
2859 "Not enough arguments passed to call to __enzyme_sample");
2863 samplefn->getFunctionType()->getNumParams() + 3;
2864 unsigned actual = CI->arg_size();
2865 if (actual - 3 != samplefn->getFunctionType()->getNumParams()) {
2866 EmitFailure(
"IllegalNumberOfArguments", CI->getDebugLoc(), CI,
2867 "Illegal number of arguments passed to call to "
2869 " Expected: ", expected,
" got: ", actual);
2873 for (
unsigned i = 0;
2874 i < samplefn->getFunctionType()->getNumParams(); ++i) {
2875 Value *ci_arg = CI->getArgOperand(i + 3);
2876 Value *sample_arg = samplefn->arg_begin() + i;
2877 Value *pdf_arg = pdf->arg_begin() + i;
2879 if (ci_arg->getType() != sample_arg->getType()) {
2881 "IllegalSampleType", CI->getDebugLoc(), CI,
2882 "Type of: ", *ci_arg,
" (", *ci_arg->getType(),
")",
2883 " does not match the argument type of the sample "
2885 *samplefn,
" at: ", i,
" (", *sample_arg->getType(),
")");
2887 if (ci_arg->getType() != pdf_arg->getType()) {
2888 EmitFailure(
"IllegalSampleType", CI->getDebugLoc(), CI,
2889 "Type of: ", *ci_arg,
" (", *ci_arg->getType(),
2891 " does not match the argument type of the "
2892 "density function: ",
2893 *pdf,
" at: ", i,
" (", *pdf_arg->getType(),
")");
2897 if ((pdf->arg_end() - 1)->getType() !=
2898 samplefn->getReturnType()) {
2900 "IllegalSampleType", CI->getDebugLoc(), CI,
2901 "Return type of ", *samplefn,
" (",
2902 *samplefn->getReturnType(),
")",
2903 " does not match the last argument type of the density "
2905 *pdf,
" (", *(pdf->arg_end() - 1)->getType(),
")");
2907 sample_calls.insert(CI);
2909 }
else if (fun->getName().contains(
"__enzyme_observe")) {
2910 if (CI->getNumOperands() < 3) {
2912 "IllegalNumberOfArguments", CI->getDebugLoc(), CI,
2913 "Not enough arguments passed to call to __enzyme_sample");
2915 Value *observed = CI->getOperand(0);
2917 unsigned expected = pdf->getFunctionType()->getNumParams() - 1;
2919 unsigned actual = CI->arg_size();
2920 if (actual - 3 != expected) {
2921 EmitFailure(
"IllegalNumberOfArguments", CI->getDebugLoc(), CI,
2922 "Illegal number of arguments passed to call to "
2923 "__enzyme_observe.",
2924 " Expected: ", expected,
" got: ", actual);
2927 for (
unsigned i = 0;
2928 i < pdf->getFunctionType()->getNumParams() - 1; ++i) {
2929 Value *ci_arg = CI->getArgOperand(i + 3);
2930 Value *pdf_arg = pdf->arg_begin() + i;
2932 if (ci_arg->getType() != pdf_arg->getType()) {
2933 EmitFailure(
"IllegalSampleType", CI->getDebugLoc(), CI,
2934 "Type of: ", *ci_arg,
" (", *ci_arg->getType(),
2936 " does not match the argument type of the "
2937 "density function: ",
2938 *pdf,
" at: ", i,
" (", *pdf_arg->getType(),
")");
2942 if ((pdf->arg_end() - 1)->getType() != observed->getType()) {
2944 "IllegalSampleType", CI->getDebugLoc(), CI,
2945 "Return type of ", *observed,
" (", *observed->getType(),
2947 " does not match the last argument type of the density "
2949 *pdf,
" (", *(pdf->arg_end() - 1)->getType(),
")");
2951 observe_calls.insert(CI);
2960 for (
auto call : sample_calls) {
2963 SmallVector<Value *, 2> args;
2964 for (
auto it = call->arg_begin() + 3; it != call->arg_end(); it++) {
2965 args.push_back(*it);
2968 CallInst::Create(samplefn->getFunctionType(), samplefn, args);
2970 ReplaceInstWithInst(call, choice);
2973 for (
auto call : observe_calls) {
2974 Value *observed = call->getArgOperand(0);
2976 if (!call->getType()->isVoidTy())
2977 call->replaceAllUsesWith(observed);
2978 call->eraseFromParent();
2981 for (
const auto &pair : Logic.
PPC.
cache)
2982 pair.second->eraseFromParent();
2985 if (changed && Logic.
PostOpt) {
2986 TimeTraceScope timeScope(
"Enzyme PostOpt", M.getName());
2989 LoopAnalysisManager LAM;
2990 FunctionAnalysisManager FAM;
2991 CGSCCAnalysisManager CGAM;
2992 ModuleAnalysisManager MAM;
2993 PB.registerModuleAnalyses(MAM);
2994 PB.registerFunctionAnalyses(FAM);
2995 PB.registerLoopAnalyses(LAM);
2996 PB.registerCGSCCAnalyses(CGAM);
2997 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
2998 auto PM = PB.buildModuleSimplificationPipeline(OptimizationLevel::O2,
2999 ThinOrFullLTOPhase::None);
3002 OpenMPOptPass().run(M, MAM);
3004 AttributorPass().run(M, MAM);
3007 PromotePass().run(F, FAM);
3019class EnzymeOldPM :
public EnzymeBase,
public ModulePass {
3022 EnzymeOldPM(
bool PostOpt =
false) : EnzymeBase(PostOpt),
ModulePass(ID) {}
3024 void getAnalysisUsage(AnalysisUsage &AU)
const override {
3025 AU.addRequired<TargetLibraryInfoWrapperPass>();
3034 bool runOnModule(Module &M)
override {
return run(M); }
3039char EnzymeOldPM::ID = 0;
3041static RegisterPass<EnzymeOldPM>
X(
"enzyme",
"Enzyme Pass");
3045#include <llvm-c/Core.h>
3046#include <llvm-c/Types.h>
3048#include "llvm/IR/LegacyPassManager.h"
3054#if LLVM_VERSION_MAJOR >= 22
3055#include "llvm/Plugins/PassPlugin.h"
3057#include "llvm/Passes/PassPlugin.h"
3064 static llvm::AnalysisKey Key;
3070 Result run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM) {
3071 return EnzymeBase::run(M) ? PreservedAnalyses::none()
3072 : PreservedAnalyses::all();
3079AnalysisKey EnzymeNewPM::Key;
3086#include "llvm/Passes/PassBuilder.h"
3087#include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
3088#include "llvm/Transforms/IPO/AlwaysInliner.h"
3089#include "llvm/Transforms/IPO/CalledValuePropagation.h"
3090#include "llvm/Transforms/IPO/ConstantMerge.h"
3091#include "llvm/Transforms/IPO/CrossDSOCFI.h"
3092#include "llvm/Transforms/IPO/DeadArgumentElimination.h"
3093#include "llvm/Transforms/IPO/FunctionAttrs.h"
3094#include "llvm/Transforms/IPO/GlobalDCE.h"
3095#include "llvm/Transforms/IPO/GlobalOpt.h"
3096#include "llvm/Transforms/IPO/GlobalSplit.h"
3097#include "llvm/Transforms/IPO/InferFunctionAttrs.h"
3098#include "llvm/Transforms/IPO/SCCP.h"
3099#include "llvm/Transforms/InstCombine/InstCombine.h"
3100#include "llvm/Transforms/Scalar/CallSiteSplitting.h"
3101#include "llvm/Transforms/Scalar/EarlyCSE.h"
3102#include "llvm/Transforms/Scalar/Float2Int.h"
3103#include "llvm/Transforms/Scalar/GVN.h"
3104#include "llvm/Transforms/Scalar/LoopDeletion.h"
3105#include "llvm/Transforms/Scalar/LoopRotation.h"
3106#include "llvm/Transforms/Scalar/LoopUnrollPass.h"
3107#include "llvm/Transforms/Scalar/SROA.h"
3109#include "llvm/Transforms/IPO/ArgumentPromotion.h"
3110#include "llvm/Transforms/Scalar/ConstraintElimination.h"
3111#include "llvm/Transforms/Scalar/DeadStoreElimination.h"
3112#include "llvm/Transforms/Scalar/JumpThreading.h"
3113#include "llvm/Transforms/Scalar/MemCpyOptimizer.h"
3114#include "llvm/Transforms/Scalar/NewGVN.h"
3115#include "llvm/Transforms/Scalar/TailRecursionElimination.h"
3116#if LLVM_VERSION_MAJOR >= 17
3117#include "llvm/Transforms/Utils/MoveAutoInit.h"
3119#include "llvm/Transforms/Scalar/IndVarSimplify.h"
3120#include "llvm/Transforms/Scalar/LICM.h"
3121#include "llvm/Transforms/Scalar/LoopFlatten.h"
3122#include "llvm/Transforms/Scalar/MergedLoadStoreMotion.h"
3125#if LLVM_VERSION_MAJOR >= 23
3126 return getInlineParams(Level.getSpeedupLevel());
3128 return getInlineParams(Level.getSpeedupLevel(), Level.getSizeLevel());
3132#include "llvm/Transforms/Scalar/LowerConstantIntrinsics.h"
3133#include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
3137#define EnableLoopFlatten false
3138#define EagerlyInvalidateAnalyses false
3139#define RunNewGVN false
3140#define EnableConstraintElimination true
3141#define UseInlineAdvisor InliningAdvisorMode::Default
3142#define EnableMemProfContextDisambiguation false
3144#define EnableMatrix false
3145#define EnableModuleInliner false
3150 auto prePass = [](ModulePassManager &MPM, OptimizationLevel Level) {
3151 FunctionPassManager OptimizePM;
3152 OptimizePM.addPass(Float2IntPass());
3153 OptimizePM.addPass(LowerConstantIntrinsicsPass());
3156 OptimizePM.addPass(LowerMatrixIntrinsicsPass());
3157 OptimizePM.addPass(EarlyCSEPass());
3160 LoopPassManager LPM;
3161 bool LTOPreLink =
false;
3164#if LLVM_VERSION_MAJOR >= 23
3165 LPM.addPass(LoopRotatePass(
true, LTOPreLink,
3168 LPM.addPass(LoopRotatePass(Level != OptimizationLevel::Oz, LTOPreLink));
3174 LPM.addPass(LoopDeletionPass());
3176 LPM.addPass(llvm::LoopFullUnrollPass());
3177 OptimizePM.addPass(createFunctionToLoopPassAdaptor(std::move(LPM)));
3179 MPM.addPass(createModuleToFunctionPassAdaptor(std::move(OptimizePM)));
3182#if LLVM_VERSION_MAJOR >= 20
3183 auto loadPass = [prePass](ModulePassManager &MPM, OptimizationLevel Level,
3186 auto loadPass = [prePass](ModulePassManager &MPM, OptimizationLevel Level)
3194 if (Level != OptimizationLevel::O0)
3195 prePass(MPM, Level);
3196 MPM.addPass(llvm::AlwaysInlinerPass());
3197 FunctionPassManager OptimizerPM;
3198 FunctionPassManager OptimizerPM2;
3199#if LLVM_VERSION_MAJOR >= 16
3200 OptimizerPM.addPass(llvm::GVNPass());
3201 OptimizerPM.addPass(llvm::SROAPass(llvm::SROAOptions::PreserveCFG));
3203 OptimizerPM.addPass(llvm::GVNPass());
3204 OptimizerPM.addPass(llvm::SROAPass());
3206 MPM.addPass(createModuleToFunctionPassAdaptor(std::move(OptimizerPM)));
3209#if LLVM_VERSION_MAJOR >= 16
3210 OptimizerPM2.addPass(llvm::GVNPass());
3211 OptimizerPM2.addPass(llvm::SROAPass(llvm::SROAOptions::PreserveCFG));
3213 OptimizerPM2.addPass(llvm::GVNPass());
3214 OptimizerPM2.addPass(llvm::SROAPass());
3217 LoopPassManager LPM1;
3218 LPM1.addPass(LoopDeletionPass());
3219 OptimizerPM2.addPass(createFunctionToLoopPassAdaptor(std::move(LPM1)));
3221 MPM.addPass(createModuleToFunctionPassAdaptor(std::move(OptimizerPM2)));
3222 MPM.addPass(GlobalOptPass());
3225 PB.registerOptimizerEarlyEPCallback(
loadPass);
3227 auto loadNVVM = [](ModulePassManager &MPM, OptimizationLevel) {
3234 PB.registerPipelineStartEPCallback(loadNVVM);
3235 PB.registerFullLinkTimeOptimizationEarlyEPCallback(loadNVVM);
3237 auto preLTOPass = [](ModulePassManager &MPM, OptimizationLevel Level) {
3240 MPM.addPass(CrossDSOCFIPass());
3242 if (Level == OptimizationLevel::O0) {
3248#if LLVM_VERSION_MAJOR >= 16
3249 MPM.addPass(OpenMPOptPass(ThinOrFullLTOPhase::FullLTOPostLink));
3251 MPM.addPass(OpenMPOptPass());
3256 MPM.addPass(GlobalDCEPass());
3260 MPM.addPass(InferFunctionAttrsPass());
3262 if (Level.getSpeedupLevel() > 1) {
3263 MPM.addPass(createModuleToFunctionPassAdaptor(CallSiteSplittingPass(),
3279#if LLVM_VERSION_MAJOR >= 23
3280 MPM.addPass(IPSCCPPass(IPSCCPOptions(
true)));
3281#elif LLVM_VERSION_MAJOR >= 16
3282 MPM.addPass(IPSCCPPass(IPSCCPOptions(
3283 Level != OptimizationLevel::Os &&
3284 Level != OptimizationLevel::Oz)));
3286 MPM.addPass(IPSCCPPass());
3291 MPM.addPass(CalledValuePropagationPass());
3296 createModuleToPostOrderCGSCCPassAdaptor(PostOrderFunctionAttrsPass()));
3302 MPM.addPass(ReversePostOrderFunctionAttrsPass());
3306 MPM.addPass(GlobalSplitPass());
3313 if (Level == OptimizationLevel::O1) {
3318 MPM.addPass(GlobalOptPass());
3321 MPM.addPass(createModuleToFunctionPassAdaptor(PromotePass()));
3325 MPM.addPass(ConstantMergePass());
3328 MPM.addPass(DeadArgumentEliminationPass());
3334 FunctionPassManager PeepholeFPM;
3335 PeepholeFPM.addPass(InstCombinePass());
3336 if (Level.getSpeedupLevel() > 1)
3337 PeepholeFPM.addPass(AggressiveInstCombinePass());
3339 MPM.addPass(createModuleToFunctionPassAdaptor(std::move(PeepholeFPM),
3349 ThinOrFullLTOPhase::FullLTOPostLink));
3351 MPM.addPass(ModuleInlinerWrapperPass(
3354 InlineContext{ThinOrFullLTOPhase::FullLTOPostLink,
3355 InlinePass::CGSCCInliner}));
3364 MPM.addPass(GlobalOptPass());
3367#if LLVM_VERSION_MAJOR >= 16
3368 MPM.addPass(OpenMPOptPass(ThinOrFullLTOPhase::FullLTOPostLink));
3370 MPM.addPass(OpenMPOptPass());
3374 MPM.addPass(GlobalDCEPass());
3379 createModuleToPostOrderCGSCCPassAdaptor(ArgumentPromotionPass()));
3381 FunctionPassManager FPM;
3383 FPM.addPass(InstCombinePass());
3386 FPM.addPass(ConstraintEliminationPass());
3388 FPM.addPass(JumpThreadingPass());
3394 if (PGOOpt->CSAction == PGOOptions::CSIRInstr)
3395 addPGOInstrPasses(MPM, Level,
true,
3396 true, PGOOpt->CSProfileGenFile,
3397 PGOOpt->ProfileRemappingFile,
3398 ThinOrFullLTOPhase::FullLTOPostLink, PGOOpt->FS);
3399 else if (PGOOpt->CSAction == PGOOptions::CSIRUse)
3400 addPGOInstrPasses(MPM, Level,
false,
3401 true, PGOOpt->ProfileFile,
3402 PGOOpt->ProfileRemappingFile,
3403 ThinOrFullLTOPhase::FullLTOPostLink, PGOOpt->FS);
3408#if LLVM_VERSION_MAJOR >= 16
3409 FPM.addPass(SROAPass(SROAOptions::ModifyCFG));
3411 FPM.addPass(SROAPass());
3416 FPM.addPass(TailCallElimPass());
3419 MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM),
3423 createModuleToPostOrderCGSCCPassAdaptor(PostOrderFunctionAttrsPass()));
3427 MPM.addPass(RequireAnalysisPass<GlobalsAA, Module>());
3430 auto loadLTO = [preLTOPass,
loadPass](ModulePassManager &MPM,
3431 OptimizationLevel Level) {
3432 preLTOPass(MPM, Level);
3434 createModuleToPostOrderCGSCCPassAdaptor(PostOrderFunctionAttrsPass()));
3438 MPM.addPass(RequireAnalysisPass<GlobalsAA, Module>());
3443 createModuleToFunctionPassAdaptor(InvalidateAnalysisPass<AAManager>()));
3445 FunctionPassManager MainFPM;
3446#if LLVM_VERSION_MAJOR >= 22
3447 MainFPM.addPass(createFunctionToLoopPassAdaptor(
3452 MainFPM.addPass(createFunctionToLoopPassAdaptor(
3459 MainFPM.addPass(NewGVNPass());
3461 MainFPM.addPass(GVNPass());
3464 MainFPM.addPass(MemCpyOptPass());
3467 MainFPM.addPass(DSEPass());
3468#if LLVM_VERSION_MAJOR >= 17
3469 MainFPM.addPass(MoveAutoInitPass());
3471 MainFPM.addPass(MergedLoadStoreMotionPass());
3473 LoopPassManager LPM;
3475 LPM.addPass(LoopFlattenPass());
3476 LPM.addPass(IndVarSimplifyPass());
3477 LPM.addPass(LoopDeletionPass());
3480#if LLVM_VERSION_MAJOR >= 20
3481 loadPass(MPM, Level, ThinOrFullLTOPhase::None);
3486 PB.registerFullLinkTimeOptimizationEarlyEPCallback(loadLTO);
3492 bool augment =
false) {
3496 PB.registerPipelineParsingCallback(
3497 [](llvm::StringRef Name, llvm::ModulePassManager &MPM,
3498 llvm::ArrayRef<llvm::PassBuilder::PipelineElement>) {
3499 if (Name ==
"enzyme") {
3506 if (Name ==
"preserve-nvvm") {
3510 if (Name ==
"preserve-nvvm-end") {
3514 if (Name ==
"print-type-analysis") {
3518 if (Name ==
"print-activity-analysis") {
3524 PB.registerPipelineParsingCallback(
3525 [](llvm::StringRef Name, llvm::FunctionPassManager &FPM,
3526 llvm::ArrayRef<llvm::PassBuilder::PipelineElement>) {
3527 if (Name ==
"jl-inst-simplify") {
3531 if (Name ==
"simple-gvn") {
3540#ifdef ENZYME_RUNPASS
3547extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK
3549 return {LLVM_PLUGIN_API_VERSION,
"EnzymeNewPM",
"v0.1",
registerEnzyme};
llvm::cl::opt< bool > EnzymeEnable
static void loadPass(const PassManagerBuilder &Builder, legacy::PassManagerBase &PM)
#define EagerlyInvalidateAnalyses
#define EnableLoopFlatten
void augmentPassBuilder(llvm::PassBuilder &PB)
::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK llvmGetPassPluginInfo()
ModulePass * createEnzymePass(bool PostOpt)
llvm::cl::opt< bool > EnzymeOMPOpt("enzyme-omp-opt", cl::init(false), cl::Hidden, cl::desc("Whether to enable openmp opt"))
static InlineParams getInlineParamsFromOptLevel(OptimizationLevel Level)
void AddEnzymePass(LLVMPassManagerRef PM)
llvm::cl::opt< bool > EnzymeAttributor("enzyme-attributor", cl::init(false), cl::Hidden, cl::desc("Run attributor post Enzyme"))
#define EnableModuleInliner
#define EnableConstraintElimination
llvm::cl::opt< bool > EnzymeDetectReadThrow("enzyme-detect-readthrow", cl::init(true), cl::Hidden, cl::desc("Run preprocessing detect readonly or throw optimization"))
void registerEnzyme(llvm::PassBuilder &PB)
void registerEnzymeAndPassPipeline(llvm::PassBuilder &PB, bool augment=false)
llvm::cl::opt< std::string > EnzymeTruncateAll("enzyme-truncate-all", cl::init(""), cl::Hidden, cl::desc("Truncate all floating point operations. " "E.g. \"64to32\" or \"64to<exponent_width>-<significand_width>\"."))
static RegisterPass< EnzymeOldPM > X("enzyme", "Enzyme Pass")
llvm::cl::opt< bool > EnzymePostOpt("enzyme-postopt", cl::init(false), cl::Hidden, cl::desc("Run enzymepostprocessing optimizations"))
bool registerFixupJuliaPass(llvm::StringRef Name, llvm::ModulePassManager &MPM)
bool DetectReadonlyOrThrow(Module &M)
bool LowerSparsification(llvm::Function *F, bool replaceAll)
Lower __enzyme_todense, returning if changed.
constexpr const char * to_string(ActivityAnalyzer::UseActivity UA)
llvm::cl::opt< bool > EnzymePrint
constexpr char EnzymeFPRTPrefix[]
Value * simplifyLoad(Value *V, size_t valSz, size_t preOffset)
bool attributeKnownFunctions(llvm::Function &F)
Function * GetFunctionFromValue(Value *fn)
llvm::FastMathFlags getFast()
Get LLVM fast math flags.
static llvm::StringRef getFuncName(llvm::Function *called)
@ Args
Return is a struct of all args.
static bool startsWith(llvm::StringRef string, llvm::StringRef prefix)
DIFFE_TYPE
Potential differentiable argument classifications.
static llvm::PointerType * getUnqual(llvm::Type *T)
void EmitFailure(llvm::StringRef RemarkName, const llvm::DiagnosticLocation &Loc, const llvm::Instruction *CodeRegion, Args &...args)
static llvm::PointerType * getInt8PtrTy(llvm::LLVMContext &Context, unsigned AddressSpace=0)
static void addCallSiteNoCapture(llvm::CallBase *call, size_t idx)
static bool isWriteOnly(const llvm::Function *F, ssize_t arg=-1)
static DIFFE_TYPE whatType(llvm::Type *arg, DerivativeMode mode, bool integersAreConstant, std::set< llvm::Type * > &seen)
Attempt to automatically detect the differentiable classification based off of a given type.
return structtype if recursive function
llvm::Type * tapeType
return structtype if recursive function
std::map< AugmentedStruct, int > returns
Map from information desired from a augmented return to its index in the returned struct.
Concrete SubType of a given value.
llvm::Function * CreateTruncateFunc(RequestContext context, llvm::Function *tobatch, FloatTruncation truncation, TruncateMode mode)
bool CreateTruncateValue(RequestContext context, llvm::Value *addr, FloatRepresentation from, FloatRepresentation to, bool isTruncate)
const AugmentedReturn & CreateAugmentedPrimal(RequestContext context, llvm::Function *todiff, DIFFE_TYPE retType, llvm::ArrayRef< DIFFE_TYPE > constant_args, TypeAnalysis &TA, bool returnUsed, bool shadowReturnUsed, const FnTypeInfo &typeInfo, bool subsequent_calls_may_write, const std::vector< bool > _overwritten_args, bool forceAnonymousTape, bool runtimeActivity, bool strongZero, unsigned width, bool AtomicAdd, bool omp=false)
Create an augmented forward pass.
llvm::Function * CreateForwardDiff(RequestContext context, llvm::Function *todiff, DIFFE_TYPE retType, llvm::ArrayRef< DIFFE_TYPE > constant_args, TypeAnalysis &TA, bool returnValue, DerivativeMode mode, bool freeMemory, bool runtimeActivity, bool strongZero, unsigned width, llvm::Type *additionalArg, const FnTypeInfo &typeInfo, bool subsequent_calls_may_write, const std::vector< bool > _overwritten_args, const AugmentedReturn *augmented, bool omp=false)
Create the forward (or forward split) mode derivative function.
llvm::Function * CreateTrace(RequestContext context, llvm::Function *totrace, const llvm::SmallPtrSetImpl< llvm::Function * > &sampleFunctions, const llvm::SmallPtrSetImpl< llvm::Function * > &observeFunctions, const llvm::StringSet<> &ActiveRandomVariables, ProbProgMode mode, bool autodiff, TraceInterface *interface)
Create a traced version of a function context the instruction which requested this trace (or null).
llvm::Function * CreatePrimalAndGradient(RequestContext context, const ReverseCacheKey &&key, TypeAnalysis &TA, const AugmentedReturn *augmented, bool omp=false)
Create the reverse pass, or combined forward+reverse derivative function.
llvm::Function * CreateBatch(RequestContext context, llvm::Function *tobatch, unsigned width, llvm::ArrayRef< BATCH_TYPE > arg_types, BATCH_TYPE ret_type)
Create a function batched in its inputs.
bool PostOpt
PostOpt is whether to perform basic optimization of the function after synthesis
EnzymeNewPM(bool PostOpt=false)
llvm::PreservedAnalyses Result
Result run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM)
static llvm::Constant * GetOrCreateShadowConstant(RequestContext context, EnzymeLogic &Logic, llvm::TargetLibraryInfo &TLI, TypeAnalysis &TA, llvm::Constant *F, DerivativeMode mode, bool runtimeActivity, bool strongZero, unsigned width, bool AtomicAdd)
static llvm::Type * getShadowType(llvm::Type *ty, unsigned width)
llvm::ModuleAnalysisManager MAM
std::map< std::pair< llvm::Function *, DerivativeMode >, llvm::Function * > cache
llvm::FunctionAnalysisManager FAM
Full interprocedural TypeAnalysis.
TypeResults analyzeFunction(const FnTypeInfo &fn)
Analyze a particular function, returning the results.
FnTypeInfo getAnalyzedTypeInfo() const
The TypeInfo calling convention.
Class representing the underlying types of values as sequences of offsets to a ConcreteType.
TypeTree Only(int Off, llvm::Instruction *orig) const
Prepend an offset to all mappings.
bool insert(const std::vector< int > Seq, ConcreteType CT, bool PointerIntSame=false)
Return if changed.
cl::opt< unsigned > SetLicmMssaOptCap
cl::opt< unsigned > SetLicmMssaNoAccForPromotionCap
Struct containing all contextual type information for a particular function call.
std::map< llvm::Argument *, TypeTree > Arguments
Types of arguments.
llvm::Function * Function
Function being analyzed.
TypeTree Return
Type of return.
std::map< llvm::Argument *, std::set< int64_t > > KnownValues
The specific constant(s) known to represented by an argument, if constant.
todiff is the function to differentiate retType is the activity info of the return.