923 llvm::Value *orig_val, llvm::MaybeAlign prevalign,
924 bool isVolatile, llvm::AtomicOrdering ordering,
925 llvm::SyncScope::ID syncScope, llvm::Value *mask) {
926 using namespace llvm;
929 Type *valType = orig_val->getType();
931 auto &DL = gutils->
newFunc->getParent()->getDataLayout();
933 if (unnecessaryStores.count(&I)) {
941 SmallVector<Metadata *, 1> scopeMD = {
943 SmallVector<Metadata *, 1> prevScopes;
944 if (
auto prev = I.getMetadata(LLVMContext::MD_alias_scope)) {
945 for (
auto &M : cast<MDNode>(prev)->operands()) {
946 scopeMD.push_back(M);
947 prevScopes.push_back(M);
950 auto scope = MDNode::get(I.getContext(), scopeMD);
952 NewI->setMetadata(LLVMContext::MD_alias_scope, scope);
954 SmallVector<Metadata *, 1> MDs;
955 SmallVector<Metadata *, 1> prevNoAlias;
956 for (
size_t j = 0; j < gutils->
getWidth(); j++) {
959 if (
auto prev = I.getMetadata(LLVMContext::MD_noalias)) {
960 for (
auto &M : cast<MDNode>(prev)->operands()) {
962 prevNoAlias.push_back(M);
965 auto noscope = MDNode::get(I.getContext(), MDs);
966 NewI->setMetadata(LLVMContext::MD_noalias, noscope);
969 parseTBAA(I, DL,
nullptr)[{-1}].isIntegral();
971 IRBuilder<> BuilderZ(NewI);
972 BuilderZ.setFastMathFlags(
getFast());
976 auto storeSize = (DL.getTypeSizeInBits(valType) + 7) / 8;
978 auto vd = TR.
query(orig_ptr).
Lookup(storeSize, DL);
982 raw_string_ostream ss(
str);
983 ss <<
"Cannot deduce type of store " << I;
986 ss <<
", assumed " << vd.str() <<
"\n";
1002 for (
size_t i = 0; i < storeSize; ++i) {
1004 dt.checkedOrIn(vd[{(int)i}],
true, Legal);
1007 raw_string_ostream ss(
str);
1008 ss <<
"Cannot deduce single type of store " << I << vd.str()
1009 <<
" size: " << storeSize;
1015 Value *diff =
nullptr;
1016 bool needs_writebarrier =
false;
1019 if (!isa<UndefValue>(orig_val) &&
1020 !isa<ConstantPointerNull>(orig_val)) {
1022 raw_string_ostream ss(
str);
1023 ss <<
"Mismatched activity for: " << I
1024 <<
" const val: " << *orig_val;
1028 wrap(orig_val), wrap(&BuilderZ)));
1030 needs_writebarrier =
true;
1042 else if (orig_val->getType()->isPointerTy() ||
1051 gutils->
setPtrDiffe(&I, orig_ptr, diff, BuilderZ, prevalign, 0, storeSize,
1052 isVolatile, ordering, syncScope, mask, prevNoAlias,
1053 prevScopes, needs_writebarrier);
1060 IRBuilder<> Builder2(&I);
1061 BasicBlock *merge =
nullptr;
1067 unsigned nextStart = storeSize;
1070 for (
size_t i = start; i < storeSize; ++i) {
1071 auto nex = vd[{(int)i}];
1078 dt.checkedOrIn(nex,
true, Legal);
1084 unsigned size = nextStart - start;
1085 if (!dt.isKnown()) {
1088 raw_string_ostream ss(
str);
1089 ss <<
"Cannot deduce type of store " << I << vd.str()
1090 <<
" start: " << start <<
" size: " << size
1091 <<
" storeSize: " << storeSize;
1098 if (start % prevalign->value() == 0)
1104 if (Type *FT = dt.isFloat()) {
1119 if (!isa<AllocaInst>(basePtr) &&
1126 shadow_ptr = gutils->
extractMeta(Builder2, shadow_ptr, 0);
1128 Value *shadow = Builder2.CreateICmpNE(primal_ptr, shadow_ptr);
1130 BasicBlock *current = Builder2.GetInsertBlock();
1132 current, current->getName() +
"_active");
1134 current->getName() +
"_amerge");
1135 Builder2.CreateCondBr(shadow, conditional, merge);
1136 Builder2.SetInsertPoint(conditional);
1144 Builder2, align, start, size, isVolatile, ordering, syncScope,
1145 mask, prevNoAlias, prevScopes);
1148 Value *maskL = mask;
1154 auto rule = [&](Value *dif1Ptr) {
1156 Builder2.CreateLoad(valType, dif1Ptr, isVolatile);
1158 dif1->setAlignment(*align);
1159 dif1->setOrdering(ordering);
1160 dif1->setSyncScopeID(syncScope);
1162 SmallVector<Metadata *, 1> scopeMD = {
1164 for (
auto M : prevScopes)
1165 scopeMD.push_back(M);
1167 SmallVector<Metadata *, 1> MDs;
1168 for (ssize_t j = -1; j < gutils->
getWidth(); j++) {
1169 if (j != (ssize_t)idx)
1172 for (
auto M : prevNoAlias)
1175 dif1->setMetadata(LLVMContext::MD_alias_scope,
1176 MDNode::get(I.getContext(), scopeMD));
1177 dif1->setMetadata(LLVMContext::MD_noalias,
1178 MDNode::get(I.getContext(), MDs));
1179 dif1->setMetadata(LLVMContext::MD_tbaa,
1180 I.getMetadata(LLVMContext::MD_tbaa));
1181 dif1->setMetadata(LLVMContext::MD_tbaa_struct,
1182 I.getMetadata(LLVMContext::MD_tbaa_struct));
1189 maskL =
lookup(mask, Builder2);
1190 Type *tys[] = {valType, orig_ptr->getType()};
1192 Intrinsic::masked_load, tys);
1194 ConstantInt::get(Type::getInt32Ty(mask->getContext()),
1195 align ? align->value() : 0);
1199 auto rule = [&](Value *ip) {
1200 Value *args[] = {ip, alignv, maskL,
1201 Constant::getNullValue(valType)};
1202 diff = Builder2.CreateCall(F, args);
1212 Builder2, align, start, size, isVolatile, ordering, syncScope,
1213 mask, prevNoAlias, prevScopes);
1215 ->addToDiffe(orig_val, diff, Builder2, FT, start, size, {},
1226 Value *diff = constantval
1227 ? Constant::getNullValue(diffeTy)
1230 gutils->
setPtrDiffe(&I, orig_ptr, diff, BuilderZ, align, start, size,
1231 isVolatile, ordering, syncScope, mask,
1232 prevNoAlias, prevScopes);
1244 if (
Constant *C = dyn_cast<Constant>(orig_val)) {
1245 while (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) {
1246 C = CE->getOperand(0);
1248 if (
auto GV = dyn_cast<GlobalVariable>(C)) {
1249 if (GV->getName() ==
"ompi_request_null") {
1255 bool backwardsShadow =
false;
1256 bool forwardsShadow =
true;
1258 if (pair.second.stores.count(&I)) {
1259 backwardsShadow =
true;
1260 forwardsShadow = pair.second.primalInitialize;
1261 if (
auto inst = dyn_cast<Instruction>(pair.first))
1262 if (!forwardsShadow && pair.second.LI &&
1263 pair.second.LI->contains(inst->getParent()))
1264 backwardsShadow =
false;
1272 (forwardsShadow || backwardsShadow)) ||
1276 Value *valueop =
nullptr;
1278 bool needs_writebarrier =
false;
1282 if (!isa<UndefValue>(orig_val) &&
1283 !isa<ConstantPointerNull>(orig_val)) {
1285 raw_string_ostream ss(
str);
1286 ss <<
"Mismatched activity for: " << I
1287 <<
" const val: " << *orig_val;
1291 gutils, wrap(orig_val), wrap(&BuilderZ)));
1293 needs_writebarrier =
true;
1304 for (
unsigned i = 0; i < gutils->
getWidth(); ++i) {
1305 array = BuilderZ.CreateInsertValue(array, val, {i});
1313 gutils->
setPtrDiffe(&I, orig_ptr, valueop, BuilderZ, align, start,
1314 size, isVolatile, ordering, syncScope, mask,
1315 prevNoAlias, prevScopes, needs_writebarrier);
1319 if (nextStart == storeSize)
1325 Builder2.CreateBr(merge);
1326 Builder2.SetInsertPoint(merge);
2379 using namespace llvm;
2381 IRBuilder<> Builder2(&BO);
2384 Value *orig_op0 = BO.getOperand(0);
2385 Value *orig_op1 = BO.getOperand(1);
2387 Value *dif0 =
nullptr;
2388 Value *dif1 =
nullptr;
2389 Value *idiff =
diffe(&BO, Builder2);
2391 Type *addingType = BO.getType();
2393 switch (BO.getOpcode()) {
2394 case Instruction::LShr: {
2396 if (
auto ci = dyn_cast<ConstantInt>(orig_op1)) {
2398 if (orig_op0->getType()->isSized())
2399 size = (gutils->
newFunc->getParent()
2401 .getTypeSizeInBits(orig_op0->getType()) +
2405 if (Type *flt = TR.
addingType(size, orig_op0)) {
2406 auto bits = gutils->
newFunc->getParent()
2408 .getTypeAllocSizeInBits(flt);
2409 if (ci->getSExtValue() >= (int64_t)bits &&
2410 ci->getSExtValue() % bits == 0) {
2411 auto rule = [&](Value *idiff) {
2412 return Builder2.CreateShl(idiff, ci);
2414 dif0 =
applyChainRule(orig_op0->getType(), Builder2, rule, idiff);
2422 llvm::errs() <<
"warning: binary operator is integer and constant: "
2429 case Instruction::AShr: {
2431 llvm::errs() <<
"warning: binary operator is integer and constant: "
2438 case Instruction::And: {
2440 auto &dl = gutils->
oldFunc->getParent()->getDataLayout();
2441 auto size = dl.getTypeSizeInBits(BO.getType()) / 8;
2446 for (
int i = 0; i < 2; ++i) {
2447 auto CI = dyn_cast<ConstantInt>(BO.getOperand(i));
2448 if (CI && dl.getTypeSizeInBits(eFT) ==
2449 dl.getTypeSizeInBits(CI->getType())) {
2450 if (eFT->isDoubleTy() && CI->getValue() == -134217728) {
2453 Constant::getNullValue(gutils->
getShadowType(BO.getType())),
2462 llvm::errs() <<
"warning: binary operator is integer and constant: "
2469 case Instruction::Xor: {
2470 auto &dl = gutils->
oldFunc->getParent()->getDataLayout();
2471 auto size = dl.getTypeSizeInBits(BO.getType()) / 8;
2477 for (
int i = 0; i < 2; ++i) {
2481 Constant::getNullValue(gutils->
getShadowType(BO.getType())),
2483 auto isZero = Builder2.CreateICmpEQ(
2485 Constant::getNullValue(BO.getType()));
2486 auto rule = [&](Value *idiff) {
2487 auto ext = Builder2.CreateBitCast(idiff, FT);
2488 auto neg = Builder2.CreateFNeg(ext);
2490 neg = Builder2.CreateBitCast(neg, BO.getType());
2493 auto bc =
applyChainRule(BO.getOperand(1 - i)->getType(), Builder2,
2495 addToDiffe(BO.getOperand(1 - i), bc, Builder2, FT);
2500 llvm::errs() <<
"warning: binary operator is integer and constant: "
2507 case Instruction::Or: {
2508 auto &dl = gutils->
oldFunc->getParent()->getDataLayout();
2509 auto size = dl.getTypeSizeInBits(BO.getType()) / 8;
2515 for (
int i = 0; i < 2; ++i) {
2516 auto CI = dyn_cast<ConstantInt>(BO.getOperand(i));
2517 if (
auto CV = dyn_cast<ConstantVector>(BO.getOperand(i))) {
2518 CI = dyn_cast_or_null<ConstantInt>(CV->getSplatValue());
2519 FT = VectorType::get(FT, CV->getType()->getElementCount());
2521 if (
auto CV = dyn_cast<ConstantDataVector>(BO.getOperand(i))) {
2522 CI = dyn_cast_or_null<ConstantInt>(CV->getSplatValue());
2523 FT = VectorType::get(FT, CV->getType()->getElementCount());
2525 if (CI && dl.getTypeSizeInBits(eFT) ==
2526 dl.getTypeSizeInBits(CI->getType())) {
2527 auto AP = CI->getValue();
2528 bool validXor =
false;
2529#if LLVM_VERSION_MAJOR > 16
2532 if (AP.isNullValue())
2539#
if LLVM_VERSION_MAJOR > 16
2540 && (AP & ~0b01111111100000000000000000000000ULL).isZero()
2542 && (AP & ~0b01111111100000000000000000000000ULL).isNullValue()
2546#
if LLVM_VERSION_MAJOR > 16
2549 ~0b0111111111110000000000000000000000000000000000000000000000000000ULL)
2554 ~0b0111111111110000000000000000000000000000000000000000000000000000ULL)
2563 Constant::getNullValue(gutils->
getShadowType(BO.getType())),
2569 auto rule = [&](Value *idiff) {
2570 auto prev = Builder2.CreateOr(arg, BO.getOperand(i));
2571 prev = Builder2.CreateSub(prev, arg,
"",
true,
2574 if (FT->isFloatTy()) {
2577 assert(FT->isDoubleTy());
2578 num = 1023ULL << 52;
2580 prev = Builder2.CreateAdd(
2581 prev, ConstantInt::get(prev->getType(), num,
false),
"",
2583 prev = Builder2.CreateBitCast(
2585 Builder2.CreateBitCast(idiff, FT),
2586 Builder2.CreateBitCast(prev, FT)),
2592 Builder2, rule, idiff);
2593 addToDiffe(BO.getOperand(1 - i), prev, Builder2, FT);
2599 llvm::errs() <<
"warning: binary operator is integer and constant: "
2606 case Instruction::UDiv:
2607 case Instruction::URem:
2608 case Instruction::SRem:
2609 case Instruction::SDiv:
2610 case Instruction::Shl:
2611 case Instruction::Mul:
2612 case Instruction::Sub:
2613 case Instruction::Add: {
2616 <<
"warning: binary operator is integer and assumed constant: "
2626 llvm::raw_string_ostream ss(s);
2627 ss << *gutils->
oldFunc <<
"\n";
2628 for (
auto &arg : gutils->
oldFunc->args()) {
2629 ss <<
" constantarg[" << arg <<
"] = " << gutils->
isConstantValue(&arg)
2630 <<
" type: " << TR.
query(&arg).
str() <<
" - vals: {";
2635 for (
auto &BB : *gutils->
oldFunc)
2636 for (
auto &I : BB) {
2637 ss <<
" constantinst[" << I
2640 <<
" type: " << TR.
query(&I).
str() <<
"\n";
2642 ss <<
"cannot handle unknown binary operator: " << BO <<
"\n";
2651 addToDiffe(orig_op0, dif0, Builder2, addingType);
2653 addToDiffe(orig_op1, dif1, Builder2, addingType);
2936 using namespace llvm;
2938 IRBuilder<> BuilderZ(&MS);
2941 IRBuilder<> Builder2(&MS);
2946 bool forceErase =
false;
2949 if (pair.second.stores.count(&MS) && pair.second.LI) {
2959 Value *orig_op0 = MS.getArgOperand(0);
2960 Value *orig_op1 = MS.getArgOperand(1);
2969 if (
auto CI = dyn_cast<ConstantInt>(orig_op1))
2971 activeValToSet =
false;
2972 if (activeValToSet) {
2974 llvm::raw_string_ostream ss(s);
2975 ss <<
"couldn't handle non constant inst in memset to "
2976 "propagate differential to\n"
2986 Value *op3 =
nullptr;
2987 if (3 < MS.arg_size()) {
3001 SmallVector<Value *, 4> args = {op0, op1, op2};
3003 args.push_back(op3);
3006 if (
startsWith(funcName,
"memset_pattern") ||
3007 startsWith(funcName,
"llvm.experimental.memset"))
3008 cal = Builder2.CreateMemSet(
3009 op0, ConstantInt::get(Builder2.getInt8Ty(), 0), op2, {});
3011 cal = BuilderZ.CreateCall(MS.getCalledFunction(), args, Defs);
3013 llvm::SmallVector<unsigned int, 9> ToCopy2(
MD_ToCopy);
3014 ToCopy2.push_back(LLVMContext::MD_noalias);
3015 cal->copyMetadata(MS, ToCopy2);
3016 if (
auto m =
hasMetadata(&MS,
"enzyme_zerostack"))
3017 cal->setMetadata(
"enzyme_zerostack", m);
3019 if (
startsWith(funcName,
"memset_pattern") ||
3020 startsWith(funcName,
"llvm.experimental.memset")) {
3021 AttributeList NewAttrs;
3023 {AttributeList::ReturnIndex, AttributeList::FunctionIndex,
3024 AttributeList::FirstArgIndex})
3025 for (
auto attr : MS.getAttributes().getAttributes(idx))
3027 NewAttrs.addAttributeAtIndex(MS.getContext(), idx, attr);
3028 cal->setAttributes(NewAttrs);
3030 cal->setAttributes(MS.getAttributes());
3031 cal->setCallingConv(MS.getCallingConv());
3032 cal->setTailCallKind(MS.getTailCallKind());
3039 bool backwardsShadow =
false;
3040 bool forwardsShadow =
true;
3042 if (pair.second.stores.count(&MS)) {
3043 backwardsShadow =
true;
3044 forwardsShadow = pair.second.primalInitialize;
3045 if (
auto inst = dyn_cast<Instruction>(pair.first))
3046 if (!forwardsShadow && pair.second.LI &&
3047 pair.second.LI->contains(inst->getParent()))
3048 backwardsShadow =
false;
3053 if (
auto ci = dyn_cast<ConstantInt>(MS.getOperand(2))) {
3054 size = ci->getLimitedValue();
3061 llvm::errs() << MS <<
"\n";
3066 std::vector<std::tuple<Type *, size_t, size_t>> toIterate;
3070 if (
auto MD =
hasMetadata(&MS,
"enzyme_truetype")) {
3073 auto &DL = gutils->
newFunc->getParent()->getDataLayout();
3076 if (!vd.isKnownPastPointer()) {
3079 if (
auto CI = dyn_cast<ConstantInt>(MS.getOperand(1)))
3082 bool writtenTo =
false;
3085 if (
auto arg = dyn_cast<Argument>(root))
3086 if (arg->hasStructRetAttr())
3089 Instruction *cur = MS.getPrevNode();
3093 if (
auto MCI = dyn_cast<ConstantInt>(MS.getOperand(2))) {
3094 if (
auto II = dyn_cast<IntrinsicInst>(cur)) {
3095 if (II->getCalledFunction()->getName() ==
3096 "llvm.enzyme.lifetime_start") {
3099 dyn_cast<ConstantInt>(II->getOperand(0))) {
3100 if (MCI->getValue().ule(CI2->getValue()))
3104 cur = cur->getPrevNode();
3109 if (II->getIntrinsicID() == Intrinsic::lifetime_start) {
3112 dyn_cast<ConstantInt>(II->getOperand(0))) {
3113 if (MCI->getValue().ule(CI2->getValue()))
3117 cur = cur->getPrevNode();
3122 if (cur->mayWriteToMemory()) {
3126 cur = cur->getPrevNode();
3137 if (!vd.isKnownPastPointer()) {
3139 if (isa<PHINode>(MS.getOperand(0)) ||
3140 isa<SelectInst>(MS.getOperand(0))) {
3141 SmallVector<Value *, 2> todo = {MS.getOperand(0)};
3143 SmallSet<Value *, 2> seen;
3145 while (todo.size()) {
3146 Value *cur = todo.back();
3148 if (seen.count(cur))
3151 if (
auto PN = dyn_cast<PHINode>(cur)) {
3152 for (
size_t i = 0, end = PN->getNumIncomingValues(); i < end;
3154 todo.push_back(PN->getIncomingValue(i));
3158 if (
auto S = dyn_cast<SelectInst>(cur)) {
3159 todo.push_back(S->getTrueValue());
3160 todo.push_back(S->getFalseValue());
3163 if (
auto CE = dyn_cast<ConstantExpr>(cur)) {
3165 todo.push_back(CE->getOperand(0));
3169 if (
auto CI = dyn_cast<CastInst>(cur)) {
3170 todo.push_back(CI->getOperand(0));
3173 if (isa<ConstantPointerNull>(cur))
3175 if (
auto CI = dyn_cast<ConstantInt>(cur))
3188 if (!vd.isKnownPastPointer()) {
3190#if LLVM_VERSION_MAJOR < 17
3191 if (
auto CI = dyn_cast<CastInst>(MS.getOperand(0))) {
3192 if (
auto PT = dyn_cast<PointerType>(CI->getSrcTy())) {
3193 auto ET = PT->getPointerElementType();
3195 if (
auto ST = dyn_cast<StructType>(ET)) {
3196 if (ST->getNumElements()) {
3197 ET = ST->getElementType(0);
3201 if (
auto AT = dyn_cast<ArrayType>(ET)) {
3202 ET = AT->getElementType();
3207 if (ET->isFPOrFPVectorTy()) {
3211 if (ET->isPointerTy()) {
3215 if (ET->isIntOrIntVectorTy()) {
3222 if (
auto gep = dyn_cast<GetElementPtrInst>(MS.getOperand(0))) {
3223 if (
auto AT = dyn_cast<ArrayType>(gep->getSourceElementType())) {
3224 if (AT->getElementType()->isIntegerTy()) {
3231 "failed to deduce type of memset ", MS);
3236 raw_string_ostream ss(
str);
3237 ss <<
"Cannot deduce type of memset " << MS;
3245 unsigned nextStart = size;
3248 for (
size_t i = start; i < size; ++i) {
3250 dt.checkedOrIn(vd[{(int)i}],
true, Legal);
3256 if (!dt.isKnown()) {
3258 llvm::errs() <<
" vd:" << vd.str() <<
" start:" << start
3259 <<
" size: " << size <<
" dt:" << dt.str() <<
"\n";
3261 assert(dt.isKnown());
3262 toIterate.emplace_back(dt.isFloat(), start, nextStart - start);
3264 if (nextStart == size)
3272 unsigned dstalign = dstAlign.valueOrOne().value();
3273 unsigned srcalign = srcAlign.valueOrOne().value();
3278 Value *op3 =
nullptr;
3279 if (3 < MS.arg_size()) {
3283 for (
auto &&[secretty_ref, seg_start_ref, seg_size_ref] : toIterate) {
3284 auto secretty = secretty_ref;
3285 auto seg_start = seg_start_ref;
3286 auto seg_size = seg_size_ref;
3288 Value *length = new_size;
3289 if (seg_start != std::get<1>(toIterate.back())) {
3290 length = ConstantInt::get(new_size->getType(), seg_start + seg_size);
3293 length = BuilderZ.CreateSub(
3294 length, ConstantInt::get(new_size->getType(), seg_start));
3297 unsigned subdstalign = dstalign;
3299 if (dstalign != 0) {
3300 if (start % dstalign != 0) {
3304 unsigned subsrcalign = srcalign;
3306 if (srcalign != 0) {
3307 if (start % srcalign != 0) {
3313 Value *shadow_dst = gutils->
invertPointerM(MS.getOperand(0), BuilderZ);
3326 auto rule = [&](Value *op0) {
3327 if (seg_start != 0) {
3328 Value *idxs[] = {ConstantInt::get(
3329 Type::getInt32Ty(op0->getContext()), seg_start)};
3330 op0 = BuilderZ.CreateInBoundsGEP(Type::getInt8Ty(op0->getContext()),
3333 SmallVector<Value *, 4> args = {op0, op1, length};
3335 args.push_back(op3);
3336 auto cal = BuilderZ.CreateCall(MS.getCalledFunction(), args, Defs);
3337 llvm::SmallVector<unsigned int, 9> ToCopy2(
MD_ToCopy);
3338 ToCopy2.push_back(LLVMContext::MD_noalias);
3339 if (
auto m =
hasMetadata(&MS,
"enzyme_zerostack"))
3340 cal->setMetadata(
"enzyme_zerostack", m);
3341 cal->copyMetadata(MS, ToCopy2);
3342 cal->setAttributes(MS.getAttributes());
3343 cal->setCallingConv(MS.getCallingConv());
3344 cal->setTailCallKind(MS.getTailCallKind());
3358 Value *op1l = gutils->
lookupM(op1, Builder2);
3361 op3l = gutils->
lookupM(op3l, BuilderZ);
3362 length = gutils->
lookupM(length, Builder2);
3363 auto rule = [&](Value *op0) {
3364 if (seg_start != 0) {
3365 Value *idxs[] = {ConstantInt::get(
3366 Type::getInt32Ty(op0->getContext()), seg_start)};
3367 op0 = Builder2.CreateInBoundsGEP(Type::getInt8Ty(op0->getContext()),
3370 SmallVector<Value *, 4> args = {op0, op1l, length};
3372 args.push_back(op3l);
3375 if (
startsWith(funcName,
"memset_pattern") ||
3376 startsWith(funcName,
"llvm.experimental.memset"))
3377 cal = Builder2.CreateMemSet(
3378 op0, ConstantInt::get(Builder2.getInt8Ty(), 0), length, {});
3380 cal = Builder2.CreateCall(MS.getCalledFunction(), args, Defs);
3381 llvm::SmallVector<unsigned int, 9> ToCopy2(
MD_ToCopy);
3382 ToCopy2.push_back(LLVMContext::MD_noalias);
3383 cal->copyMetadata(MS, ToCopy2);
3384 if (
auto m =
hasMetadata(&MS,
"enzyme_zerostack"))
3385 cal->setMetadata(
"enzyme_zerostack", m);
3387 if (
startsWith(funcName,
"memset_pattern") ||
3388 startsWith(funcName,
"llvm.experimental.memset")) {
3389 AttributeList NewAttrs;
3391 {AttributeList::ReturnIndex, AttributeList::FunctionIndex,
3392 AttributeList::FirstArgIndex})
3393 for (
auto attr : MS.getAttributes().getAttributes(idx))
3395 NewAttrs.addAttributeAtIndex(MS.getContext(), idx, attr);
3396 cal->setAttributes(NewAttrs);
3398 cal->setAttributes(MS.getAttributes());
3399 cal->setCallingConv(MS.getCallingConv());
3420 llvm::MaybeAlign dstAlign, llvm::CallInst &MTI,
3421 llvm::Value *orig_dst, llvm::Value *orig_src,
3422 llvm::Value *new_size, llvm::Value *isVolatile) {
3423 using namespace llvm;
3430 if (unnecessaryStores.count(&MTI)) {
3436 if (
auto ci = dyn_cast<ConstantInt>(new_size)) {
3437 if (ci->getValue() == 1) {
3445 if (isa<ConstantPointerNull>(orig_dst) ||
3452 if (
auto ci = dyn_cast<ConstantInt>(new_size)) {
3453 size = ci->getLimitedValue();
3471 std::vector<std::tuple<Type *, size_t, size_t>> toIterate;
3476 if (
auto MD =
hasMetadata(&MTI,
"enzyme_truetype")) {
3481 auto &DL = gutils->
newFunc->getParent()->getDataLayout();
3485 for (
size_t i = 0; i < MTI.getNumOperands(); i++)
3486 if (MTI.getOperand(i) == orig_dst)
3487 if (MTI.getAttributes().hasParamAttr(i,
"enzyme_type")) {
3488 auto attr = MTI.getAttributes().getParamAttr(i,
"enzyme_type");
3491 vd |= TT.Data0().ShiftIndices(DL, 0, size, 0);
3495 bool errorIfNoType =
true;
3499 errorIfNoType =
false;
3502 if (!vd.isKnownPastPointer()) {
3504 for (
auto val : {orig_dst, orig_src}) {
3505#if LLVM_VERSION_MAJOR < 17
3506 if (
auto CI = dyn_cast<CastInst>(val)) {
3507 if (
auto PT = dyn_cast<PointerType>(CI->getSrcTy())) {
3508 auto ET = PT->getPointerElementType();
3510 if (
auto ST = dyn_cast<StructType>(ET)) {
3511 if (ST->getNumElements()) {
3512 ET = ST->getElementType(0);
3516 if (
auto AT = dyn_cast<ArrayType>(ET)) {
3517 ET = AT->getElementType();
3522 if (ET->isFPOrFPVectorTy()) {
3527 if (ET->isPointerTy()) {
3531 if (ET->isIntOrIntVectorTy()) {
3538 if (
auto gep = dyn_cast<GetElementPtrInst>(val)) {
3539 if (
auto AT = dyn_cast<ArrayType>(gep->getSourceElementType())) {
3540 if (AT->getElementType()->isIntegerTy()) {
3550 if (size == 1 && !isa<ConstantInt>(new_size)) {
3551 for (
auto ptr : {orig_dst, orig_src}) {
3553 if (vd.isKnownPastPointer()) {
3555 size_t minInt = 0xFFFFFFFF;
3556 for (
const auto &pair : vd.getMapping()) {
3557 if (pair.first.size() != 1)
3559 if (minInt < (
size_t)pair.first[0])
3561 minInt = pair.first[0];
3572 "failed to deduce type of copy ", MTI);
3576 if (errorIfNoType) {
3578 raw_string_ostream ss(
str);
3579 ss <<
"Cannot deduce type of copy " << MTI;
3592 unsigned nextStart = size;
3595 for (
size_t i = start; i < size; ++i) {
3598 auto next = vd[{(int)i}];
3599 tmp.checkedOrIn(next,
true, Legal);
3616 if ((dt.isFloat() ==
nullptr) ==
3617 (vd[{(int)i}].isFloat() ==
nullptr)) {
3632 if (!dt.isKnown()) {
3634 llvm::errs() <<
" vd:" << vd.str() <<
" start:" << start
3635 <<
" size: " << size <<
" dt:" << dt.str() <<
"\n";
3637 assert(dt.isKnown());
3638 toIterate.emplace_back(dt.isFloat(), start, nextStart - start);
3640 if (nextStart == size)
3650 unsigned dstalign = dstAlign.valueOrOne().value();
3651 unsigned srcalign = srcAlign.valueOrOne().value();
3653 bool backwardsShadow =
false;
3654 bool forwardsShadow =
true;
3656 if (pair.second.stores.count(&MTI)) {
3657 backwardsShadow =
true;
3658 forwardsShadow = pair.second.primalInitialize;
3659 if (
auto inst = dyn_cast<Instruction>(pair.first))
3660 if (!forwardsShadow && pair.second.LI &&
3661 pair.second.LI->contains(inst->getParent()))
3662 backwardsShadow =
false;
3666 for (
auto &&[floatTy_ref, seg_start_ref, seg_size_ref] : toIterate) {
3667 auto floatTy = floatTy_ref;
3668 auto seg_start = seg_start_ref;
3669 auto seg_size = seg_size_ref;
3671 Value *length = new_size;
3672 if (seg_start != std::get<1>(toIterate.back())) {
3673 length = ConstantInt::get(new_size->getType(), seg_start + seg_size);
3676 length = BuilderZ.CreateSub(
3677 length, ConstantInt::get(new_size->getType(), seg_start));
3679 unsigned subdstalign = dstalign;
3681 if (dstalign != 0) {
3682 if (seg_start % dstalign != 0) {
3686 unsigned subsrcalign = srcalign;
3688 if (srcalign != 0) {
3689 if (seg_start % srcalign != 0) {
3701 auto rev_rule = [&](Value *shadow_dst, Value *shadow_src) {
3702 if (shadow_dst ==
nullptr)
3704 if (shadow_src ==
nullptr)
3707 gutils, Mode, floatTy, ID, subdstalign, subsrcalign,
3710 length, isVolatile, &MTI,
3711 forwardsShadow,
false,
3715 auto fwd_rule = [&](Value *ddst, Value *dsrc) {
3716 if (ddst ==
nullptr)
3718 if (dsrc ==
nullptr)
3722 dalign = MaybeAlign(subdstalign);
3725 salign = MaybeAlign(subsrcalign);
3726 if (ddst->getType()->isIntegerTy())
3728 BuilderZ.CreateIntToPtr(ddst,
getInt8PtrTy(ddst->getContext()));
3729 if (seg_start != 0) {
3730 ddst = BuilderZ.CreateConstInBoundsGEP1_64(
3731 Type::getInt8Ty(ddst->getContext()), ddst, seg_start);
3736 call = BuilderZ.CreateMemSet(
3737 ddst, ConstantInt::get(Type::getInt8Ty(ddst->getContext()), 0),
3738 length, dalign, isVolatile);
3740 if (dsrc->getType()->isIntegerTy())
3742 BuilderZ.CreateIntToPtr(dsrc,
getInt8PtrTy(dsrc->getContext()));
3743 if (seg_start != 0) {
3744 dsrc = BuilderZ.CreateConstInBoundsGEP1_64(
3745 Type::getInt8Ty(ddst->getContext()), dsrc, seg_start);
3747 if (ID == Intrinsic::memmove) {
3748 call = BuilderZ.CreateMemMove(ddst, dalign, dsrc, salign, length);
3750 call = BuilderZ.CreateMemCpy(ddst, dalign, dsrc, salign, length);
3752 call->setAttributes(MTI.getAttributes());
3755 call->setMetadata(LLVMContext::MD_alias_scope,
3756 MTI.getMetadata(LLVMContext::MD_alias_scope));
3757 call->setMetadata(LLVMContext::MD_noalias,
3758 MTI.getMetadata(LLVMContext::MD_noalias));
3759 call->setMetadata(LLVMContext::MD_tbaa,
3760 MTI.getMetadata(LLVMContext::MD_tbaa));
3761 call->setMetadata(LLVMContext::MD_tbaa_struct,
3762 MTI.getMetadata(LLVMContext::MD_tbaa_struct));
3763 call->setMetadata(LLVMContext::MD_invariant_group,
3764 MTI.getMetadata(LLVMContext::MD_invariant_group));
3765 call->setTailCallKind(MTI.getTailCallKind());
3877 llvm::SmallVectorImpl<llvm::Value *> &orig_ops) {
3878 using namespace llvm;
3880 Module *M = I.getParent()->getParent()->getParent();
3883#if LLVM_VERSION_MAJOR < 20
3884 case Intrinsic::nvvm_ldg_global_i:
3885 case Intrinsic::nvvm_ldg_global_p:
3886 case Intrinsic::nvvm_ldg_global_f:
3888 case Intrinsic::nvvm_ldu_global_i:
3889 case Intrinsic::nvvm_ldu_global_p:
3890 case Intrinsic::nvvm_ldu_global_f: {
3891 auto CI = cast<ConstantInt>(I.getOperand(1));
3900 if (ID == Intrinsic::masked_store) {
3901 auto align0 = cast<ConstantInt>(I.getOperand(2))->getZExtValue();
3902 auto align = MaybeAlign(align0);
3904 I.getOperand(0), align,
3905 false, llvm::AtomicOrdering::NotAtomic,
3906 SyncScope::SingleThread,
3910 if (ID == Intrinsic::masked_load) {
3911 auto align0 = cast<ConstantInt>(I.getOperand(1))->getZExtValue();
3912 auto align = MaybeAlign(align0);
3913 auto &DL = gutils->
newFunc->getParent()->getDataLayout();
3914 bool constantval =
parseTBAA(I, DL,
nullptr)[{-1}].isIntegral();
3921 auto mod = I.getParent()->getParent()->getParent();
3922 auto called = cast<CallInst>(&I)->getCalledFunction();
3925#include "IntrinsicDerivatives.inc"
3933#if LLVM_VERSION_MAJOR <= 20
3934 case Intrinsic::nvvm_barrier0:
3936 case Intrinsic::nvvm_barrier_cta_sync_aligned_all:
3937 case Intrinsic::nvvm_barrier_cta_sync_aligned_count:
3939#if LLVM_VERSION_MAJOR < 22
3940 case Intrinsic::nvvm_barrier0_popc:
3941 case Intrinsic::nvvm_barrier0_and:
3942 case Intrinsic::nvvm_barrier0_or:
3944 case Intrinsic::nvvm_barrier_cta_red_and_aligned_all:
3945 case Intrinsic::nvvm_barrier_cta_red_and_aligned_count:
3946 case Intrinsic::nvvm_barrier_cta_red_or_aligned_all:
3947 case Intrinsic::nvvm_barrier_cta_red_or_aligned_count:
3948 case Intrinsic::nvvm_barrier_cta_red_popc_aligned_all:
3949 case Intrinsic::nvvm_barrier_cta_red_popc_aligned_count:
3951 case Intrinsic::nvvm_membar_cta:
3952 case Intrinsic::nvvm_membar_gl:
3953 case Intrinsic::nvvm_membar_sys:
3954 case Intrinsic::amdgcn_s_barrier:
3959 if (ID == Intrinsic::umax || ID == Intrinsic::smax ||
3960 ID == Intrinsic::abs || ID == Intrinsic::sadd_with_overflow ||
3961 ID == Intrinsic::uadd_with_overflow ||
3962 ID == Intrinsic::smul_with_overflow ||
3963 ID == Intrinsic::umul_with_overflow ||
3964 ID == Intrinsic::ssub_with_overflow ||
3965 ID == Intrinsic::usub_with_overflow)
3968 "failed to deduce type of intrinsic ", I);
3972 llvm::raw_string_ostream ss(s);
3973 ss << *gutils->
oldFunc <<
"\n";
3974 ss << *gutils->
newFunc <<
"\n";
3975 ss <<
"cannot handle (augmented) unknown intrinsic\n" << I;
3976 IRBuilder<> BuilderZ(&I);
3987 IRBuilder<> Builder2(&I);
3990 Value *vdiff =
nullptr;
3992 vdiff =
diffe(&I, Builder2);
3999#if LLVM_VERSION_MAJOR < 22
4000 case Intrinsic::nvvm_barrier0_popc:
4001 case Intrinsic::nvvm_barrier0_and:
4002 case Intrinsic::nvvm_barrier0_or:
4004 case Intrinsic::nvvm_barrier_cta_red_and_aligned_all:
4005 case Intrinsic::nvvm_barrier_cta_red_and_aligned_count:
4006 case Intrinsic::nvvm_barrier_cta_red_or_aligned_all:
4007 case Intrinsic::nvvm_barrier_cta_red_or_aligned_count:
4008 case Intrinsic::nvvm_barrier_cta_red_popc_aligned_all:
4009 case Intrinsic::nvvm_barrier_cta_red_popc_aligned_count:
4012 SmallVector<Value *, 1> args = {};
4013#if LLVM_VERSION_MAJOR > 20
4014 auto cal = cast<CallInst>(Builder2.CreateCall(
4016 M, Intrinsic::nvvm_barrier_cta_sync_aligned_all),
4019 M, Intrinsic::nvvm_barrier_cta_sync_aligned_all)
4020 ->getCallingConv());
4022 auto cal = cast<CallInst>(Builder2.CreateCall(
4025 ->getCallingConv());
4031#if LLVM_VERSION_MAJOR <= 20
4032 case Intrinsic::nvvm_barrier0:
4034 case Intrinsic::nvvm_barrier_cta_sync_aligned_all:
4035 case Intrinsic::nvvm_barrier_cta_sync_aligned_count:
4037 case Intrinsic::amdgcn_s_barrier:
4038 case Intrinsic::nvvm_membar_cta:
4039 case Intrinsic::nvvm_membar_gl:
4040 case Intrinsic::nvvm_membar_sys: {
4041 SmallVector<Value *, 1> args = {};
4042 auto cal = cast<CallInst>(
4049 case Intrinsic::lifetime_start: {
4052 SmallVector<Value *, 2> args = {
4055 Type *tys[] = {args[1]->getType()};
4056 auto cal = Builder2.CreateCall(
4058 cal->setCallingConv(
4060 ->getCallingConv());
4064 case Intrinsic::vector_reduce_fmax: {
4067 auto VT = cast<VectorType>(orig_ops[0]->getType());
4069 assert(!VT->getElementCount().isScalable());
4070 size_t numElems = VT->getElementCount().getKnownMinValue();
4071 SmallVector<Value *> elems;
4072 SmallVector<Value *> cmps;
4074 for (
size_t i = 0; i < numElems; ++i)
4075 elems.push_back(Builder2.CreateExtractElement(prev, (uint64_t)i));
4077 Value *curmax = elems[0];
4078 for (
size_t i = 0; i < numElems - 1; ++i) {
4079 cmps.push_back(Builder2.CreateFCmpOLT(curmax, elems[i + 1]));
4080 if (i + 2 != numElems)
4081 curmax =
CreateSelect(Builder2, cmps[i], elems[i + 1], curmax);
4084 auto rule = [&](Value *vdiff) {
4085 auto nv = Constant::getNullValue(orig_ops[0]->getType());
4086 Value *res = Builder2.CreateInsertElement(nv, vdiff, (uint64_t)0);
4088 for (
size_t i = 0; i < numElems - 1; ++i) {
4089 auto rhs_v = Builder2.CreateInsertElement(nv, vdiff, i + 1);
4096 addToDiffe(orig_ops[0], dif0, Builder2, I.getType());
4103 if (ID == Intrinsic::umax || ID == Intrinsic::smax ||
4104 ID == Intrinsic::abs || ID == Intrinsic::sadd_with_overflow ||
4105 ID == Intrinsic::uadd_with_overflow ||
4106 ID == Intrinsic::smul_with_overflow ||
4107 ID == Intrinsic::umul_with_overflow ||
4108 ID == Intrinsic::ssub_with_overflow ||
4109 ID == Intrinsic::usub_with_overflow)
4112 "failed to deduce type of intrinsic ", I);
4116 llvm::raw_string_ostream ss(s);
4117 ss << *gutils->
oldFunc <<
"\n";
4118 ss << *gutils->
newFunc <<
"\n";
4119 if (Intrinsic::isOverloaded(ID))
4120 ss <<
"cannot handle (reverse) unknown intrinsic\n"
4121 << Intrinsic::getName(ID, ArrayRef<Type *>(),
4122 gutils->
oldFunc->getParent(),
nullptr)
4126 ss <<
"cannot handle (reverse) unknown intrinsic\n"
4127 << Intrinsic::getName(ID) <<
"\n"
4138 IRBuilder<> Builder2(&I);
4143 case Intrinsic::vector_reduce_fmax: {
4147 auto VT = cast<VectorType>(orig_ops[0]->getType());
4149 assert(!VT->getElementCount().isScalable());
4150 size_t numElems = VT->getElementCount().getKnownMinValue();
4151 SmallVector<Value *> elems;
4152 SmallVector<Value *> cmps;
4154 for (
size_t i = 0; i < numElems; ++i)
4155 elems.push_back(Builder2.CreateExtractElement(prev, (uint64_t)i));
4157 Value *curmax = elems[0];
4158 for (
size_t i = 0; i < numElems - 1; ++i) {
4159 cmps.push_back(Builder2.CreateFCmpOLT(curmax, elems[i + 1]));
4160 if (i + 2 != numElems)
4161 curmax =
CreateSelect(Builder2, cmps[i], elems[i + 1], curmax);
4164 auto rule = [&](Value *vdiff) {
4165 Value *res = Builder2.CreateExtractElement(vdiff, (uint64_t)0);
4167 for (
size_t i = 0; i < numElems - 1; ++i) {
4168 auto rhs_v = Builder2.CreateExtractElement(vdiff, i + 1);
4173 auto vdiff =
diffe(orig_ops[0], Builder2);
4187 if (ID == Intrinsic::umax || ID == Intrinsic::smax ||
4188 ID == Intrinsic::abs || ID == Intrinsic::sadd_with_overflow ||
4189 ID == Intrinsic::uadd_with_overflow ||
4190 ID == Intrinsic::smul_with_overflow ||
4191 ID == Intrinsic::umul_with_overflow ||
4192 ID == Intrinsic::ssub_with_overflow ||
4193 ID == Intrinsic::usub_with_overflow)
4196 "failed to deduce type of intrinsic ", I);
4200 llvm::raw_string_ostream ss(s);
4201 if (Intrinsic::isOverloaded(ID))
4202 ss <<
"cannot handle (forward) unknown intrinsic\n"
4203 << Intrinsic::getName(ID, ArrayRef<Type *>(),
4204 gutils->
oldFunc->getParent(),
nullptr)
4208 ss <<
"cannot handle (forward) unknown intrinsic\n"
4209 << Intrinsic::getName(ID) <<
"\n"
4226 using namespace llvm;
4228 Function *kmpc = call.getCalledFunction();
4230 if (overwritten_args_map.find(&call) == overwritten_args_map.end()) {
4231 llvm::errs() <<
" call: " << call <<
"\n";
4232 for (
auto &pair : overwritten_args_map) {
4233 llvm::errs() <<
" + " << *pair.first <<
"\n";
4237 auto found_ow = overwritten_args_map.find(&call);
4238 assert(found_ow != overwritten_args_map.end());
4239 const bool subsequent_calls_may_write = found_ow->second.first;
4240 const std::vector<bool> &overwritten_args = found_ow->second.second;
4243 BuilderZ.setFastMathFlags(
getFast());
4245 Function *task = dyn_cast<Function>(call.getArgOperand(2));
4246 if (task ==
nullptr && isa<ConstantExpr>(call.getArgOperand(2))) {
4247 task = dyn_cast<Function>(
4248 cast<ConstantExpr>(call.getArgOperand(2))->getOperand(0));
4250 if (task ==
nullptr) {
4251 llvm::errs() <<
"could not derive underlying task from omp call: " << call
4253 llvm_unreachable(
"could not derive underlying task from omp call");
4255 if (task->empty()) {
4257 <<
"could not derive underlying task contents from omp call: " << call
4260 "could not derive underlying task contents from omp call");
4266 bool foreignFunction = called ==
nullptr;
4268 SmallVector<Value *, 8> args = {0, 0, 0};
4269 SmallVector<Value *, 8> pre_args = {0, 0, 0};
4272 SmallVector<Instruction *, 4> postCreate;
4273 SmallVector<Instruction *, 4> userReplace;
4275 SmallVector<Value *, 4> OutTypes;
4276 SmallVector<Type *, 4> OutFPTypes;
4278 for (
unsigned i = 3; i < call.arg_size(); ++i) {
4282 pre_args.push_back(argi);
4285 IRBuilder<> Builder2(&call);
4287 args.push_back(
lookup(argi, Builder2));
4290 auto argTy = gutils->
getDiffeType(call.getArgOperand(i), foreignFunction);
4291 argsInverted.push_back(argTy);
4297 auto argType = argi->getType();
4301 IRBuilder<> Builder2(&call);
4315 assert(TR.
query(call.getArgOperand(i))[{-1}].isFloat());
4316 OutTypes.push_back(call.getArgOperand(i));
4317 OutFPTypes.push_back(argType);
4325 Value *tape =
nullptr;
4326 CallInst *augmentcall =
nullptr;
4333 std::map<Value *, std::set<int64_t>> intseen;
4340 for (
auto &arg : called->args()) {
4343 std::pair<Argument *, TypeTree>(&arg, IntPtr));
4345 std::pair<Argument *, std::set<int64_t>>(&arg, {0}));
4347 nextTypeInfo.
Arguments.insert(std::pair<Argument *, TypeTree>(
4348 &arg, TR.
query(call.getArgOperand(argnum - 2 + 3))));
4350 std::pair<Argument *, std::set<int64_t>>(
4368 assert(augmentedReturn);
4369 if (augmentedReturn) {
4372 subdata = fd->second;
4384 false, nextTypeInfo,
4385 subsequent_calls_may_write, overwritten_args,
false,
4390 assert(augmentedReturn);
4391 auto subaugmentations =
4392 (std::map<const llvm::CallInst *, AugmentedReturn *>
4399 auto newcalled = subdata->
fn;
4403 ValueToValueMapTy VMap;
4404 newcalled = CloneFunction(newcalled, VMap);
4405 auto tapeArg = newcalled->arg_end();
4407 Type *tapeElemType = subdata->
tapeType;
4408 SmallVector<std::pair<ssize_t, Value *>, 4> geps;
4409 SmallPtrSet<Instruction *, 4> gepsToErase;
4410 for (
auto a : tapeArg->users()) {
4411 if (
auto gep = dyn_cast<GetElementPtrInst>(a)) {
4412 auto idx = gep->idx_begin();
4414 auto cidx = cast<ConstantInt>(idx->get());
4415 assert(gep->getNumIndices() == 2);
4416 SmallPtrSet<StoreInst *, 1> storesToErase;
4417 for (
auto st : gep->users()) {
4418 auto SI = cast<StoreInst>(st);
4419 Value *op = SI->getValueOperand();
4420 storesToErase.insert(SI);
4421 geps.emplace_back(cidx->getLimitedValue(), op);
4423 for (
auto SI : storesToErase)
4424 SI->eraseFromParent();
4425 gepsToErase.insert(gep);
4426 }
else if (
auto SI = dyn_cast<StoreInst>(a)) {
4427 Value *op = SI->getValueOperand();
4428 gepsToErase.insert(SI);
4429 geps.emplace_back(-1, op);
4431 llvm::errs() <<
"unknown tape user: " << a <<
"\n";
4432 assert(0 &&
"unknown tape user");
4433 llvm_unreachable(
"unknown tape user");
4436 for (
auto gep : gepsToErase)
4437 gep->eraseFromParent();
4438 IRBuilder<> ph(&*newcalled->getEntryBlock().begin());
4439 tape = UndefValue::get(tapeElemType);
4440 ValueToValueMapTy available;
4441 auto subarg = newcalled->arg_begin();
4444 for (
size_t i = 3; i < pre_args.size(); ++i) {
4445 available[&*subarg] = pre_args[i];
4448 for (
auto pair : geps) {
4449 Value *op = pair.second;
4451 Value *replacement = gutils->
unwrapM(op, BuilderZ, available,
4456 : BuilderZ.CreateInsertValue(tape, replacement, pair.first);
4457 if (
auto ci = dyn_cast<CastInst>(alloc)) {
4458 alloc = ci->getOperand(0);
4460 if (
auto uload = dyn_cast<Instruction>(replacement)) {
4462 if (
auto ci = dyn_cast<CastInst>(replacement)) {
4463 if (
auto ucast = dyn_cast<Instruction>(ci->getOperand(0)))
4467 if (
auto ci = dyn_cast<CallInst>(alloc)) {
4468 if (
auto F = ci->getCalledFunction()) {
4470 if (F->getName() ==
"malloc") {
4472 ->tapeIndiciesToFree.emplace(pair.first);
4473 Value *toload = tapeArg;
4474 if (pair.first != -1) {
4477 Type::getInt64Ty(tapeArg->getContext()), 0),
4479 Type::getInt32Ty(tapeArg->getContext()),
4481 toload = ph.CreateInBoundsGEP(tapeElemType, toload, Idxs);
4483 op->replaceAllUsesWith(ph.CreateLoad(op->getType(), toload));
4484 cast<Instruction>(op)->eraseFromParent();
4486 ci->eraseFromParent();
4492 ConstantInt::get(Type::getInt64Ty(tapeArg->getContext()), 0),
4493 ConstantInt::get(Type::getInt32Ty(tapeArg->getContext()),
4495 op->replaceAllUsesWith(ph.CreateLoad(
4499 : ph.CreateInBoundsGEP(tapeElemType, tapeArg, Idxs)));
4500 cast<Instruction>(op)->eraseFromParent();
4505 BuilderZ.CreateStore(tape, alloc);
4506 pre_args.push_back(alloc);
4512 auto numargs = ConstantInt::get(Type::getInt32Ty(call.getContext()),
4513 pre_args.size() - 3);
4515 pre_args[1] = numargs;
4516 pre_args[2] = BuilderZ.CreatePointerCast(
4517 newcalled, kmpc->getFunctionType()->getParamType(2));
4519 BuilderZ.CreateCall(kmpc->getFunctionType(), kmpc, pre_args);
4520 augmentcall->setCallingConv(call.getCallingConv());
4521 augmentcall->setDebugLoc(
4523 BuilderZ.SetInsertPoint(
4527 assert(0 &&
"unhandled unknown outline");
4532 Intrinsic::ID ID = Intrinsic::not_intrinsic;
4534 llvm::errs() << *gutils->
oldFunc->getParent() <<
"\n";
4535 llvm::errs() << *gutils->
oldFunc <<
"\n";
4536 llvm::errs() << *gutils->
newFunc <<
"\n";
4537 llvm::errs() << *called <<
"\n";
4538 llvm_unreachable(
"no subdata");
4544 assert(found == subdata->
returns.end());
4548 assert(found == subdata->
returns.end());
4553 IRBuilder<> Builder2(&call);
4557 BuilderZ.SetInsertPoint(
4562 Function *newcalled =
nullptr;
4567 if (tape ==
nullptr) {
4568#if LLVM_VERSION_MAJOR >= 18
4569 auto It = BuilderZ.GetInsertPoint();
4570 It.setHeadBit(
true);
4571 BuilderZ.SetInsertPoint(It);
4573 tape = BuilderZ.CreatePHI(subdata->
tapeType, 0,
"tapeArg");
4578 tape =
lookup(tape, Builder2);
4580 .CreateAlloca(tape->getType());
4581 Builder2.CreateStore(tape, alloc);
4582 args.push_back(alloc);
4586 for (
size_t i = 0; i < argsInverted.size(); i++) {
4598 .todiff = cast<Function>(called),
4599 .retType = subretType,
4600 .constant_args = argsInverted,
4601 .subsequent_calls_may_write = subsequent_calls_may_write,
4602 .overwritten_args = overwritten_args,
4603 .returnUsed =
false,
4604 .shadowReturnUsed =
false,
4609 .additionalType = tape ?
getUnqual(tape->getType()) :
nullptr,
4610 .forceAnonymousTape =
false,
4611 .typeInfo = nextTypeInfo,
4619 auto tapeArg = newcalled->arg_end();
4621 LoadInst *tape =
nullptr;
4622 for (
auto u : tapeArg->users()) {
4624 if (!isa<LoadInst>(u)) {
4625 llvm::errs() <<
" newcalled: " << *newcalled <<
"\n";
4626 llvm::errs() <<
" u: " << *u <<
"\n";
4628 tape = cast<LoadInst>(u);
4631 SmallVector<Value *, 4> extracts;
4633 assert(subdata->
tapeIndices.begin()->second == -1);
4634 extracts.push_back(tape);
4636 for (
auto a : tape->users()) {
4637 extracts.push_back(a);
4640 SmallVector<LoadInst *, 4> geps;
4641 for (
auto E : extracts) {
4642 AllocaInst *AI =
nullptr;
4643 for (
auto U : E->users()) {
4644 if (
auto SI = dyn_cast<StoreInst>(U)) {
4645 assert(SI->getValueOperand() == E);
4646 AI = cast<AllocaInst>(SI->getPointerOperand());
4650 for (
auto U : AI->users()) {
4651 if (
auto LI = dyn_cast<LoadInst>(U)) {
4657 for (
auto LI : geps) {
4658 CallInst *freeCall =
nullptr;
4659 for (
auto LU : LI->users()) {
4660 if (
auto CI = dyn_cast<CallInst>(LU)) {
4661 if (
auto F = CI->getCalledFunction()) {
4662 if (F->getName() ==
"free") {
4667 }
else if (
auto BC = dyn_cast<CastInst>(LU)) {
4668 for (
auto CU : BC->users()) {
4669 if (
auto CI = dyn_cast<CallInst>(CU)) {
4670 if (
auto F = CI->getCalledFunction()) {
4671 if (F->getName() ==
"free") {
4683 freeCall->eraseFromParent();
4688 Value *OutAlloc =
nullptr;
4689 auto ST = StructType::get(newcalled->getContext(), OutFPTypes);
4690 if (OutTypes.size()) {
4692 args.push_back(OutAlloc);
4694 SmallVector<Type *, 3> MetaTypes;
4696 cast<Function>(newcalled)->getFunctionType()->params()) {
4697 MetaTypes.push_back(P);
4700 auto FT = FunctionType::get(Type::getVoidTy(newcalled->getContext()),
4703 Function::Create(FT, GlobalVariable::InternalLinkage,
4704 cast<Function>(newcalled)->getName() +
"#out",
4705 *task->getParent());
4707 BasicBlock::Create(newcalled->getContext(),
"entry", F);
4708 IRBuilder<> B(entry);
4709 SmallVector<Value *, 2> SubArgs;
4710 for (
auto &arg : F->args())
4711 SubArgs.push_back(&arg);
4712 Value *cacheArg = SubArgs.back();
4714 Value *outdiff = B.CreateCall(newcalled, SubArgs);
4715 for (
size_t ee = 0; ee < OutTypes.size(); ee++) {
4716 Value *dif = B.CreateExtractValue(outdiff, ee);
4718 ConstantInt::get(Type::getInt64Ty(ST->getContext()), 0),
4719 ConstantInt::get(Type::getInt32Ty(ST->getContext()), ee)};
4720 Value *ptr = B.CreateInBoundsGEP(ST, cacheArg, Idxs);
4722 if (dif->getType()->isIntOrIntVectorTy()) {
4724 ptr = B.CreateBitCast(
4728 cast<PointerType>(ptr->getType())->getAddressSpace()));
4729 dif = B.CreateBitCast(dif,
IntToFloatTy(dif->getType()));
4733 AtomicRMWInst::BinOp op = AtomicRMWInst::FAdd;
4734 if (
auto vt = dyn_cast<VectorType>(dif->getType())) {
4735 assert(!vt->getElementCount().isScalable());
4736 size_t numElems = vt->getElementCount().getKnownMinValue();
4737 for (
size_t i = 0; i < numElems; ++i) {
4738 auto vdif = B.CreateExtractElement(dif, i);
4740 ConstantInt::get(Type::getInt64Ty(vt->getContext()), 0),
4741 ConstantInt::get(Type::getInt32Ty(vt->getContext()), i)};
4742 auto vptr = B.CreateInBoundsGEP(vt, ptr, Idxs);
4743 B.CreateAtomicRMW(op, vptr, vdif, align,
4744 AtomicOrdering::Monotonic, SyncScope::System);
4747 B.CreateAtomicRMW(op, ptr, dif, align, AtomicOrdering::Monotonic,
4755 auto numargs = ConstantInt::get(Type::getInt32Ty(call.getContext()),
4760 args[2] = Builder2.CreatePointerCast(
4761 newcalled, kmpc->getFunctionType()->getParamType(2));
4764 Builder2.CreateCall(kmpc->getFunctionType(), kmpc, args);
4765 diffes->setCallingConv(call.getCallingConv());
4768 for (
size_t i = 0; i < OutTypes.size(); i++) {
4771 if (OutTypes[i]->getType()->isSized())
4772 size = (gutils->
newFunc->getParent()
4774 .getTypeSizeInBits(OutTypes[i]->getType()) +
4778 ConstantInt::get(Type::getInt64Ty(call.getContext()), 0),
4779 ConstantInt::get(Type::getInt32Ty(call.getContext()), i)};
4781 ->addToDiffe(OutTypes[i],
4782 Builder2.CreateLoad(
4784 Builder2.CreateInBoundsGEP(ST, OutAlloc, Idxs)),
4792 : Builder2.CreateExtractValue(tape, idx));
4796 assert(0 &&
"openmp indirect unhandled");
4918 llvm::Function *called,
4919 bool subsequent_calls_may_write,
4920 const std::vector<bool> &overwritten_args,
4921 bool shadowReturnUsed,
4923 using namespace llvm;
4926 BuilderZ.setFastMathFlags(
getFast());
4929 Module &M = *call.getParent()->getParent()->getParent();
4931 bool foreignFunction = called ==
nullptr;
4942 assert(augmentedReturn);
4943 if (augmentedReturn) {
4946 subdata = fd->second;
4954 IRBuilder<> Builder2(&call);
4957 SmallVector<Value *, 8> args;
4958 std::vector<DIFFE_TYPE> argsInverted;
4959 std::map<int, Type *> gradByVal;
4960 std::map<int, std::vector<Attribute>> structAttrs;
4962 for (
unsigned i = 0; i < call.arg_size(); ++i) {
4964 if (call.paramHasAttr(i, Attribute::StructRet)) {
4965 structAttrs[args.size()].push_back(Attribute::get(
4966 call.getContext(),
"enzyme_sret",
4968 .getValueAsType())));
4970 for (
auto attr : {
"enzymejl_returnRoots",
"enzymejl_parmtype",
4971 "enzymejl_parmtype_ref",
"enzyme_type",
4972 "enzymejl_sret_union_bytes",
"enzymejl_rooted_typ"})
4973 if (call.getAttributes().hasParamAttr(i, attr)) {
4974 structAttrs[args.size()].push_back(call.getParamAttr(i, attr));
4977 if (call.getAttributes().hasParamAttr(i, ty)) {
4978 auto attr = call.getAttributes().getParamAttr(i, ty);
4979 structAttrs[args.size()].push_back(attr);
4984 if (call.isByValArgument(i)) {
4985 gradByVal[args.size()] = call.getParamByValType(i);
4988 bool writeOnlyNoCapture =
true;
4989 bool readOnly =
true;
4991 writeOnlyNoCapture =
false;
4994 writeOnlyNoCapture =
false;
5001 writeOnlyNoCapture =
false;
5004 gutils->
getDiffeType(call.getArgOperand(i), foreignFunction);
5008 (writeOnlyNoCapture ||
5011 (writeOnlyNoCapture && readOnly);
5016 argsInverted.push_back(argTy);
5017 args.push_back(argi);
5025 if (call.getAttributes().hasParamAttr(i, ty)) {
5026 auto attr = call.getAttributes().getParamAttr(i, ty);
5027 structAttrs[args.size()].push_back(attr);
5030 for (
auto attr : {
"enzymejl_returnRoots",
"enzymejl_parmtype",
5031 "enzymejl_parmtype_ref",
"enzyme_type",
5032 "enzymejl_sret_union_bytes",
"enzymejl_rooted_typ"})
5033 if (call.getAttributes().hasParamAttr(i, attr)) {
5035 structAttrs[args.size()].push_back(call.getParamAttr(i, attr));
5036 }
else if (attr == std::string(
"enzymejl_returnRoots")) {
5037 structAttrs[args.size()].push_back(
5038 Attribute::get(call.getContext(),
"enzymejl_returnRoots_v",
5039 call.getAttributes()
5040 .getParamAttr(i,
"enzymejl_returnRoots")
5041 .getValueAsString()));
5042 }
else if (attr == std::string(
"enzymejl_sret_union_bytes")) {
5043 structAttrs[args.size()].push_back(Attribute::get(
5044 call.getContext(),
"enzymejl_sret_union_bytes_v",
5045 call.getAttributes()
5046 .getParamAttr(i,
"enzymejl_sret_union_bytes")
5047 .getValueAsString()));
5048 }
else if (attr == std::string(
"enzymejl_rooted_typ")) {
5049 structAttrs[args.size()].push_back(
5050 Attribute::get(call.getContext(),
"enzymejl_rooted_typ_v",
5051 call.getAttributes()
5052 .getParamAttr(i,
"enzymejl_rooted_typ")
5053 .getValueAsString()));
5056 if (call.paramHasAttr(i, Attribute::StructRet)) {
5058 structAttrs[args.size()].push_back(
5059 Attribute::get(call.getContext(),
"enzyme_sret",
5061 call.getParamAttr(i, Attribute::StructRet)
5062 .getValueAsType())));
5064 structAttrs[args.size()].push_back(
5065 Attribute::get(call.getContext(),
"enzyme_sret_v",
5067 call.getParamAttr(i, Attribute::StructRet)
5068 .getValueAsType())));
5074 args.push_back(gutils->
invertPointerM(call.getArgOperand(i), Builder2));
5076#if LLVM_VERSION_MAJOR >= 16
5077 std::optional<int> tapeIdx;
5079 Optional<int> tapeIdx;
5083 if (found != subdata->
returns.end()) {
5084 tapeIdx = found->second;
5087 Value *tape =
nullptr;
5090 auto idx = *tapeIdx;
5091 FunctionType *FT = subdata->
fn->getFunctionType();
5092#if LLVM_VERSION_MAJOR >= 18
5093 auto It = BuilderZ.GetInsertPoint();
5094 It.setHeadBit(
true);
5095 BuilderZ.SetInsertPoint(It);
5097 tape = BuilderZ.CreatePHI(
5099 ? FT->getReturnType()
5100 : cast<StructType>(FT->getReturnType())->getElementType(idx),
5103 assert(!tape->getType()->isEmptyTy());
5107 args.push_back(tape);
5110 Value *newcalled =
nullptr;
5111 FunctionType *FT =
nullptr;
5120 tape ? tape->getType() :
nullptr, nextTypeInfo,
5121 subsequent_calls_may_write, overwritten_args,
5123 FT = cast<Function>(newcalled)->getFunctionType();
5125 auto callval = call.getCalledOperand();
5129 newcalled = BuilderZ.CreateExtractValue(newcalled, {0});
5134 "Attempting to call an indirect active function "
5135 "whose runtime value is inactive",
5138 auto ft = call.getFunctionType();
5142 ft, Mode, gutils->
getWidth(), tape ? tape->getType() :
nullptr,
5143 argsInverted,
false,
false,
5144 subretused, retActive);
5146 newcalled = BuilderZ.CreatePointerCast(newcalled,
getUnqual(fptype));
5147 newcalled = BuilderZ.CreateLoad(fptype, newcalled);
5153 SmallVector<ValueType, 2> BundleTypes;
5154 for (
auto A : argsInverted)
5163 CallInst *diffes = Builder2.CreateCall(FT, newcalled, args, Defs);
5164 diffes->setCallingConv(call.getCallingConv());
5167 for (
auto pair : gradByVal) {
5168 diffes->addParamAttr(
5170 Attribute::getWithByValType(diffes->getContext(), pair.second));
5173 for (
auto &pair : structAttrs) {
5174 for (
auto val : pair.second)
5175 diffes->addParamAttr(pair.first, val);
5180 Value *primal =
nullptr;
5181 Value *
diffe =
nullptr;
5184 primal = Builder2.CreateExtractValue(diffes, 0);
5185 diffe = Builder2.CreateExtractValue(diffes, 1);
5188 }
else if (!FT->getReturnType()->isVoidTy()) {
5193 auto placeholder = cast<PHINode>(&*ifound->second);
5196 gutils->
erase(newcall);
5205 gutils->
erase(placeholder);
5207 if (primal &&
diffe) {
5212 gutils->
erase(newcall);
5216 }
else if (primal) {
5218 gutils->
erase(newcall);
5229 SmallVector<Value *, 8> args;
5230 SmallVector<Value *, 8> pre_args;
5231 std::vector<DIFFE_TYPE> argsInverted;
5232 SmallVector<Instruction *, 4> postCreate;
5233 SmallVector<Instruction *, 4> userReplace;
5234 std::map<int, Type *> preByVal;
5235 std::map<int, Type *> gradByVal;
5236 std::map<int, std::vector<Attribute>> structAttrs;
5238 bool replaceFunction =
false;
5242 &call, *replacedReturns, postCreate, userReplace, gutils,
5243 unnecessaryInstructions, oldUnreachable, subretused);
5244 if (replaceFunction) {
5245 modifyPrimal =
false;
5249 SmallVector<ValueType, 2> PreBundleTypes;
5250 SmallVector<ValueType, 2> BundleTypes;
5252 for (
unsigned i = 0; i < call.arg_size(); ++i) {
5256 if (call.isByValArgument(i)) {
5257 preByVal[pre_args.size()] = call.getParamByValType(i);
5259 for (
auto attr : {
"enzymejl_returnRoots",
"enzymejl_parmtype",
5260 "enzymejl_parmtype_ref",
"enzyme_type",
5261 "enzymejl_sret_union_bytes",
"enzymejl_rooted_typ"})
5262 if (call.getAttributes().hasParamAttr(i, attr)) {
5263 structAttrs[pre_args.size()].push_back(call.getParamAttr(i, attr));
5265 if (call.paramHasAttr(i, Attribute::StructRet)) {
5266 structAttrs[pre_args.size()].push_back(Attribute::get(
5267 call.getContext(),
"enzyme_sret",
5269 call.getParamAttr(i, Attribute::StructRet).getValueAsType())));
5272 if (call.getAttributes().hasParamAttr(i, ty)) {
5273 auto attr = call.getAttributes().getParamAttr(i, ty);
5274 structAttrs[pre_args.size()].push_back(attr);
5277 auto argTy = gutils->
getDiffeType(call.getArgOperand(i), foreignFunction);
5279 bool writeOnlyNoCapture =
true;
5280 bool readNoneNoCapture =
false;
5282 writeOnlyNoCapture =
false;
5283 readNoneNoCapture =
false;
5286 writeOnlyNoCapture =
false;
5289 readNoneNoCapture =
false;
5293 writeOnlyNoCapture =
false;
5294 readNoneNoCapture =
false;
5297 Value *prearg = argi;
5303 if (readNoneNoCapture ||
5305 (writeOnlyNoCapture ||
5310 pre_args.push_back(prearg);
5313 IRBuilder<> Builder2(&call);
5316 if (call.isByValArgument(i)) {
5317 gradByVal[args.size()] = call.getParamByValType(i);
5320 if ((writeOnlyNoCapture && !replaceFunction) ||
5321 (readNoneNoCapture ||
5323 (writeOnlyNoCapture ||
5328 args.push_back(
lookup(argi, Builder2));
5331 argsInverted.push_back(argTy);
5334 PreBundleTypes.push_back(preType);
5335 BundleTypes.push_back(revType);
5339 auto argType = argi->getType();
5344 if (call.getAttributes().hasParamAttr(i, ty)) {
5345 auto attr = call.getAttributes().getParamAttr(i, ty);
5346 structAttrs[pre_args.size()].push_back(attr);
5349 for (
auto attr : {
"enzymejl_returnRoots",
"enzymejl_parmtype",
5350 "enzymejl_parmtype_ref",
"enzyme_type",
5351 "enzymejl_sret_union_bytes",
"enzymejl_rooted_typ"})
5352 if (call.getAttributes().hasParamAttr(i, attr)) {
5354 structAttrs[pre_args.size()].push_back(
5355 call.getParamAttr(i, attr));
5356 }
else if (attr == std::string(
"enzymejl_returnRoots")) {
5357 structAttrs[pre_args.size()].push_back(
5358 Attribute::get(call.getContext(),
"enzymejl_returnRoots_v",
5359 call.getAttributes()
5360 .getParamAttr(i, attr)
5361 .getValueAsString()));
5362 }
else if (attr == std::string(
"enzymejl_sret_union_bytes")) {
5363 structAttrs[pre_args.size()].push_back(Attribute::get(
5364 call.getContext(),
"enzymejl_sret_union_bytes_v",
5365 call.getAttributes()
5366 .getParamAttr(i, attr)
5367 .getValueAsString()));
5368 }
else if (attr == std::string(
"enzymejl_rooted_typ")) {
5369 structAttrs[pre_args.size()].push_back(
5370 Attribute::get(call.getContext(),
"enzymejl_rooted_typ_v",
5371 call.getAttributes()
5372 .getParamAttr(i, attr)
5373 .getValueAsString()));
5376 if (call.paramHasAttr(i, Attribute::StructRet)) {
5378 structAttrs[pre_args.size()].push_back(
5379 Attribute::get(call.getContext(),
"enzyme_sret",
5381 call.getParamAttr(i, Attribute::StructRet)
5382 .getValueAsType())));
5384 structAttrs[pre_args.size()].push_back(
5385 Attribute::get(call.getContext(),
"enzyme_sret_v",
5387 call.getParamAttr(i, Attribute::StructRet)
5388 .getValueAsType())));
5392 IRBuilder<> Builder2(&call);
5395 Value *darg =
nullptr;
5397 if (((writeOnlyNoCapture && TR.
query(call.getArgOperand(
5408 args.push_back(
lookup(darg, Builder2));
5424 raw_string_ostream ss(
str);
5425 ss <<
"Mismatched estimated activity type for " << *argType
5426 <<
" expected DUP_ARG or CONSTANT found " << wt
5427 <<
", call = " << call <<
"\n";
5433 EmitFailure(
"MismatchArgType", call.getDebugLoc(), &call, ss.str());
5437 if (foreignFunction)
5438 assert(!argType->isIntOrIntVectorTy());
5442 PreBundleTypes.push_back(preType);
5443 BundleTypes.push_back(revType);
5446 if (call.arg_size() !=
5447 cast<Function>(called)->getFunctionType()->getNumParams()) {
5448 llvm::errs() << *gutils->
oldFunc->getParent() <<
"\n";
5449 llvm::errs() << *gutils->
oldFunc <<
"\n";
5450 llvm::errs() << call <<
"\n";
5451 llvm::errs() <<
" number of arg operands != function parameters\n";
5452 EmitFailure(
"MismatchArgs", call.getDebugLoc(), &call,
5453 "Number of arg operands != function parameters\n", call);
5457 Value *tape =
nullptr;
5458 CallInst *augmentcall =
nullptr;
5459 Value *cachereplace =
nullptr;
5463#if LLVM_VERSION_MAJOR >= 16
5464 std::optional<int> tapeIdx;
5465 std::optional<int> returnIdx;
5466 std::optional<int> differetIdx;
5468 Optional<int> tapeIdx;
5469 Optional<int> returnIdx;
5470 Optional<int> differetIdx;
5474 Value *newcalled =
nullptr;
5475 FunctionType *FT =
nullptr;
5479 auto callval = call.getCalledOperand();
5480 Value *uncast = callval;
5481 while (
auto CE = dyn_cast<ConstantExpr>(uncast)) {
5483 uncast = CE->getOperand(0);
5488 if (isa<ConstantInt>(uncast)) {
5490 raw_string_ostream ss(
str);
5491 ss <<
"cannot find shadow for " << *callval
5492 <<
" for use as function in " << call;
5500 "Attempting to call an indirect active function "
5501 "whose runtime value is inactive",
5504 FunctionType *ft = call.getFunctionType();
5506 std::set<llvm::Type *> seen;
5510 ft,
true, subretType);
5511 FT = FunctionType::get(
5512 StructType::get(newcalled->getContext(), res.second), res.first,
5515 newcalled = BuilderZ.CreatePointerCast(newcalled,
getUnqual(fptype));
5516 newcalled = BuilderZ.CreateLoad(fptype, newcalled);
5519 if (!call.getType()->isVoidTy()) {
5532 subretused, shadowReturnUsed, nextTypeInfo,
5533 subsequent_calls_may_write, overwritten_args,
false,
5537 assert(augmentedReturn);
5538 auto subaugmentations =
5539 (std::map<const llvm::CallInst *, AugmentedReturn *>
5546 Intrinsic::ID ID = Intrinsic::not_intrinsic;
5549 llvm::errs() << *gutils->
oldFunc->getParent() <<
"\n";
5550 llvm::errs() << *gutils->
oldFunc <<
"\n";
5551 llvm::errs() << *gutils->
newFunc <<
"\n";
5552 llvm::errs() << *called <<
"\n";
5558 fnandtapetype = subdata;
5559 newcalled = subdata->
fn;
5560 FT = cast<Function>(newcalled)->getFunctionType();
5564 if (found != subdata->
returns.end()) {
5565 differetIdx = found->second;
5567 assert(!shadowReturnUsed);
5571 if (found != subdata->
returns.end()) {
5572 returnIdx = found->second;
5574 assert(!subretused);
5578 if (found != subdata->
returns.end()) {
5579 tapeIdx = found->second;
5595 auto NC = dyn_cast<Function>(newcalled);
5596 llvm::errs() << *gutils->
oldFunc <<
"\n";
5597 llvm::errs() << *gutils->
newFunc <<
"\n";
5599 llvm::errs() <<
" trying to call " << NC->getName() <<
" " << *FT
5602 llvm::errs() <<
" trying to call " << *newcalled <<
" " << *FT
5605 for (
unsigned i = 0; i < pre_args.size(); ++i) {
5606 assert(pre_args[i]);
5607 assert(pre_args[i]->getType());
5608 llvm::errs() <<
"args[" << i <<
"] = " << *pre_args[i]
5609 <<
" FT:" << *FT->getParamType(i) <<
"\n";
5611 assert(0 &&
"calling with wrong number of arguments");
5615 if (pre_args.size() != FT->getNumParams())
5616 goto badaugmentedfn;
5618 for (
unsigned i = 0; i < pre_args.size(); ++i) {
5619 if (pre_args[i]->getType() == FT->getParamType(i))
5621 else if (!call.getCalledFunction())
5623 BuilderZ.CreateBitCast(pre_args[i], FT->getParamType(i));
5625 goto badaugmentedfn;
5628 augmentcall = BuilderZ.CreateCall(
5629 FT, newcalled, pre_args,
5632 augmentcall->setCallingConv(call.getCallingConv());
5633 augmentcall->setDebugLoc(
5636 for (
auto pair : preByVal) {
5637 augmentcall->addParamAttr(
5638 pair.first, Attribute::getWithByValType(augmentcall->getContext(),
5642 for (
auto &pair : structAttrs) {
5643 for (
auto val : pair.second)
5644 augmentcall->addParamAttr(pair.first, val);
5647 if (!augmentcall->getType()->isVoidTy())
5648 augmentcall->setName(call.getName() +
"_augmented");
5651 auto tval = *tapeIdx;
5652 tape = (tval == -1) ? augmentcall
5653 : BuilderZ.CreateExtractValue(
5654 augmentcall, {(unsigned)tval},
"subcache");
5655 if (tape->getType()->isEmptyTy()) {
5656 auto tt = tape->getType();
5657 gutils->
erase(cast<Instruction>(tape));
5658 tape = UndefValue::get(tt);
5667 Value *dcall =
nullptr;
5669 assert(augmentcall);
5670 auto rval = *returnIdx;
5671 dcall = (rval < 0) ? augmentcall
5672 : BuilderZ.CreateExtractValue(augmentcall,
5678 assert(dcall->getType() == call.getType());
5682 if (!call.getType()->isFPOrFPVectorTy() && TR.
anyPointer(&call)) {
5689 assert(dcall->getType() == call.getType());
5692 if (isa<Instruction>(dcall) && !isa<PHINode>(dcall)) {
5693 cast<Instruction>(dcall)->takeName(newCall);
5699 std::map<UsageKey, bool> Seen;
5700 bool primalNeededInReverse =
false;
5703 if (pair.first == &call) {
5704 primalNeededInReverse =
true;
5710 if (!primalNeededInReverse) {
5715 primalNeededInReverse =
5720 if (primalNeededInReverse)
5724 BuilderZ.SetInsertPoint(newCall->getNextNode());
5725 gutils->
erase(newCall);
5727 BuilderZ.SetInsertPoint(BuilderZ.GetInsertPoint()->getNextNode());
5742 auto tval = *tapeIdx;
5743#if LLVM_VERSION_MAJOR >= 18
5744 auto It = BuilderZ.GetInsertPoint();
5745 It.setHeadBit(
true);
5746 BuilderZ.SetInsertPoint(It);
5748 tape = BuilderZ.CreatePHI(
5749 (tapeIdx == -1) ? FT->getReturnType()
5750 : cast<StructType>(FT->getReturnType())
5751 ->getElementType(tval),
5760 Intrinsic::ID ID = Intrinsic::not_intrinsic;
5767#if LLVM_VERSION_MAJOR >= 18
5768 auto It = BuilderZ.GetInsertPoint();
5769 It.setHeadBit(
true);
5770 BuilderZ.SetInsertPoint(It);
5775 raw_string_ostream ss(
str);
5776 ss <<
"Failed to compute consistent cache index for operation: "
5783 EmitFailure(
"GetIndexError", call.getDebugLoc(), &call,
5788 cachereplace = newCall;
5790 cachereplace = BuilderZ.CreatePHI(
5791 call.getType(), 1, call.getName() +
"_tmpcacheB");
5797#if LLVM_VERSION_MAJOR >= 18
5798 auto It = BuilderZ.GetInsertPoint();
5799 It.setHeadBit(
true);
5800 BuilderZ.SetInsertPoint(It);
5802 auto pn = BuilderZ.CreatePHI(
5803 call.getType(), 1, (call.getName() +
"_replacementE").str());
5810 BuilderZ.SetInsertPoint(BuilderZ.GetInsertPoint()->getNextNode());
5817 auto placeholder = cast<PHINode>(&*ifound->second);
5824 bool hasNonReturnUse =
false;
5825 for (
auto use : call.users()) {
5827 !isa<ReturnInst>(use)) {
5828 hasNonReturnUse =
true;
5832 if (subcheck && hasNonReturnUse) {
5834 Value *newip =
nullptr;
5840 raw_string_ostream ss(
str);
5841 ss <<
"Did not have return index set when differentiating "
5843 ss <<
" call" << call <<
"\n";
5844 ss <<
" augmentcall" << *augmentcall <<
"\n";
5850 EmitFailure(
"GetIndexError", call.getDebugLoc(), &call,
5853 placeholder->replaceAllUsesWith(
5854 UndefValue::get(placeholder->getType()));
5855 if (placeholder == &*BuilderZ.GetInsertPoint()) {
5856 BuilderZ.SetInsertPoint(placeholder->getNextNode());
5858 gutils->
erase(placeholder);
5860 auto drval = *differetIdx;
5863 : BuilderZ.CreateExtractValue(augmentcall,
5865 call.getName() +
"'ac");
5866 assert(newip->getType() == placeholder->getType());
5867 placeholder->replaceAllUsesWith(newip);
5868 if (placeholder == &*BuilderZ.GetInsertPoint()) {
5869 BuilderZ.SetInsertPoint(placeholder->getNextNode());
5871 gutils->
erase(placeholder);
5874 newip = placeholder;
5884 if (placeholder == &*BuilderZ.GetInsertPoint()) {
5885 BuilderZ.SetInsertPoint(placeholder->getNextNode());
5887 gutils->
erase(placeholder);
5891 if (fnandtapetype && fnandtapetype->
tapeType &&
5897 auto tapep = BuilderZ.CreatePointerCast(
5900 cast<PointerType>(tape->getType())->getAddressSpace()));
5902 BuilderZ.CreateLoad(fnandtapetype->
tapeType, tapep,
"tapeld");
5903 truetape->setMetadata(
"enzyme_mustcache",
5904 MDNode::get(truetape->getContext(), {}));
5912 auto placeholder = cast<PHINode>(&*ifound->second);
5914 gutils->
erase(placeholder);
5917 subretused && !call.doesNotAccessMemory()) {
5921 assert(!replaceFunction);
5922#if LLVM_VERSION_MAJOR >= 18
5923 auto It = BuilderZ.GetInsertPoint();
5924 It.setHeadBit(
true);
5925 BuilderZ.SetInsertPoint(It);
5927 cachereplace = BuilderZ.CreatePHI(call.getType(), 1,
5928 call.getName() +
"_cachereplace2");
5930 BuilderZ, cachereplace,
5933#if LLVM_VERSION_MAJOR >= 18
5934 auto It = BuilderZ.GetInsertPoint();
5935 It.setHeadBit(
true);
5936 BuilderZ.SetInsertPoint(It);
5938 auto pn = BuilderZ.CreatePHI(call.getType(), 1,
5939 call.getName() +
"_replacementC");
5945 if (!subretused && !replaceFunction)
5954 IRBuilder<> Builder2(&call);
5957 Value *newcalled =
nullptr;
5958 FunctionType *FT =
nullptr;
5965 for (
size_t i = 0; i < argsInverted.size(); i++) {
5977 .todiff = cast<Function>(called),
5978 .retType = subretType,
5979 .constant_args = argsInverted,
5980 .subsequent_calls_may_write = subsequent_calls_may_write,
5981 .overwritten_args = overwritten_args,
5982 .returnUsed = replaceFunction && subretused,
5983 .shadowReturnUsed = shadowReturnUsed && replaceFunction,
5988 .additionalType = tape ? tape->getType() :
nullptr,
5989 .forceAnonymousTape =
false,
5990 .typeInfo = nextTypeInfo,
5996 FT = cast<Function>(newcalled)->getFunctionType();
6001 auto callval = call.getCalledOperand();
6005 llvm::raw_string_ostream ss(s);
6006 ss << *gutils->
oldFunc <<
"\n";
6007 ss <<
"in Mode: " <<
to_string(Mode) <<
"\n";
6008 ss <<
" orig: " << call <<
" callval: " << *callval <<
"\n";
6009 ss <<
" constant function being called, but active call instruction\n";
6020 auto ft = call.getFunctionType();
6025 res.first.push_back(
getInt8PtrTy(newcalled->getContext()));
6026 FT = FunctionType::get(
6027 StructType::get(newcalled->getContext(), res.second), res.first,
6030 newcalled = Builder2.CreatePointerCast(newcalled,
getUnqual(fptype));
6031 newcalled = Builder2.CreateLoad(
6032 fptype, Builder2.CreateConstGEP1_64(fptype, newcalled, 1));
6036 args.push_back(
diffe(&call, Builder2));
6040 auto ntape = gutils->
lookupM(tape, Builder2);
6042 assert(ntape->getType());
6043 args.push_back(ntape);
6051 auto NC = dyn_cast<Function>(newcalled);
6052 llvm::errs() << *gutils->
oldFunc <<
"\n";
6053 llvm::errs() << *gutils->
newFunc <<
"\n";
6055 llvm::errs() <<
" trying to call " << NC->getName() <<
" " << *FT
6058 llvm::errs() <<
" trying to call " << *newcalled <<
" " << *FT <<
"\n";
6060 for (
unsigned i = 0; i < args.size(); ++i) {
6062 assert(args[i]->getType());
6063 llvm::errs() <<
"args[" << i <<
"] = " << *args[i]
6064 <<
" FT:" << *FT->getParamType(i) <<
"\n";
6066 assert(0 &&
"calling with wrong number of arguments");
6070 if (args.size() != FT->getNumParams())
6073 for (
unsigned i = 0; i < args.size(); ++i) {
6074 if (args[i]->getType() == FT->getParamType(i))
6076 else if (!call.getCalledFunction())
6077 args[i] = Builder2.CreateBitCast(args[i], FT->getParamType(i));
6083 Builder2.CreateCall(FT, newcalled, args,
6085 &call, BundleTypes, Builder2,
true));
6086 diffes->setCallingConv(call.getCallingConv());
6089 for (
auto pair : gradByVal) {
6090 diffes->addParamAttr(pair.first, Attribute::getWithByValType(
6091 diffes->getContext(), pair.second));
6094 for (
auto &pair : structAttrs) {
6095 for (
auto val : pair.second)
6096 diffes->addParamAttr(pair.first, val);
6099 unsigned structidx = 0;
6100 if (replaceFunction) {
6103 if (shadowReturnUsed)
6107 for (
unsigned i = 0; i < call.arg_size(); ++i) {
6109 Value *diffeadd = Builder2.CreateExtractValue(diffes, {structidx});
6114 if (call.getArgOperand(i)->getType()->isSized())
6115 size = (gutils->
newFunc->getParent()
6117 .getTypeSizeInBits(call.getArgOperand(i)->getType()) +
6121 addToDiffe(call.getArgOperand(i), diffeadd, Builder2,
6127 if (diffes->getType()->isVoidTy()) {
6128 if (structidx != 0) {
6129 llvm::errs() << *gutils->
oldFunc->getParent() <<
"\n";
6130 llvm::errs() <<
"diffes: " << *diffes <<
" structidx=" << structidx
6131 <<
" subretused=" << subretused
6132 <<
" shadowReturnUsed=" << shadowReturnUsed <<
"\n";
6134 assert(structidx == 0);
6136 assert(cast<StructType>(diffes->getType())->getNumElements() ==
6142 Constant::getNullValue(gutils->
getShadowType(call.getType())),
6145 if (replaceFunction) {
6151 auto placeholder = cast<PHINode>(&*ifound->second);
6153 if (shadowReturnUsed) {
6155 auto dretval = cast<Instruction>(
6156 Builder2.CreateExtractValue(diffes, {subretused ? 1U : 0U}));
6158 assert(!subretused);
6162 gutils->
erase(placeholder);
6165 Instruction *retval =
nullptr;
6168 retval = cast<Instruction>(Builder2.CreateExtractValue(diffes, {0}));
6177 SmallPtrSet<Value *, 2> postCreateSet(postCreate.begin(),
6179 for (
auto a : postCreate) {
6180 a->moveBefore(*Builder2.GetInsertBlock(), Builder2.GetInsertPoint());
6181 for (
size_t i = 0; i < a->getNumOperands(); i++) {
6182 auto op = dyn_cast<Instruction>(a->getOperand(i));
6183 if (!op || postCreateSet.count(op))
6186 IRBuilder<> BuilderA(a);
6187 a->setOperand(i, gutils->
lookupM(op, BuilderA));
6196 gutils->
erase(newCall);
6203 Value *dcall =
nullptr;
6204 assert(cachereplace->getType() == call.getType());
6205 assert(dcall ==
nullptr);
6206 dcall = cachereplace;
6213 if (!call.getType()->isFPOrFPVectorTy() && TR.
anyPointer(&call)) {
6220 assert(dcall->getType() == call.getType());
6221 newCall->replaceAllUsesWith(dcall);
6222 if (isa<Instruction>(dcall) && !isa<PHINode>(dcall)) {
6223 cast<Instruction>(dcall)->takeName(&call);
6225 gutils->
erase(newCall);
6249 using namespace llvm;
6258 if (
startsWith(funcName, (
"llvm.intel.subscript"))) {
6259 assert(isa<IntrinsicInst>(call));
6264 if (funcName ==
"llvm.enzyme.lifetime_start") {
6268 if (funcName ==
"llvm.enzyme.lifetime_end") {
6269 SmallVector<Value *, 2> orig_ops(call.getNumOperands());
6270 for (
unsigned i = 0; i < call.getNumOperands(); ++i) {
6271 orig_ops[i] = call.getOperand(i);
6279 IRBuilder<> BuilderZ(newCall);
6280 BuilderZ.setFastMathFlags(
getFast());
6282 if (overwritten_args_map.find(&call) == overwritten_args_map.end() &&
6285 llvm::errs() <<
" call: " << call <<
"\n";
6286 for (
auto &pair : overwritten_args_map) {
6287 llvm::errs() <<
" + " << *pair.first <<
"\n";
6291 assert(overwritten_args_map.find(&call) != overwritten_args_map.end() ||
6294 const bool subsequent_calls_may_write =
6298 : overwritten_args_map.find(&call)->second.first;
6299 const std::vector<bool> &overwritten_args =
6302 ? std::vector<bool>()
6303 : overwritten_args_map.find(&call)->second.second;
6307 bool subretused =
false;
6308 bool shadowReturnUsed =
false;
6313 &call, &subretused, &shadowReturnUsed, smode);
6319 Value *invertedReturn =
nullptr;
6322 invertedReturn = cast<PHINode>(&*ifound->second);
6325 Value *normalReturn = subretused ? newCall :
nullptr;
6327 bool noMod = found->second(BuilderZ, &call, *gutils, normalReturn,
6331 assert(normalReturn == newCall);
6337 auto placeholder = cast<PHINode>(&*ifound->second);
6338 if (invertedReturn && invertedReturn != placeholder) {
6339 if (invertedReturn->getType() !=
6341 llvm::errs() <<
" o: " << call <<
"\n";
6342 llvm::errs() <<
" ot: " << *call.getType() <<
"\n";
6343 llvm::errs() <<
" ir: " << *invertedReturn <<
"\n";
6344 llvm::errs() <<
" irt: " << *invertedReturn->getType() <<
"\n";
6345 llvm::errs() <<
" p: " << *placeholder <<
"\n";
6346 llvm::errs() <<
" PT: " << *placeholder->getType() <<
"\n";
6347 llvm::errs() <<
" newCall: " << *newCall <<
"\n";
6348 llvm::errs() <<
" newCallT: " << *newCall->getType() <<
"\n";
6350 assert(invertedReturn->getType() ==
6352 placeholder->replaceAllUsesWith(invertedReturn);
6353 gutils->
erase(placeholder);
6355 std::make_pair((
const Value *)&call,
6359 gutils->
erase(placeholder);
6363 if (normalReturn && normalReturn != newCall) {
6364 assert(normalReturn->getType() == newCall->getType());
6366 gutils->
erase(newCall);
6377 IRBuilder<> Builder2(&call);
6382 Value *invertedReturn =
nullptr;
6384 PHINode *placeholder =
nullptr;
6386 placeholder = cast<PHINode>(&*ifound->second);
6387 if (shadowReturnUsed)
6388 invertedReturn = placeholder;
6391 Value *normalReturn = subretused ? newCall :
nullptr;
6393 Value *tape =
nullptr;
6395 Type *tapeType =
nullptr;
6399 bool noMod = found->second.first(BuilderZ, &call, *gutils,
6400 normalReturn, invertedReturn, tape);
6403 assert(normalReturn == newCall);
6407 tapeType = tape->getType();
6412 assert(augmentedReturn);
6413 auto subaugmentations =
6414 (std::map<const llvm::CallInst *, AugmentedReturn *>
6427 assert(augmentedReturn);
6428 auto subaugmentations =
6429 (std::map<const llvm::CallInst *, AugmentedReturn *>
6431 auto fd = subaugmentations->find(&call);
6432 assert(fd != subaugmentations->end());
6438 tapeType = (llvm::Type *)fd->second;
6440#if LLVM_VERSION_MAJOR >= 18
6441 auto It = BuilderZ.GetInsertPoint();
6442 It.setHeadBit(
true);
6443 BuilderZ.SetInsertPoint(It);
6445 tape = BuilderZ.CreatePHI(tapeType, 0);
6451 tape = gutils->
lookupM(tape, Builder2);
6457 if (!shadowReturnUsed) {
6459 gutils->
erase(placeholder);
6461 if (invertedReturn && invertedReturn != placeholder) {
6462 if (invertedReturn->getType() !=
6464 llvm::errs() <<
" o: " << call <<
"\n";
6465 llvm::errs() <<
" ot: " << *call.getType() <<
"\n";
6466 llvm::errs() <<
" ir: " << *invertedReturn <<
"\n";
6467 llvm::errs() <<
" irt: " << *invertedReturn->getType() <<
"\n";
6468 llvm::errs() <<
" p: " << *placeholder <<
"\n";
6469 llvm::errs() <<
" PT: " << *placeholder->getType() <<
"\n";
6470 llvm::errs() <<
" newCall: " << *newCall <<
"\n";
6471 llvm::errs() <<
" newCallT: " << *newCall->getType() <<
"\n";
6473 assert(invertedReturn->getType() ==
6475 placeholder->replaceAllUsesWith(invertedReturn);
6476 gutils->
erase(placeholder);
6478 BuilderZ, invertedReturn,
6485 if (placeholder->getType() != invertedReturn->getType())
6486 llvm::errs() <<
" place: " << *placeholder
6487 <<
" invRet: " << *invertedReturn;
6488 placeholder->replaceAllUsesWith(invertedReturn);
6489 gutils->
erase(placeholder);
6494 std::make_pair((
const Value *)&call,
6499 bool primalNeededInReverse;
6504 std::map<UsageKey, bool> Seen;
6508 primalNeededInReverse =
6512 if (subretused && primalNeededInReverse) {
6513 if (normalReturn != newCall) {
6514 assert(normalReturn->getType() == newCall->getType());
6516 BuilderZ.SetInsertPoint(newCall->getNextNode());
6517 gutils->
erase(newCall);
6520 BuilderZ, normalReturn,
6523 if (normalReturn && normalReturn != newCall) {
6524 assert(normalReturn->getType() == newCall->getType());
6527 BuilderZ.SetInsertPoint(newCall->getNextNode());
6528 gutils->
erase(newCall);
6537 if (funcName ==
"__kmpc_fork_call") {
6544 subsequent_calls_may_write, overwritten_args,
6548 bool useConstantFallback =
6551 if (!useConstantFallback) {
6555 "Call was deduced inactive but still doing differential "
6556 "rewrite as it may escape an allocation",
6560 if (useConstantFallback) {
6564 PHINode *placeholder = cast<PHINode>(&*found->second);
6566 gutils->
erase(placeholder);
6571 noFree |= call.hasFnAttr(Attribute::NoFree);
6572 if (!noFree && called) {
6573 noFree |= called->hasFnAttribute(Attribute::NoFree);
6576 std::map<UsageKey, bool> CacheResults;
6579 cast<Instruction>(pair.first))) {
6585 bool mayActiveFree =
false;
6586 for (
unsigned i = 0; i < call.arg_size(); ++i) {
6587 Value *a = call.getOperand(i);
6592 if (!TR.
query(a)[{-1}].isPossiblePointer())
6596 mayActiveFree =
true;
6604 bool isAllocation =
false;
6605 for (
auto objv = obj;;) {
6607 isAllocation =
true;
6610 if (
auto objC = dyn_cast<CallBase>(objv))
6613 SmallPtrSet<Value *, 1> set;
6614 for (
auto &B : *F) {
6615 if (
auto RI = dyn_cast<ReturnInst>(B.getTerminator())) {
6617 if (isa<ConstantPointerNull>(v))
6622 if (set.size() == 1) {
6623 objv = *set.begin();
6629 if (!isAllocation) {
6630 mayActiveFree =
true;
6636 if (!found->second) {
6637 auto CacheResults2(CacheResults);
6642 CacheResults2, oldUnreachable)) {
6643 mayActiveFree =
true;
6650 auto CacheResults2(CacheResults);
6654 CacheResults2, oldUnreachable)) {
6655 mayActiveFree =
true;
6663 auto callval = call.getCalledOperand();
6664 if (!isa<Constant>(callval))
6685 (call.mayWriteToMemory() ||
6689 std::map<UsageKey, bool> Seen;
6690 bool primalNeededInReverse =
false;
6693 if (pair.first == &call) {
6694 primalNeededInReverse =
true;
6700 if (!primalNeededInReverse) {
6705 primalNeededInReverse =
6710 if (primalNeededInReverse) {
6730 if (call.mayWriteToMemory() &&
6738 if (!call.mayWriteToMemory() && !subretused) {
6747 call, called, subsequent_calls_may_write, overwritten_args,
6748 shadowReturnUsed, subretType, subretused);