37 llvm::StringRef funcName) {
44 BuilderZ.setFastMathFlags(
getFast());
47 if (funcName ==
"PMPI_Isend" || funcName ==
"MPI_Isend" ||
48 funcName ==
"PMPI_Irecv" || funcName ==
"MPI_Irecv") {
54 Value *d_req = gutils->
invertPointerM(call.getOperand(6), BuilderZ);
55 if (d_req->getType()->isIntegerTy()) {
56 d_req = BuilderZ.CreateIntToPtr(
60 auto i64 = Type::getInt64Ty(call.getContext());
67 d_req = BuilderZ.CreateBitCast(d_req,
getUnqual(impialloc->getType()));
68 Value *d_req_prev = BuilderZ.CreateLoad(impialloc->getType(), d_req);
70 BuilderZ.CreatePointerCast(d_req_prev,
73 BuilderZ.CreateStore(impialloc, d_req);
75 if (funcName ==
"MPI_Isend" || funcName ==
"PMPI_Isend") {
78 BuilderZ, call.getType(), called);
80 auto len_arg = BuilderZ.CreateZExtOrTrunc(
82 Type::getInt64Ty(call.getContext()));
83 len_arg = BuilderZ.CreateMul(
85 BuilderZ.CreateZExtOrTrunc(tysize,
86 Type::getInt64Ty(call.getContext())),
89 Value *firstallocation =
91 len_arg,
"mpirecv_malloccache");
93 BuilderZ, impialloc, impi));
96 Value *ibuf = gutils->
invertPointerM(call.getOperand(0), BuilderZ);
97 if (ibuf->getType()->isIntegerTy())
99 BuilderZ.CreateIntToPtr(ibuf,
getInt8PtrTy(call.getContext()));
100 BuilderZ.CreateStore(
104 BuilderZ.CreateStore(
105 BuilderZ.CreateZExtOrTrunc(
110 if (dataType->getType()->isIntegerTy())
111 dataType = BuilderZ.CreateIntToPtr(
113 BuilderZ.CreateStore(
114 BuilderZ.CreatePointerCast(dataType,
118 BuilderZ.CreateStore(
119 BuilderZ.CreateZExtOrTrunc(
123 BuilderZ.CreateStore(
124 BuilderZ.CreateZExtOrTrunc(
129 if (comm->getType()->isIntegerTy())
130 comm = BuilderZ.CreateIntToPtr(comm,
132 BuilderZ.CreateStore(
133 BuilderZ.CreatePointerCast(comm,
getInt8PtrTy(call.getContext())),
136 BuilderZ.CreateStore(
138 Type::getInt8Ty(impialloc->getContext()),
139 (funcName ==
"MPI_Isend" || funcName ==
"PMPI_Isend")
147 IRBuilder<> Builder2(&call);
150 Type *statusType =
nullptr;
151#if LLVM_VERSION_MAJOR < 17
152 if (Function *recvfn = called->getParent()->getFunction(
154 auto statusArg = recvfn->arg_end();
156 if (
auto PT = dyn_cast<PointerType>(statusArg->getType()))
157 statusType = PT->getPointerElementType();
160 if (statusType ==
nullptr) {
161 statusType = ArrayType::get(Type::getInt8Ty(call.getContext()), 24);
162 llvm::errs() <<
" warning could not automatically determine mpi "
163 "status type, assuming [24 x i8]\n";
169 if (d_req->getType()->isIntegerTy()) {
171 Builder2.CreateIntToPtr(d_req,
getInt8PtrTy(call.getContext()));
175 Value *helper = Builder2.CreatePointerCast(d_req,
getUnqual(helperTy));
176 helper = Builder2.CreateLoad(helperTy, helper);
178 auto i64 = Type::getInt64Ty(call.getContext());
180 Value *firstallocation;
181 firstallocation = Builder2.CreateLoad(
184 Value *len_arg =
nullptr;
185 if (
auto C = dyn_cast<Constant>(
187 len_arg = Builder2.CreateZExtOrTrunc(C, i64);
189 len_arg = Builder2.CreateLoad(
192 Value *tysize =
nullptr;
193 if (
auto C = dyn_cast<Constant>(
197 tysize = Builder2.CreateLoad(
203 prev = Builder2.CreateLoad(
207 Builder2.CreateStore(prev, Builder2.CreatePointerCast(
213 tysize =
MPI_TYPE_SIZE(tysize, Builder2, call.getType(), called);
215 Value *args[] = { req,
217 .CreateAlloca(statusType)};
218 FunctionCallee waitFunc =
nullptr;
222 if (Function *recvfn = called->getParent()->getFunction(
224 auto statusArg = recvfn->arg_end();
226 if (statusArg->getType()->isIntegerTy())
227 args[1] = Builder2.CreatePtrToInt(args[1], statusArg->getType());
229 args[1] = Builder2.CreateBitCast(args[1], statusArg->getType());
234 Type *types[
sizeof(args) /
sizeof(*args)];
235 for (
size_t i = 0; i <
sizeof(args) /
sizeof(*args); i++)
236 types[i] = args[i]->getType();
237 FunctionType *FT = FunctionType::get(call.getType(), types,
false);
238 waitFunc = called->getParent()->getOrInsertFunction(
258 auto fcall = Builder2.CreateCall(waitFunc, args, ReqDefs);
260 if (
auto F = dyn_cast<Function>(waitFunc.getCallee()))
261 fcall->setCallingConv(F->getCallingConv());
262 len_arg = Builder2.CreateMul(
264 Builder2.CreateZExtOrTrunc(tysize,
265 Type::getInt64Ty(Builder2.getContext())),
267 if (funcName ==
"MPI_Irecv" || funcName ==
"PMPI_Irecv") {
269 ConstantInt::get(Type::getInt8Ty(Builder2.getContext()), 0);
270 auto volatile_arg = ConstantInt::getFalse(Builder2.getContext());
272 auto dbuf = firstallocation;
273 Value *nargs[] = {dbuf, val_arg, len_arg, volatile_arg};
274 Type *tys[] = {dbuf->getType(), len_arg->getType()};
276 auto memset = cast<CallInst>(Builder2.CreateCall(
280 memset->addParamAttr(0, Attribute::NonNull);
281 }
else if (funcName ==
"MPI_Isend" || funcName ==
"PMPI_Isend") {
288 shadow, len_arg, Builder2, BufferDefs);
294 assert(0 &&
"illegal mpi");
300 IRBuilder<> Builder2(&call);
306 Value *buf = gutils->
invertPointerM(call.getOperand(0), Builder2);
312 Value *request = gutils->
invertPointerM(call.getOperand(6), Builder2);
331 auto callval = call.getCalledOperand();
333 Builder2.CreateCall(call.getFunctionType(), callval, args, Defs);
342 if (funcName ==
"MPI_Wait" || funcName ==
"PMPI_Wait") {
343 Value *d_reqp =
nullptr;
348 Value *d_req = gutils->
invertPointerM(call.getOperand(0), BuilderZ);
350 if (req->getType()->isIntegerTy()) {
351 req = BuilderZ.CreateIntToPtr(
355 Value *isNull =
nullptr;
356 if (
auto GV = gutils->
newFunc->getParent()->getNamedValue(
357 "ompi_request_null")) {
358 Value *reql = BuilderZ.CreatePointerCast(req,
getUnqual(GV->getType()));
359 reql = BuilderZ.CreateLoad(GV->getType(), reql);
360 isNull = BuilderZ.CreateICmpEQ(reql, GV);
363 if (d_req->getType()->isIntegerTy()) {
364 d_req = BuilderZ.CreateIntToPtr(
368 d_reqp = BuilderZ.CreateLoad(
374 Constant::getNullValue(d_reqp->getType()), d_reqp);
375 if (
auto I = dyn_cast<Instruction>(d_reqp))
382 IRBuilder<> Builder2(&call);
390 d_reqp = BuilderZ.CreatePHI(
getUnqual(impi), 0);
395 d_reqp =
lookup(d_reqp, Builder2);
397 Value *isNull = Builder2.CreateICmpEQ(
398 d_reqp, Constant::getNullValue(d_reqp->getType()));
400 BasicBlock *currentBlock = Builder2.GetInsertBlock();
402 currentBlock, currentBlock->getName() +
"_nonnull");
404 nonnullBlock, currentBlock->getName() +
"_end",
407 Builder2.CreateCondBr(isNull, endBlock, nonnullBlock);
408 Builder2.SetInsertPoint(nonnullBlock);
410 Value *cache = Builder2.CreateLoad(impi, d_reqp);
421 Type *types[
sizeof(args) /
sizeof(*args) - 1];
422 for (
size_t i = 0; i <
sizeof(args) /
sizeof(*args) - 1; i++)
423 types[i] = args[i]->getType();
425 *called->getParent(), types, call.getOperand(0)->getType(),
432 Builder2.CreateCall(dwait, args,
434 &call, {ValueType::Shadow, ValueType::None},
436 cal->setCallingConv(dwait->getCallingConv());
438 cal->addFnAttr(Attribute::AlwaysInline);
439 Builder2.CreateBr(endBlock);
443 SmallVector<BasicBlock *, 4> &vec =
446 vec.push_back(endBlock);
448 Builder2.SetInsertPoint(endBlock);
451 IRBuilder<> Builder2(&call);
456 Value *request = gutils->
invertPointerM(call.getArgOperand(0), Builder2);
457 Value *status = gutils->
invertPointerM(call.getArgOperand(1), Builder2);
459 Value *args[] = { request,
466 auto callval = call.getCalledOperand();
468 Builder2.CreateCall(call.getFunctionType(), callval, args, Defs);
476 if (funcName ==
"MPI_Waitall" || funcName ==
"PMPI_Waitall") {
477 Value *d_reqp =
nullptr;
484 Value *d_req = gutils->
invertPointerM(call.getOperand(1), BuilderZ);
486 if (req->getType()->isIntegerTy()) {
487 req = BuilderZ.CreateIntToPtr(
491 if (d_req->getType()->isIntegerTy()) {
492 d_req = BuilderZ.CreateIntToPtr(
498 {count->getType(), req->getType(), d_req->getType()}, reqType);
500 d_reqp = BuilderZ.CreateCall(dsave, {count, req, d_req});
501 cast<CallInst>(d_reqp)->setCallingConv(dsave->getCallingConv());
502 cast<CallInst>(d_reqp)->setDebugLoc(
509 IRBuilder<> Builder2(&call);
519 d_reqp = BuilderZ.CreatePHI(
getUnqual(reqType), 0);
524 d_reqp =
lookup(d_reqp, Builder2);
526 BasicBlock *currentBlock = Builder2.GetInsertBlock();
528 currentBlock, currentBlock->getName() +
"_loop");
530 loopBlock, currentBlock->getName() +
"_nonnull");
532 nonnullBlock, currentBlock->getName() +
"_eloop");
533 BasicBlock *endBlock =
537 Builder2.CreateCondBr(
538 Builder2.CreateICmpNE(count,
539 ConstantInt::get(count->getType(), 0,
false)),
540 loopBlock, endBlock);
542 Builder2.SetInsertPoint(loopBlock);
543 auto idx = Builder2.CreatePHI(count->getType(), 2);
544 idx->addIncoming(ConstantInt::get(count->getType(), 0,
false),
546 Value *inc = Builder2.CreateAdd(
547 idx, ConstantInt::get(count->getType(), 1,
false),
"",
true,
true);
548 idx->addIncoming(inc, eloopBlock);
550 Value *idxs[] = {idx};
551 Value *req = Builder2.CreateInBoundsGEP(reqType, req_orig, idxs);
552 Value *d_req = Builder2.CreateInBoundsGEP(reqType, d_reqp, idxs);
554 d_req = Builder2.CreateLoad(
558 Value *isNull = Builder2.CreateICmpEQ(
559 d_req, Constant::getNullValue(d_req->getType()));
561 Builder2.CreateCondBr(isNull, eloopBlock, nonnullBlock);
562 Builder2.SetInsertPoint(nonnullBlock);
564 Value *cache = Builder2.CreateLoad(impi, d_req);
575 Type *types[
sizeof(args) /
sizeof(*args) - 1];
576 for (
size_t i = 0; i <
sizeof(args) /
sizeof(*args) - 1; i++)
577 types[i] = args[i]->getType();
579 *called->getParent(), types, req->getType(), called->getName());
584 auto cal = Builder2.CreateCall(
587 {ValueType::None, ValueType::None,
588 ValueType::None, ValueType::None,
589 ValueType::None, ValueType::None,
592 cal->setCallingConv(dwait->getCallingConv());
594 cal->addFnAttr(Attribute::AlwaysInline);
595 Builder2.CreateBr(eloopBlock);
597 Builder2.SetInsertPoint(eloopBlock);
598 Builder2.CreateCondBr(Builder2.CreateICmpEQ(inc, count), endBlock,
603 SmallVector<BasicBlock *, 4> &vec =
606 vec.push_back(endBlock);
608 Builder2.SetInsertPoint(endBlock);
614 IRBuilder<> Builder2(&call);
619 Value *array_of_requests =
621 if (array_of_requests->getType()->isIntegerTy()) {
622 array_of_requests = Builder2.CreateIntToPtr(
637 auto callval = call.getCalledOperand();
639 Builder2.CreateCall(call.getFunctionType(), callval, args, Defs);
647 if (funcName ==
"MPI_Send" || funcName ==
"MPI_Ssend" ||
648 funcName ==
"PMPI_Send" || funcName ==
"PMPI_Ssend") {
656 IRBuilder<> Builder2 =
657 forwardMode ? IRBuilder<>(&call) : IRBuilder<>(call.getParent());
664 Value *shadow = gutils->
invertPointerM(call.getOperand(0), Builder2);
666 shadow =
lookup(shadow, Builder2);
667 Value *shadowOrig = shadow;
668 if (shadow->getType()->isIntegerTy())
670 Builder2.CreateIntToPtr(shadow,
getInt8PtrTy(call.getContext()));
672 Type *statusType =
nullptr;
673#if LLVM_VERSION_MAJOR < 17
674 if (called->getContext().supportsTypedPointers()) {
675 if (Function *recvfn = called->getParent()->getFunction(
677 auto statusArg = recvfn->arg_end();
679 if (
auto PT = dyn_cast<PointerType>(statusArg->getType()))
680 statusType = PT->getPointerElementType();
684 if (statusType ==
nullptr) {
685 statusType = ArrayType::get(Type::getInt8Ty(call.getContext()), 24);
686 llvm::errs() <<
" warning could not automatically determine mpi "
687 "status type, assuming [24 x i8]\n";
692 count =
lookup(count, Builder2);
696 datatype =
lookup(datatype, Builder2);
700 src =
lookup(src, Builder2);
704 tag =
lookup(tag, Builder2);
708 comm =
lookup(comm, Builder2);
726 auto callval = call.getCalledOperand();
727 Builder2.CreateCall(call.getFunctionType(), callval, args, Defs);
741 Value *tysize =
MPI_TYPE_SIZE(datatype, Builder2, call.getType(), called);
743 auto len_arg = Builder2.CreateZExtOrTrunc(
744 args[1], Type::getInt64Ty(call.getContext()));
746 Builder2.CreateMul(len_arg,
747 Builder2.CreateZExtOrTrunc(
748 tysize, Type::getInt64Ty(call.getContext())),
751 Value *firstallocation =
753 len_arg,
"mpirecv_malloccache");
754 args[0] = firstallocation;
756 Type *types[
sizeof(args) /
sizeof(*args)];
757 for (
size_t i = 0; i <
sizeof(args) /
sizeof(*args); i++)
758 types[i] = args[i]->getType();
759 FunctionType *FT = FunctionType::get(call.getType(), types,
false);
761 Builder2.SetInsertPoint(Builder2.GetInsertBlock());
769 auto fcall = Builder2.CreateCall(
770 called->getParent()->getOrInsertFunction(
773 fcall->setCallingConv(call.getCallingConv());
776 shadow, len_arg, Builder2, BufferDefs);
787 if (funcName ==
"MPI_Recv" || funcName ==
"PMPI_Recv") {
795 IRBuilder<> Builder2 =
796 forwardMode ? IRBuilder<>(&call) : IRBuilder<>(call.getParent());
803 Value *shadow = gutils->
invertPointerM(call.getOperand(0), Builder2);
805 shadow =
lookup(shadow, Builder2);
809 count =
lookup(count, Builder2);
813 datatype =
lookup(datatype, Builder2);
817 source =
lookup(source, Builder2);
821 tag =
lookup(tag, Builder2);
825 comm =
lookup(comm, Builder2);
829 Value *args[] = {shadow, count, datatype, source, tag, comm, status};
836 Builder2, !forwardMode);
838 auto callval = call.getCalledOperand();
840 Builder2.CreateCall(call.getFunctionType(), callval, args, Defs);
844 Value *args[] = {shadow, count, datatype, source, tag, comm};
851 Builder2, !forwardMode);
853 Type *types[
sizeof(args) /
sizeof(*args)];
854 for (
size_t i = 0; i <
sizeof(args) /
sizeof(*args); i++)
855 types[i] = args[i]->getType();
856 FunctionType *FT = FunctionType::get(call.getType(), types,
false);
858 auto fcall = Builder2.CreateCall(
859 called->getParent()->getOrInsertFunction(
862 fcall->setCallingConv(call.getCallingConv());
865 Builder2.CreateBitCast(args[0],
getInt8PtrTy(call.getContext()));
866 auto val_arg = ConstantInt::get(Type::getInt8Ty(call.getContext()), 0);
867 auto len_arg = Builder2.CreateZExtOrTrunc(
868 args[1], Type::getInt64Ty(call.getContext()));
869 auto tysize =
MPI_TYPE_SIZE(datatype, Builder2, call.getType(), called);
871 Builder2.CreateMul(len_arg,
872 Builder2.CreateZExtOrTrunc(
873 tysize, Type::getInt64Ty(call.getContext())),
875 auto volatile_arg = ConstantInt::getFalse(call.getContext());
877 Value *nargs[] = {dst_arg, val_arg, len_arg, volatile_arg};
878 Type *tys[] = {dst_arg->getType(), len_arg->getType()};
885 auto memset = cast<CallInst>(Builder2.CreateCall(
887 Intrinsic::memset, tys),
889 memset->addParamAttr(0, Attribute::NonNull);
902 if (funcName ==
"MPI_Bcast" || funcName ==
"PMPI_Bcast") {
910 IRBuilder<> Builder2 =
911 forwardMode ? IRBuilder<>(&call) : IRBuilder<>(call.getParent());
918 Value *shadow = gutils->
invertPointerM(call.getOperand(0), Builder2);
920 shadow =
lookup(shadow, Builder2);
921 if (shadow->getType()->isIntegerTy())
923 Builder2.CreateIntToPtr(shadow,
getInt8PtrTy(call.getContext()));
927 Type *MPI_OP_Ptr_type =
getUnqual(MPI_OP_type);
931 count =
lookup(count, Builder2);
934 datatype =
lookup(datatype, Builder2);
937 root =
lookup(root, Builder2);
941 comm =
lookup(comm, Builder2);
958 auto callval = call.getCalledOperand();
959 Builder2.CreateCall(call.getFunctionType(), callval, args, Defs);
963 Value *rank =
MPI_COMM_RANK(comm, Builder2, root->getType(), called);
964 Value *tysize =
MPI_TYPE_SIZE(datatype, Builder2, call.getType(), called);
966 auto len_arg = Builder2.CreateZExtOrTrunc(
967 count, Type::getInt64Ty(call.getContext()));
969 Builder2.CreateMul(len_arg,
970 Builder2.CreateZExtOrTrunc(
971 tysize, Type::getInt64Ty(call.getContext())),
978 BasicBlock *currentBlock = Builder2.GetInsertBlock();
980 currentBlock, currentBlock->getName() +
"_root", gutils->
newFunc);
982 rootBlock, currentBlock->getName() +
"_post", gutils->
newFunc);
984 Builder2.CreateCondBr(Builder2.CreateICmpEQ(rank, root), rootBlock,
987 Builder2.SetInsertPoint(rootBlock);
991 len_arg,
"mpireduce_malloccache");
992 Builder2.CreateBr(mergeBlock);
994 Builder2.SetInsertPoint(mergeBlock);
996 buf = Builder2.CreatePHI(rootbuf->getType(), 2);
997 buf->addIncoming(rootbuf, rootBlock);
998 buf->addIncoming(UndefValue::get(buf->getType()), currentBlock);
1020 MPI_OP_Ptr_type, MPI_OP_type, CT,
1021 root->getType(), Builder2),
1025 Type *types[
sizeof(args) /
sizeof(*args)];
1026 for (
size_t i = 0; i <
sizeof(args) /
sizeof(*args); i++)
1027 types[i] = args[i]->getType();
1029 FunctionType *FT = FunctionType::get(call.getType(), types,
false);
1031 Builder2.CreateCall(
1032 called->getParent()->getOrInsertFunction(
1038 BasicBlock *currentBlock = Builder2.GetInsertBlock();
1040 currentBlock, currentBlock->getName() +
"_root", gutils->
newFunc);
1042 rootBlock, currentBlock->getName() +
"_nonroot", gutils->
newFunc);
1044 nonrootBlock, currentBlock->getName() +
"_post", gutils->
newFunc);
1046 Builder2.CreateCondBr(Builder2.CreateICmpEQ(rank, root), rootBlock,
1049 Builder2.SetInsertPoint(rootBlock);
1052 auto volatile_arg = ConstantInt::getFalse(call.getContext());
1053 Value *nargs[] = {shadow, buf, len_arg, volatile_arg};
1055 Type *tys[] = {shadow->getType(), buf->getType(), len_arg->getType()};
1058 Intrinsic::memcpy, tys);
1061 cast<CallInst>(Builder2.CreateCall(memcpyF, nargs, BufferDefs));
1062 mem->setCallingConv(memcpyF->getCallingConv());
1070 Builder2.CreateBr(mergeBlock);
1072 Builder2.SetInsertPoint(nonrootBlock);
1075 auto val_arg = ConstantInt::get(Type::getInt8Ty(call.getContext()), 0);
1076 auto volatile_arg = ConstantInt::getFalse(call.getContext());
1077 Value *args[] = {shadow, val_arg, len_arg, volatile_arg};
1078 Type *tys[] = {args[0]->getType(), args[2]->getType()};
1079 auto memset = cast<CallInst>(Builder2.CreateCall(
1081 Intrinsic::memset, tys),
1083 memset->addParamAttr(0, Attribute::NonNull);
1084 Builder2.CreateBr(mergeBlock);
1086 Builder2.SetInsertPoint(mergeBlock);
1105 if (funcName ==
"MPI_Reduce" || funcName ==
"PMPI_Reduce") {
1115 IRBuilder<> Builder2 =
1116 forwardMode ? IRBuilder<>(&call) : IRBuilder<>(call.getParent());
1124 Value *orig_sendbuf = call.getOperand(0);
1125 Value *orig_recvbuf = call.getOperand(1);
1126 Value *orig_count = call.getOperand(2);
1127 Value *orig_datatype = call.getOperand(3);
1128 Value *orig_op = call.getOperand(4);
1129 Value *orig_root = call.getOperand(5);
1130 Value *orig_comm = call.getOperand(6);
1133 if (
Constant *C = dyn_cast<Constant>(orig_op)) {
1134 while (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) {
1135 C = CE->getOperand(0);
1137 if (
auto GV = dyn_cast<GlobalVariable>(C)) {
1138 if (GV->getName() ==
"ompi_mpi_op_sum") {
1143 if (ConstantInt *CI = dyn_cast<ConstantInt>(C)) {
1144 if (CI->getValue() == 1476395011) {
1151 llvm::raw_string_ostream ss(s);
1152 ss <<
" call: " << call <<
"\n";
1153 ss <<
" unhandled mpi_reduce op: " << *orig_op <<
"\n";
1158 Value *shadow_recvbuf = gutils->
invertPointerM(orig_recvbuf, Builder2);
1160 shadow_recvbuf =
lookup(shadow_recvbuf, Builder2);
1161 if (shadow_recvbuf->getType()->isIntegerTy())
1162 shadow_recvbuf = Builder2.CreateIntToPtr(
1165 Value *shadow_sendbuf = gutils->
invertPointerM(orig_sendbuf, Builder2);
1167 shadow_sendbuf =
lookup(shadow_sendbuf, Builder2);
1168 if (shadow_sendbuf->getType()->isIntegerTy())
1169 shadow_sendbuf = Builder2.CreateIntToPtr(
1178 Builder2, !forwardMode);
1182 count =
lookup(count, Builder2);
1186 datatype =
lookup(datatype, Builder2);
1190 op =
lookup(op, Builder2);
1194 root =
lookup(root, Builder2);
1198 comm =
lookup(comm, Builder2);
1200 Value *rank =
MPI_COMM_RANK(comm, Builder2, root->getType(), called);
1220 auto callval = call.getCalledOperand();
1221 Builder2.CreateCall(call.getFunctionType(), callval, args, Defs);
1225 Value *tysize =
MPI_TYPE_SIZE(datatype, Builder2, call.getType(), called);
1228 auto len_arg = Builder2.CreateZExtOrTrunc(
1229 count, Type::getInt64Ty(call.getContext()));
1231 Builder2.CreateMul(len_arg,
1232 Builder2.CreateZExtOrTrunc(
1233 tysize, Type::getInt64Ty(call.getContext())),
1239 len_arg,
"mpireduce_malloccache");
1244 BasicBlock *currentBlock = Builder2.GetInsertBlock();
1246 currentBlock, currentBlock->getName() +
"_root", gutils->
newFunc);
1248 rootBlock, currentBlock->getName() +
"_post", gutils->
newFunc);
1250 Builder2.CreateCondBr(Builder2.CreateICmpEQ(rank, root), rootBlock,
1253 Builder2.SetInsertPoint(rootBlock);
1256 auto volatile_arg = ConstantInt::getFalse(call.getContext());
1257 Value *nargs[] = {buf, shadow_recvbuf, len_arg, volatile_arg};
1259 Type *tys[] = {nargs[0]->getType(), nargs[1]->getType(),
1260 len_arg->getType()};
1263 Intrinsic::memcpy, tys);
1266 cast<CallInst>(Builder2.CreateCall(memcpyF, nargs, BufferDefs));
1267 mem->setCallingConv(memcpyF->getCallingConv());
1270 Builder2.CreateBr(mergeBlock);
1271 Builder2.SetInsertPoint(mergeBlock);
1286 Type *types[
sizeof(args) /
sizeof(*args)];
1287 for (
size_t i = 0; i <
sizeof(args) /
sizeof(*args); i++)
1288 types[i] = args[i]->getType();
1290 FunctionType *FT = FunctionType::get(call.getType(), types,
false);
1291 Builder2.CreateCall(
1292 called->getParent()->getOrInsertFunction(
1299 BasicBlock *currentBlock = Builder2.GetInsertBlock();
1301 currentBlock, currentBlock->getName() +
"_root", gutils->
newFunc);
1303 rootBlock, currentBlock->getName() +
"_post", gutils->
newFunc);
1305 Builder2.CreateCondBr(Builder2.CreateICmpEQ(rank, root), rootBlock,
1308 Builder2.SetInsertPoint(rootBlock);
1310 auto val_arg = ConstantInt::get(Type::getInt8Ty(call.getContext()), 0);
1311 auto volatile_arg = ConstantInt::getFalse(call.getContext());
1312 Value *args[] = {shadow_recvbuf, val_arg, len_arg, volatile_arg};
1313 Type *tys[] = {args[0]->getType(), args[2]->getType()};
1314 auto memset = cast<CallInst>(Builder2.CreateCall(
1316 Intrinsic::memset, tys),
1318 memset->addParamAttr(0, Attribute::NonNull);
1320 Builder2.CreateBr(mergeBlock);
1321 Builder2.SetInsertPoint(mergeBlock);
1326 len_arg, Builder2, BufferDefs);
1348 if (funcName ==
"MPI_Allreduce" || funcName ==
"PMPI_Allreduce") {
1358 IRBuilder<> Builder2 =
1359 forwardMode ? IRBuilder<>(&call) : IRBuilder<>(call.getParent());
1367 Value *orig_sendbuf = call.getOperand(0);
1368 Value *orig_recvbuf = call.getOperand(1);
1369 Value *orig_count = call.getOperand(2);
1370 Value *orig_datatype = call.getOperand(3);
1371 Value *orig_op = call.getOperand(4);
1372 Value *orig_comm = call.getOperand(5);
1375 if (
Constant *C = dyn_cast<Constant>(orig_op)) {
1376 while (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) {
1377 C = CE->getOperand(0);
1379 if (
auto GV = dyn_cast<GlobalVariable>(C)) {
1380 if (GV->getName() ==
"ompi_mpi_op_sum") {
1385 if (ConstantInt *CI = dyn_cast<ConstantInt>(C)) {
1386 if (CI->getValue() == 1476395011) {
1393 llvm::raw_string_ostream ss(s);
1394 ss <<
" call: " << call <<
"\n";
1395 ss <<
" unhandled mpi_allreduce op: " << *orig_op <<
"\n";
1400 Value *shadow_recvbuf = gutils->
invertPointerM(orig_recvbuf, Builder2);
1402 shadow_recvbuf =
lookup(shadow_recvbuf, Builder2);
1403 if (shadow_recvbuf->getType()->isIntegerTy())
1404 shadow_recvbuf = Builder2.CreateIntToPtr(
1407 Value *shadow_sendbuf = gutils->
invertPointerM(orig_sendbuf, Builder2);
1409 shadow_sendbuf =
lookup(shadow_sendbuf, Builder2);
1410 if (shadow_sendbuf->getType()->isIntegerTy())
1411 shadow_sendbuf = Builder2.CreateIntToPtr(
1419 Builder2, !forwardMode);
1423 count =
lookup(count, Builder2);
1427 datatype =
lookup(datatype, Builder2);
1431 comm =
lookup(comm, Builder2);
1435 op =
lookup(op, Builder2);
1447 auto callval = call.getCalledOperand();
1448 Builder2.CreateCall(call.getFunctionType(), callval, args, BufferDefs);
1453 Value *tysize =
MPI_TYPE_SIZE(datatype, Builder2, call.getType(), called);
1456 auto len_arg = Builder2.CreateZExtOrTrunc(
1457 count, Type::getInt64Ty(call.getContext()));
1459 Builder2.CreateMul(len_arg,
1460 Builder2.CreateZExtOrTrunc(
1461 tysize, Type::getInt64Ty(call.getContext())),
1467 len_arg,
"mpireduce_malloccache");
1481 Type *types[
sizeof(args) /
sizeof(*args)];
1482 for (
size_t i = 0; i <
sizeof(args) /
sizeof(*args); i++)
1483 types[i] = args[i]->getType();
1485 FunctionType *FT = FunctionType::get(call.getType(), types,
false);
1486 Builder2.CreateCall(
1487 called->getParent()->getOrInsertFunction(
1494 auto val_arg = ConstantInt::get(Type::getInt8Ty(call.getContext()), 0);
1495 auto volatile_arg = ConstantInt::getFalse(call.getContext());
1496 Value *args[] = {shadow_recvbuf, val_arg, len_arg, volatile_arg};
1497 Type *tys[] = {args[0]->getType(), args[2]->getType()};
1498 auto memset = cast<CallInst>(Builder2.CreateCall(
1500 Intrinsic::memset, tys),
1502 memset->addParamAttr(0, Attribute::NonNull);
1506 len_arg, Builder2, BufferDefs);
1529 if (funcName ==
"MPI_Gather" || funcName ==
"PMPI_Gather") {
1537 IRBuilder<> Builder2 =
1538 forwardMode ? IRBuilder<>(&call) : IRBuilder<>(call.getParent());
1545 Value *orig_sendbuf = call.getOperand(0);
1546 Value *orig_sendcount = call.getOperand(1);
1547 Value *orig_sendtype = call.getOperand(2);
1548 Value *orig_recvbuf = call.getOperand(3);
1549 Value *orig_recvcount = call.getOperand(4);
1550 Value *orig_recvtype = call.getOperand(5);
1551 Value *orig_root = call.getOperand(6);
1552 Value *orig_comm = call.getOperand(7);
1554 Value *shadow_recvbuf = gutils->
invertPointerM(orig_recvbuf, Builder2);
1556 shadow_recvbuf =
lookup(shadow_recvbuf, Builder2);
1557 if (shadow_recvbuf->getType()->isIntegerTy())
1558 shadow_recvbuf = Builder2.CreateIntToPtr(
1561 Value *shadow_sendbuf = gutils->
invertPointerM(orig_sendbuf, Builder2);
1563 shadow_sendbuf =
lookup(shadow_sendbuf, Builder2);
1564 if (shadow_sendbuf->getType()->isIntegerTy())
1565 shadow_sendbuf = Builder2.CreateIntToPtr(
1570 recvcount =
lookup(recvcount, Builder2);
1574 recvtype =
lookup(recvtype, Builder2);
1578 sendcount =
lookup(sendcount, Builder2);
1582 sendtype =
lookup(sendtype, Builder2);
1586 root =
lookup(root, Builder2);
1590 comm =
lookup(comm, Builder2);
1592 Value *rank =
MPI_COMM_RANK(comm, Builder2, root->getType(), called);
1593 Value *tysize =
MPI_TYPE_SIZE(sendtype, Builder2, call.getType(), called);
1614 auto callval = call.getCalledOperand();
1615 Builder2.CreateCall(call.getFunctionType(), callval, args, Defs);
1620 auto sendlen_arg = Builder2.CreateZExtOrTrunc(
1621 sendcount, Type::getInt64Ty(call.getContext()));
1623 Builder2.CreateMul(sendlen_arg,
1624 Builder2.CreateZExtOrTrunc(
1625 tysize, Type::getInt64Ty(call.getContext())),
1639 sendlen_arg,
"mpireduce_malloccache");
1657 Type *types[
sizeof(args) /
sizeof(*args)];
1658 for (
size_t i = 0; i <
sizeof(args) /
sizeof(*args); i++)
1659 types[i] = args[i]->getType();
1661 FunctionType *FT = FunctionType::get(call.getType(), types,
false);
1662 Builder2.CreateCall(
1663 called->getParent()->getOrInsertFunction(
1671 BasicBlock *currentBlock = Builder2.GetInsertBlock();
1673 currentBlock, currentBlock->getName() +
"_root", gutils->
newFunc);
1675 rootBlock, currentBlock->getName() +
"_post", gutils->
newFunc);
1677 Builder2.CreateCondBr(Builder2.CreateICmpEQ(rank, root), rootBlock,
1680 Builder2.SetInsertPoint(rootBlock);
1681 auto recvlen_arg = Builder2.CreateZExtOrTrunc(
1682 recvcount, Type::getInt64Ty(call.getContext()));
1684 Builder2.CreateMul(recvlen_arg,
1685 Builder2.CreateZExtOrTrunc(
1686 tysize, Type::getInt64Ty(call.getContext())),
1688 recvlen_arg = Builder2.CreateMul(
1690 Builder2.CreateZExtOrTrunc(
1692 Type::getInt64Ty(call.getContext())),
1695 auto val_arg = ConstantInt::get(Type::getInt8Ty(call.getContext()), 0);
1696 auto volatile_arg = ConstantInt::getFalse(call.getContext());
1697 Value *args[] = {shadow_recvbuf, val_arg, recvlen_arg, volatile_arg};
1698 Type *tys[] = {args[0]->getType(), args[2]->getType()};
1699 auto memset = cast<CallInst>(Builder2.CreateCall(
1701 Intrinsic::memset, tys),
1703 memset->addParamAttr(0, Attribute::NonNull);
1705 Builder2.CreateBr(mergeBlock);
1706 Builder2.SetInsertPoint(mergeBlock);
1711 sendlen_arg, Builder2, BufferDefs);
1734 if (funcName ==
"MPI_Scatter" || funcName ==
"PMPI_Scatter") {
1742 IRBuilder<> Builder2 =
1743 forwardMode ? IRBuilder<>(&call) : IRBuilder<>(call.getParent());
1750 Value *orig_sendbuf = call.getOperand(0);
1751 Value *orig_sendcount = call.getOperand(1);
1752 Value *orig_sendtype = call.getOperand(2);
1753 Value *orig_recvbuf = call.getOperand(3);
1754 Value *orig_recvcount = call.getOperand(4);
1755 Value *orig_recvtype = call.getOperand(5);
1756 Value *orig_root = call.getOperand(6);
1757 Value *orig_comm = call.getOperand(7);
1759 Value *shadow_recvbuf = gutils->
invertPointerM(orig_recvbuf, Builder2);
1761 shadow_recvbuf =
lookup(shadow_recvbuf, Builder2);
1762 if (shadow_recvbuf->getType()->isIntegerTy())
1763 shadow_recvbuf = Builder2.CreateIntToPtr(
1766 Value *shadow_sendbuf = gutils->
invertPointerM(orig_sendbuf, Builder2);
1768 shadow_sendbuf =
lookup(shadow_sendbuf, Builder2);
1769 if (shadow_sendbuf->getType()->isIntegerTy())
1770 shadow_sendbuf = Builder2.CreateIntToPtr(
1775 recvcount =
lookup(recvcount, Builder2);
1779 recvtype =
lookup(recvtype, Builder2);
1783 sendcount =
lookup(sendcount, Builder2);
1787 sendtype =
lookup(sendtype, Builder2);
1791 root =
lookup(root, Builder2);
1795 comm =
lookup(comm, Builder2);
1797 Value *rank =
MPI_COMM_RANK(comm, Builder2, root->getType(), called);
1798 Value *tysize =
MPI_TYPE_SIZE(sendtype, Builder2, call.getType(), called);
1819 auto callval = call.getCalledOperand();
1820 Builder2.CreateCall(call.getFunctionType(), callval, args, Defs);
1824 auto recvlen_arg = Builder2.CreateZExtOrTrunc(
1825 recvcount, Type::getInt64Ty(call.getContext()));
1827 Builder2.CreateMul(recvlen_arg,
1828 Builder2.CreateZExtOrTrunc(
1829 tysize, Type::getInt64Ty(call.getContext())),
1842 PHINode *sendlen_phi;
1845 BasicBlock *currentBlock = Builder2.GetInsertBlock();
1847 currentBlock, currentBlock->getName() +
"_root", gutils->
newFunc);
1849 rootBlock, currentBlock->getName() +
"_post", gutils->
newFunc);
1851 Builder2.CreateCondBr(Builder2.CreateICmpEQ(rank, root), rootBlock,
1854 Builder2.SetInsertPoint(rootBlock);
1856 auto sendlen_arg = Builder2.CreateZExtOrTrunc(
1857 sendcount, Type::getInt64Ty(call.getContext()));
1859 Builder2.CreateMul(sendlen_arg,
1860 Builder2.CreateZExtOrTrunc(
1861 tysize, Type::getInt64Ty(call.getContext())),
1863 sendlen_arg = Builder2.CreateMul(
1865 Builder2.CreateZExtOrTrunc(
1867 Type::getInt64Ty(call.getContext())),
1872 sendlen_arg,
"mpireduce_malloccache");
1874 Builder2.CreateBr(mergeBlock);
1876 Builder2.SetInsertPoint(mergeBlock);
1878 buf = Builder2.CreatePHI(rootbuf->getType(), 2);
1879 buf->addIncoming(rootbuf, rootBlock);
1880 buf->addIncoming(UndefValue::get(buf->getType()), currentBlock);
1882 sendlen_phi = Builder2.CreatePHI(sendlen_arg->getType(), 2);
1883 sendlen_phi->addIncoming(sendlen_arg, rootBlock);
1884 sendlen_phi->addIncoming(UndefValue::get(sendlen_arg->getType()),
1904 Type *types[
sizeof(args) /
sizeof(*args)];
1905 for (
size_t i = 0; i <
sizeof(args) /
sizeof(*args); i++)
1906 types[i] = args[i]->getType();
1908 FunctionType *FT = FunctionType::get(call.getType(), types,
false);
1909 Builder2.CreateCall(
1910 called->getParent()->getOrInsertFunction(
1917 auto val_arg = ConstantInt::get(Type::getInt8Ty(call.getContext()), 0);
1918 auto volatile_arg = ConstantInt::getFalse(call.getContext());
1919 Value *args[] = {shadow_recvbuf, val_arg, recvlen_arg, volatile_arg};
1920 Type *tys[] = {args[0]->getType(), args[2]->getType()};
1921 auto memset = cast<CallInst>(Builder2.CreateCall(
1923 Intrinsic::memset, tys),
1925 memset->addParamAttr(0, Attribute::NonNull);
1932 BasicBlock *currentBlock = Builder2.GetInsertBlock();
1934 currentBlock, currentBlock->getName() +
"_root", gutils->
newFunc);
1936 rootBlock, currentBlock->getName() +
"_post", gutils->
newFunc);
1938 Builder2.CreateCondBr(Builder2.CreateICmpEQ(rank, root), rootBlock,
1941 Builder2.SetInsertPoint(rootBlock);
1945 sendlen_phi, Builder2, BufferDefs);
1952 Builder2.CreateBr(mergeBlock);
1953 Builder2.SetInsertPoint(mergeBlock);
1974 if (funcName ==
"MPI_Allgather" || funcName ==
"PMPI_Allgather") {
1982 IRBuilder<> Builder2 =
1983 forwardMode ? IRBuilder<>(&call) : IRBuilder<>(call.getParent());
1990 Value *orig_sendbuf = call.getOperand(0);
1991 Value *orig_sendcount = call.getOperand(1);
1992 Value *orig_sendtype = call.getOperand(2);
1993 Value *orig_recvbuf = call.getOperand(3);
1994 Value *orig_recvcount = call.getOperand(4);
1995 Value *orig_recvtype = call.getOperand(5);
1996 Value *orig_comm = call.getOperand(6);
1998 Value *shadow_recvbuf = gutils->
invertPointerM(orig_recvbuf, Builder2);
2000 shadow_recvbuf =
lookup(shadow_recvbuf, Builder2);
2002 if (shadow_recvbuf->getType()->isIntegerTy())
2003 shadow_recvbuf = Builder2.CreateIntToPtr(
2006 Value *shadow_sendbuf = gutils->
invertPointerM(orig_sendbuf, Builder2);
2008 shadow_sendbuf =
lookup(shadow_sendbuf, Builder2);
2010 if (shadow_sendbuf->getType()->isIntegerTy())
2011 shadow_sendbuf = Builder2.CreateIntToPtr(
2016 recvcount =
lookup(recvcount, Builder2);
2020 recvtype =
lookup(recvtype, Builder2);
2024 sendcount =
lookup(sendcount, Builder2);
2028 sendtype =
lookup(sendtype, Builder2);
2032 comm =
lookup(comm, Builder2);
2034 Value *tysize =
MPI_TYPE_SIZE(sendtype, Builder2, call.getType(), called);
2054 auto callval = call.getCalledOperand();
2055 Builder2.CreateCall(call.getFunctionType(), callval, args, Defs);
2059 auto sendlen_arg = Builder2.CreateZExtOrTrunc(
2060 sendcount, Type::getInt64Ty(call.getContext()));
2062 Builder2.CreateMul(sendlen_arg,
2063 Builder2.CreateZExtOrTrunc(
2064 tysize, Type::getInt64Ty(call.getContext())),
2078 sendlen_arg,
"mpireduce_malloccache");
2082 Type *MPI_OP_Ptr_type =
getUnqual(MPI_OP_type);
2100 MPI_OP_Ptr_type, MPI_OP_type, CT,
2101 call.getType(), Builder2),
2104 Type *types[
sizeof(args) /
sizeof(*args)];
2105 for (
size_t i = 0; i <
sizeof(args) /
sizeof(*args); i++)
2106 types[i] = args[i]->getType();
2108 FunctionType *FT = FunctionType::get(call.getType(), types,
false);
2109 Builder2.CreateCall(
2110 called->getParent()->getOrInsertFunction(
2112 "MPI_Reduce_scatter_block"),
2119 auto recvlen_arg = Builder2.CreateZExtOrTrunc(
2120 recvcount, Type::getInt64Ty(call.getContext()));
2122 Builder2.CreateMul(recvlen_arg,
2123 Builder2.CreateZExtOrTrunc(
2124 tysize, Type::getInt64Ty(call.getContext())),
2126 recvlen_arg = Builder2.CreateMul(
2128 Builder2.CreateZExtOrTrunc(
2130 Type::getInt64Ty(call.getContext())),
2132 auto val_arg = ConstantInt::get(Type::getInt8Ty(call.getContext()), 0);
2133 auto volatile_arg = ConstantInt::getFalse(call.getContext());
2134 Value *args[] = {shadow_recvbuf, val_arg, recvlen_arg, volatile_arg};
2135 Type *tys[] = {args[0]->getType(), args[2]->getType()};
2136 auto memset = cast<CallInst>(Builder2.CreateCall(
2138 Intrinsic::memset, tys),
2140 memset->addParamAttr(0, Attribute::NonNull);
2145 sendlen_arg, Builder2, BufferDefs);
2159 if (funcName ==
"MPI_Barrier" || funcName ==
"PMPI_Barrier") {
2162 IRBuilder<> Builder2(&call);
2164 auto callval = call.getCalledOperand();
2167 Builder2.CreateCall(call.getFunctionType(), callval, args);
2176 if (funcName ==
"MPI_Comm_free" || funcName ==
"MPI_Comm_disconnect") {
2188 IRBuilder<> Builder2(&call);
2191 Value *args[] = {
lookup(call.getOperand(commFound->second), Builder2)};
2192 Type *types[] = {args[0]->getType()};
2194 FunctionType *FT = FunctionType::get(call.getType(), types,
false);
2195 Builder2.CreateCall(
2196 called->getParent()->getOrInsertFunction(
2205 llvm::errs() << *gutils->
oldFunc->getParent() <<
"\n";
2206 llvm::errs() << *gutils->
oldFunc <<
"\n";
2207 llvm::errs() << call <<
"\n";
2208 llvm::errs() << called <<
"\n";
2209 llvm_unreachable(
"Unhandled MPI FUNCTION");
2213 CallInst &call, Function *called, StringRef funcName,
2214 bool subsequent_calls_may_write,
const std::vector<bool> &overwritten_args,
2215 CallInst *
const newCall) {
2216 bool subretused =
false;
2217 bool shadowReturnUsed =
false;
2221 IRBuilder<> BuilderZ(newCall);
2222 BuilderZ.setFastMathFlags(
getFast());
2225 if (funcName ==
"__kmpc_for_static_init_4" ||
2226 funcName ==
"__kmpc_for_static_init_4u" ||
2227 funcName ==
"__kmpc_for_static_init_8" ||
2228 funcName ==
"__kmpc_for_static_init_8u") {
2229 IRBuilder<> Builder2(&call);
2231 auto fini = called->getParent()->getFunction(
"__kmpc_for_static_fini");
2236 auto fcall = Builder2.CreateCall(fini->getFunctionType(), fini, args);
2237 fcall->setCallingConv(fini->getCallingConv());
2244 funcName ==
"MPI_Comm_free" || funcName ==
"MPI_Comm_disconnect" ||
2252 if (handleBLAS(call, called, *blas, overwritten_args))
2256 if (funcName ==
"printf" || funcName ==
"puts" ||
2257 startsWith(funcName,
"_ZN3std2io5stdio6_print") ||
2264 if (called && (called->getName().contains(
"__enzyme_float") ||
2265 called->getName().contains(
"__enzyme_double") ||
2266 called->getName().contains(
"__enzyme_integer") ||
2267 called->getName().contains(
"__enzyme_pointer"))) {
2274 if (funcName ==
"__kmpc_for_static_init_4" ||
2275 funcName ==
"__kmpc_for_static_init_4u" ||
2276 funcName ==
"__kmpc_for_static_init_8" ||
2277 funcName ==
"__kmpc_for_static_init_8u") {
2279 IRBuilder<> Builder2(&call);
2281 auto fini = called->getParent()->getFunction(
"__kmpc_for_static_fini");
2287 auto fcall = Builder2.CreateCall(fini->getFunctionType(), fini, args);
2288 fcall->setCallingConv(fini->getCallingConv());
2292 if (funcName ==
"__kmpc_for_static_fini") {
2301 if (funcName ==
"__kmpc_barrier") {
2304 IRBuilder<> Builder2(&call);
2306 auto callval = call.getCalledOperand();
2310 Builder2.CreateCall(call.getFunctionType(), callval, args);
2314 if (funcName ==
"__kmpc_critical") {
2316 IRBuilder<> Builder2(&call);
2318 auto crit2 = called->getParent()->getFunction(
"__kmpc_end_critical");
2325 auto fcall = Builder2.CreateCall(crit2->getFunctionType(), crit2, args);
2326 fcall->setCallingConv(crit2->getCallingConv());
2330 if (funcName ==
"__kmpc_end_critical") {
2332 IRBuilder<> Builder2(&call);
2334 auto crit2 = called->getParent()->getFunction(
"__kmpc_critical");
2341 auto fcall = Builder2.CreateCall(crit2->getFunctionType(), crit2, args);
2342 fcall->setCallingConv(crit2->getCallingConv());
2348 funcName !=
"__kmpc_global_thread_num") {
2350 llvm::raw_string_ostream ss(s);
2351 ss <<
" unhandled openmp function: " << call <<
"\n";
2356 auto mod = call.getParent()->getParent()->getParent();
2357#include "CallDerivatives.inc"
2359 if (funcName ==
"llvm.julia.gc_preserve_end") {
2363 auto begin_call = cast<CallInst>(call.getOperand(0));
2365 IRBuilder<> Builder2(&call);
2367 SmallVector<Value *, 1> args;
2368 for (
auto &arg : begin_call->args()) {
2369 bool primalUsed =
false;
2370 bool shadowUsed =
false;
2378 Value *ptrshadow = gutils->
lookupM(
2381 args.push_back(ptrshadow);
2383 for (
size_t i = 0; i < gutils->
getWidth(); ++i)
2384 args.push_back(gutils->
extractMeta(Builder2, ptrshadow, i));
2388 auto newp = Builder2.CreateCall(
2389 called->getParent()->getOrInsertFunction(
2390 "llvm.julia.gc_preserve_begin",
2391 FunctionType::get(Type::getTokenTy(call.getContext()),
2392 ArrayRef<Type *>(),
true)),
2396 auto placeholder = cast<CallInst>(&*ifound->second);
2402 gutils->
erase(placeholder);
2406 if (funcName ==
"llvm.julia.gc_preserve_begin") {
2407 SmallVector<Value *, 1> args;
2408 for (
auto &arg : call.args()) {
2409 bool primalUsed =
false;
2410 bool shadowUsed =
false;
2419 args.push_back(ptrshadow);
2421 for (
size_t i = 0; i < gutils->
getWidth(); ++i)
2422 args.push_back(gutils->
extractMeta(BuilderZ, ptrshadow, i));
2426 auto newp = BuilderZ.CreateCall(called, args);
2429 gutils->
erase(oldp);
2433 IRBuilder<> Builder2(&call);
2438 auto placeholder = cast<CallInst>(&*ifound->second);
2439 Builder2.CreateCall(
2440 called->getParent()->getOrInsertFunction(
2441 "llvm.julia.gc_preserve_end",
2442 FunctionType::get(Builder2.getVoidTy(), call.getType(),
false)),
2456 if (funcName ==
"gsl_sf_legendre_array_e") {
2467 IRBuilder<> Builder2(&call);
2476 call.getOperand(0)->getType(), call.getOperand(1)->getType(),
2477 call.getOperand(2)->getType(), call.getOperand(3)->getType(),
2478 call.getOperand(4)->getType(), call.getOperand(4)->getType(),
2480 FunctionType *FT = FunctionType::get(call.getType(), types,
false);
2481 auto F = called->getParent()->getOrInsertFunction(
2482 "gsl_sf_legendre_deriv_array_e", FT);
2484 llvm::Value *args[6] = {
2496 Type *typesS[] = {args[1]->getType()};
2498 FunctionType::get(args[1]->getType(), typesS,
false);
2499 auto FS = called->getParent()->getOrInsertFunction(
2500 "gsl_sf_legendre_array_n", FTS);
2501 Value *alSize = Builder2.CreateCall(FS, args[1]);
2504 Builder2.CreateLifetimeStart(tmp);
2505 Builder2.CreateLifetimeStart(dtmp);
2507 args[4] = Builder2.CreateBitCast(tmp, types[4]);
2508 args[5] = Builder2.CreateBitCast(dtmp, types[5]);
2510 Builder2.CreateCall(F, args, Defs);
2511 Builder2.CreateLifetimeEnd(tmp);
2514 BasicBlock *currentBlock = Builder2.GetInsertBlock();
2517 currentBlock, currentBlock->getName() +
"_loop");
2518 BasicBlock *endBlock =
2522 Builder2.CreateCondBr(
2523 Builder2.CreateICmpEQ(args[1], Constant::getNullValue(types[1])),
2524 endBlock, loopBlock);
2525 Builder2.SetInsertPoint(loopBlock);
2527 auto idx = Builder2.CreatePHI(types[1], 2);
2528 idx->addIncoming(ConstantInt::get(types[1], 0,
false), currentBlock);
2530 auto acc_idx = Builder2.CreatePHI(types[2], 2);
2532 Value *inc = Builder2.CreateAdd(
2533 idx, ConstantInt::get(types[1], 1,
false),
"",
true,
true);
2534 idx->addIncoming(inc, loopBlock);
2535 acc_idx->addIncoming(Constant::getNullValue(types[2]), currentBlock);
2537 Value *idxs[] = {idx};
2538 Value *dtmp_idx = Builder2.CreateInBoundsGEP(types[2], dtmp, idxs);
2539 Value *d_req = Builder2.CreateInBoundsGEP(
2541 Builder2.CreatePointerCast(
2546 auto l0 = Builder2.CreateLoad(types[2], dtmp_idx);
2547 auto l1 = Builder2.CreateLoad(types[2], d_req);
2548 auto acc = Builder2.CreateFAdd(acc_idx, Builder2.CreateFMul(l0, l1));
2549 Builder2.CreateStore(Constant::getNullValue(types[2]), d_req);
2551 acc_idx->addIncoming(acc, loopBlock);
2553 Builder2.CreateCondBr(Builder2.CreateICmpEQ(inc, args[1]), endBlock,
2556 Builder2.SetInsertPoint(endBlock);
2560 SmallVector<BasicBlock *, 4> &vec =
2563 vec.push_back(endBlock);
2566 auto fin_idx = Builder2.CreatePHI(types[2], 2);
2567 fin_idx->addIncoming(Constant::getNullValue(types[2]), currentBlock);
2568 fin_idx->addIncoming(acc, loopBlock);
2570 Builder2.CreateLifetimeEnd(dtmp);
2574 ->addToDiffe(call.getOperand(2), fin_idx, Builder2, types[2]);
2582 if (funcName ==
"_ZSt29_Rb_tree_insert_and_rebalancebPSt18_Rb_tree_"
2583 "node_baseS0_RS_") {
2590 SmallVector<Value *, 2> args;
2591 for (
auto &arg : call.args()) {
2597 BuilderZ.CreateCall(called, args);
2603 if (funcName ==
"_ZNSt8ios_baseC2Ev" || funcName ==
"_ZNSt8ios_baseD2Ev" ||
2604 funcName ==
"_ZNSt6localeC1Ev" || funcName ==
"_ZNSt6localeD1Ev" ||
2605 funcName ==
"_ZNKSt5ctypeIcE13_M_widen_initEv") {
2613 Value *args[] = {gutils->
invertPointerM(call.getArgOperand(0), BuilderZ)};
2614 BuilderZ.CreateCall(called, args);
2618 if (funcName ==
"_ZNSt9basic_iosIcSt11char_traitsIcEE4initEPSt15basic_"
2619 "streambufIcS1_E") {
2627 Value *args[] = {gutils->
invertPointerM(call.getArgOperand(0), BuilderZ),
2629 BuilderZ.CreateCall(called, args);
2635 if (funcName ==
"__dynamic_cast" ||
2636 funcName ==
"_ZSt18_Rb_tree_decrementPKSt18_Rb_tree_node_base" ||
2637 funcName ==
"_ZSt18_Rb_tree_incrementPKSt18_Rb_tree_node_base" ||
2638 funcName ==
"_ZSt18_Rb_tree_decrementPSt18_Rb_tree_node_base" ||
2639 funcName ==
"_ZSt18_Rb_tree_incrementPSt18_Rb_tree_node_base" ||
2640 funcName ==
"jl_ptr_to_array" || funcName ==
"jl_ptr_to_array_1d") {
2641 bool shouldCache =
false;
2648 ValueToValueMapTy empty;
2654 auto placeholder = cast<PHINode>(&*ifound->second);
2657 Value *shadow = placeholder;
2665 SmallVector<Value *, 2> args;
2667 for (
auto &arg : call.args()) {
2669 (funcName ==
"__dynamic_cast" && i > 0) ||
2670 (funcName ==
"jl_ptr_to_array_1d" && i != 1) ||
2671 (funcName ==
"jl_ptr_to_array" && i != 1))
2677 shadow = BuilderZ.CreateCall(called, args);
2681 bool needsReplacement =
true;
2687 needsReplacement =
false;
2692 if (needsReplacement) {
2693 assert(shadow != placeholder);
2695 gutils->
erase(placeholder);
2699 gutils->
erase(placeholder);
2710 if (!shouldCache && !lrc) {
2711 std::map<UsageKey, bool> Seen;
2714 bool primalNeededInReverse =
2717 shouldCache = primalNeededInReverse;
2721 BuilderZ.SetInsertPoint(newCall->getNextNode());
2731 if (funcName ==
"julia.write_barrier" ||
2732 funcName ==
"julia.write_barrier_binding") {
2734 std::map<UsageKey, bool> Seen;
2739 bool backwardsShadow =
false;
2740 bool forwardsShadow =
true;
2742 if (pair.second.stores.count(&call)) {
2743 backwardsShadow =
true;
2744 forwardsShadow = pair.second.primalInitialize;
2745 if (
auto inst = dyn_cast<Instruction>(pair.first))
2746 if (!forwardsShadow && pair.second.LI &&
2747 pair.second.LI->contains(inst->getParent()))
2748 backwardsShadow =
false;
2756 (forwardsShadow || backwardsShadow)) ||
2760 for (
size_t i = 0; i < gutils->
getWidth(); i++) {
2761 SmallVector<Value *, 1> iargs;
2763 for (
auto &arg : call.args()) {
2767 ptrshadow = gutils->
extractMeta(BuilderZ, ptrshadow, i);
2769 iargs.push_back(ptrshadow);
2777 BuilderZ.CreateCall(called, iargs);
2782 bool forceErase =
false;
2789 if (!pair.second.stores.count(&call))
2791 bool primalNeededInReverse =
2799 bool cacheWholeAllocation =
2800 gutils->needsCacheWholeAllocation(pair.first);
2801 if (cacheWholeAllocation) {
2802 primalNeededInReverse =
true;
2805 if (primalNeededInReverse && !cacheWholeAllocation)
2810 if (!pair.second.LI || !pair.second.LI->contains(&call)) {
2822 Intrinsic::ID ID = Intrinsic::not_intrinsic;
2839 if (ID != Intrinsic::not_intrinsic) {
2840 SmallVector<Value *, 2> orig_ops(call.getNumOperands());
2841 for (
unsigned i = 0; i < call.getNumOperands(); ++i) {
2842 orig_ops[i] = call.getOperand(i);
2861 if (
auto assembly = dyn_cast<InlineAsm>(call.getCalledOperand())) {
2862 if (assembly->getAsmString() ==
"maxpd $1, $0") {
2877 SmallVector<Value *, 2> orig_ops(call.getNumOperands());
2878 for (
unsigned i = 0; i < call.getNumOperands(); ++i) {
2879 orig_ops[i] = call.getOperand(i);
2894 if (funcName ==
"realloc") {
2898 IRBuilder<> Builder2(&call);
2903 auto rule = [&](Value *ip) {
2909 llvm::Value *args[2] = {
2911 CallInst *CI = Builder2.CreateCall(
2912 call.getFunctionType(), call.getCalledFunction(), args, Defs);
2913 CI->setAttributes(call.getAttributes());
2914 CI->setCallingConv(call.getCallingConv());
2915 CI->setTailCallKind(call.getTailCallKind());
2916 CI->setDebugLoc(dbgLoc);
2921 call.getType(), Builder2, rule,
2925 PHINode *placeholder = cast<PHINode>(&*found->second);
2929 gutils->
erase(placeholder);
2945 PHINode *placeholder = cast<PHINode>(&*found->second);
2946 IRBuilder<> bb(placeholder);
2948 SmallVector<Value *, 8> args;
2949 for (
auto &arg : call.args()) {
2958 Value *anti = placeholder;
2965 bool backwardsShadow =
false;
2966 bool forwardsShadow =
true;
2967 bool inLoop =
false;
2968 bool isAlloca = isa<AllocaInst>(&call);
2972 backwardsShadow =
true;
2973 forwardsShadow = found->second.primalInitialize;
2975 if (found->second.LI &&
2976 found->second.LI->contains(call.getParent()))
2982 if (!forwardsShadow) {
2986 *gutils->
oldFunc->getParent(), placeholder->getType());
2991 gutils->
erase(placeholder);
2994 }
else if (inLoop) {
3002 placeholder->setName(
"");
3004 bb.SetInsertPoint(placeholder);
3011 return shadowHandlers[funcName](bb, &call, args, gutils);
3013 if (anti->getType() != placeholder->getType()) {
3014 llvm::errs() <<
"orig: " << call <<
"\n";
3015 llvm::errs() <<
"placeholder: " << *placeholder <<
"\n";
3016 llvm::errs() <<
"anti: " << *anti <<
"\n";
3019 bb.SetInsertPoint(placeholder);
3022 gutils->
erase(placeholder);
3025 if (
auto inst = dyn_cast<Instruction>(anti))
3026 bb.SetInsertPoint(inst);
3028 if (!backwardsShadow)
3032 bool zeroed =
false;
3034 Value *prev =
nullptr;
3038 bb.CreateCall(call.getFunctionType(), call.getCalledOperand(),
3039 args, call.getName() +
"'mi");
3040 cast<CallInst>(anti)->setAttributes(call.getAttributes());
3041 cast<CallInst>(anti)->setCallingConv(call.getCallingConv());
3042 cast<CallInst>(anti)->setTailCallKind(call.getTailCallKind());
3043 cast<CallInst>(anti)->setDebugLoc(dbgLoc);
3045 if (anti->getType()->isPointerTy()) {
3046 cast<CallInst>(anti)->addAttributeAtIndex(
3047 AttributeList::ReturnIndex, Attribute::NoAlias);
3048 cast<CallInst>(anti)->addAttributeAtIndex(
3049 AttributeList::ReturnIndex, Attribute::NonNull);
3051 if (funcName ==
"malloc" || funcName ==
"_Znwm" ||
3052 funcName ==
"??2@YAPAXI@Z" ||
3053 funcName ==
"??2@YAPEAX_K@Z") {
3054 if (
auto ci = dyn_cast<ConstantInt>(args[0])) {
3055 unsigned derefBytes = ci->getLimitedValue();
3058 cast<CallInst>(anti)->addDereferenceableRetAttr(derefBytes);
3059 cal->addDereferenceableRetAttr(derefBytes);
3060#if !defined(FLANG) && !defined(ROCM)
3061 AttrBuilder B(ci->getContext());
3065 B.addDereferenceableOrNullAttr(derefBytes);
3066 cast<CallInst>(anti)->setAttributes(
3067 cast<CallInst>(anti)->getAttributes().addRetAttributes(
3068 call.getContext(), B));
3069 cal->setAttributes(cal->getAttributes().addRetAttributes(
3070 call.getContext(), B));
3071 cal->addAttributeAtIndex(AttributeList::ReturnIndex,
3072 Attribute::NoAlias);
3073 cal->addAttributeAtIndex(AttributeList::ReturnIndex,
3074 Attribute::NonNull);
3077 if (funcName ==
"julia.gc_alloc_obj" ||
3078 funcName ==
"jl_gc_alloc_typed" ||
3079 funcName ==
"ijl_gc_alloc_typed") {
3081 bool used = unnecessaryInstructions.find(&call) ==
3082 unnecessaryInstructions.end();
3084 idx, wrap(prev), used);
3109 if (&*bb.GetInsertPoint() == placeholder)
3110 bb.SetInsertPoint(placeholder->getNextNode());
3112 gutils->
erase(placeholder);
3114 if (!backwardsShadow)
3118 if (
auto MD =
hasMetadata(&call,
"enzyme_fromstack")) {
3120 bb.SetInsertPoint(cast<Instruction>(anti));
3122 if (funcName ==
"malloc")
3124 else if (funcName ==
"julia.gc_alloc_obj" ||
3125 funcName ==
"jl_gc_alloc_typed" ||
3126 funcName ==
"ijl_gc_alloc_typed")
3129 llvm_unreachable(
"Unknown allocation to upgrade");
3131 Type *elTy = Type::getInt8Ty(call.getContext());
3132 if (MD->getNumOperands() == 2) {
3133 elTy = (Type *)cast<ConstantInt>(
3134 cast<ConstantAsMetadata>(MD->getOperand(1))
3136 ->getLimitedValue();
3137 Value *tsize = ConstantInt::get(
3138 Size->getType(), (gutils->
newFunc->getParent()
3140 .getTypeAllocSizeInBits(elTy) +
3144 Size = bb.CreateUDiv(Size, tsize,
"",
true);
3146 std::string name =
"";
3147#if LLVM_VERSION_MAJOR < 17
3148 if (call.getContext().supportsTypedPointers()) {
3149 for (
auto U : call.users()) {
3150 if (
hasMetadata(cast<Instruction>(U),
"enzyme_caststack")) {
3151 if (MD->getNumOperands() == 1) {
3152 elTy = U->getType()->getPointerElementType();
3153 Value *tsize = ConstantInt::get(
3157 .getTypeAllocSizeInBits(elTy) +
3161 Size = bb.CreateUDiv(Size, tsize,
"",
true);
3163 name = (U->getName() +
"'ai").str();
3169 auto rule = [&](Value *anti) {
3170 bb.SetInsertPoint(cast<Instruction>(anti));
3171 Value *replacement = bb.CreateAlloca(elTy, Size, name);
3172 if (name.size() == 0)
3173 replacement->takeName(anti);
3176 auto Alignment = cast<ConstantInt>(cast<ConstantAsMetadata>(
3179 ->getLimitedValue();
3181 cast<AllocaInst>(replacement)
3182 ->setAlignment(Align(Alignment));
3184#if LLVM_VERSION_MAJOR < 17
3185 if (call.getContext().supportsTypedPointers()) {
3186 if (anti->getType()->getPointerElementType() != elTy)
3187 replacement = bb.CreatePointerCast(
3189 getUnqual(anti->getType()->getPointerElementType()));
3192 if (
int AS = cast<PointerType>(anti->getType())
3193 ->getAddressSpace()) {
3194 llvm::PointerType *PT;
3195#if LLVM_VERSION_MAJOR < 17
3196 if (call.getContext().supportsTypedPointers()) {
3197 PT = PointerType::get(
3198 anti->getType()->getPointerElementType(), AS);
3200#if LLVM_VERSION_MAJOR < 17
3203 PT = PointerType::get(anti->getContext(), AS);
3204#if LLVM_VERSION_MAJOR < 17
3207 replacement = bb.CreateAddrSpaceCast(replacement, PT);
3208 cast<Instruction>(replacement)
3211 MDNode::get(replacement->getContext(), {}));
3213 gutils->
replaceAWithB(cast<Instruction>(anti), replacement);
3214 bb.SetInsertPoint(cast<Instruction>(anti)->getNextNode());
3215 gutils->
erase(cast<Instruction>(anti));
3243 IRBuilder<> Builder2(&call);
3246 Value *tofree =
lookup(anti, Builder2);
3248 assert(tofree->getType());
3249 auto rule = [&](Value *tofree) {
3251 gutils->
TLI, &call, gutils);
3253 CI->addAttributeAtIndex(AttributeList::FirstArgIndex,
3254 Attribute::NonNull);
3260 IRBuilder<> Builder2(&call);
3263 SmallVector<Value *, 2> args;
3264 for (
unsigned i = 0; i < call.arg_size(); ++i) {
3265 auto arg = call.getArgOperand(i);
3277 CallInst *CI = Builder2.CreateCall(
3278 call.getFunctionType(), call.getCalledFunction(), args, Defs);
3279 CI->setAttributes(call.getAttributes());
3280 CI->setCallingConv(call.getCallingConv());
3281 CI->setTailCallKind(call.getTailCallKind());
3282 CI->setDebugLoc(dbgLoc);
3284 if (funcName ==
"julia.gc_alloc_obj" ||
3285 funcName ==
"jl_gc_alloc_typed" ||
3286 funcName ==
"ijl_gc_alloc_typed") {
3288 bool used = unnecessaryInstructions.find(&call) ==
3289 unnecessaryInstructions.end();
3302 PHINode *placeholder = cast<PHINode>(&*found->second);
3306 gutils->
erase(placeholder);
3319 std::map<UsageKey, bool> Seen;
3325 bool primalNeededInReverse =
3332 bool cacheWholeAllocation = gutils->needsCacheWholeAllocation(&call);
3333 if (cacheWholeAllocation) {
3334 primalNeededInReverse =
true;
3337 auto restoreFromStack = [&](MDNode *MD) {
3338 IRBuilder<> B(newCall);
3340 if (funcName ==
"malloc")
3341 Size = call.getArgOperand(0);
3342 else if (funcName ==
"julia.gc_alloc_obj" ||
3343 funcName ==
"jl_gc_alloc_typed" ||
3344 funcName ==
"ijl_gc_alloc_typed")
3345 Size = call.getArgOperand(1);
3347 llvm_unreachable(
"Unknown allocation to upgrade");
3348 Size = gutils->getNewFromOriginal(Size);
3350 if (isa<ConstantInt>(Size)) {
3351 B.SetInsertPoint(gutils->inversionAllocs);
3353 Type *elTy = Type::getInt8Ty(call.getContext());
3354 if (MD->getNumOperands() == 2) {
3355 elTy = (Type *)cast<ConstantInt>(
3356 cast<ConstantAsMetadata>(MD->getOperand(1))->getValue())
3357 ->getLimitedValue();
3358 Value *tsize = ConstantInt::get(Size->getType(),
3359 (gutils->newFunc->getParent()
3361 .getTypeAllocSizeInBits(elTy) +
3364 Size = B.CreateUDiv(Size, tsize,
"",
true);
3366 Instruction *I =
nullptr;
3367#if LLVM_VERSION_MAJOR < 17
3368 if (call.getContext().supportsTypedPointers()) {
3369 for (
auto U : call.users()) {
3370 if (
hasMetadata(cast<Instruction>(U),
"enzyme_caststack")) {
3371 if (MD->getNumOperands() == 1) {
3372 elTy = U->getType()->getPointerElementType();
3373 Value *tsize = ConstantInt::get(
3374 Size->getType(), (gutils->newFunc->getParent()
3376 .getTypeAllocSizeInBits(elTy) +
3380 Size = B.CreateUDiv(Size, tsize,
"",
true);
3382 I = gutils->getNewFromOriginal(cast<Instruction>(U));
3388 Value *replacement = B.CreateAlloca(elTy, Size);
3389 for (
auto MD : {
"enzyme_active",
"enzyme_inactive",
"enzyme_type",
3390 "enzymejl_allocart",
"enzymejl_allocart_name",
3391 "enzymejl_gc_alloc_rt"})
3392 if (
auto M = call.getMetadata(MD))
3393 cast<AllocaInst>(replacement)->setMetadata(MD, M);
3395 replacement->takeName(I);
3397 replacement->takeName(newCall);
3400 cast<ConstantAsMetadata>(MD->getOperand(0))->getValue())
3401 ->getLimitedValue();
3404 cast<AllocaInst>(replacement)->setAlignment(Align(Alignment));
3406#if LLVM_VERSION_MAJOR < 17
3407 if (call.getContext().supportsTypedPointers()) {
3408 if (call.getType()->getPointerElementType() != elTy)
3409 replacement = B.CreatePointerCast(
3410 replacement,
getUnqual(call.getType()->getPointerElementType()));
3413 if (
int AS = cast<PointerType>(call.getType())->getAddressSpace()) {
3414 llvm::PointerType *PT;
3415#if LLVM_VERSION_MAJOR < 17
3416 if (call.getContext().supportsTypedPointers()) {
3417 PT = PointerType::get(call.getType()->getPointerElementType(), AS);
3419#if LLVM_VERSION_MAJOR < 17
3422 PT = PointerType::get(call.getContext(), AS);
3423#if LLVM_VERSION_MAJOR < 17
3426 replacement = B.CreateAddrSpaceCast(replacement, PT);
3427 cast<Instruction>(replacement)
3428 ->setMetadata(
"enzyme_backstack",
3429 MDNode::get(replacement->getContext(), {}));
3431 gutils->replaceAWithB(newCall, replacement);
3432 gutils->erase(newCall);
3437 auto found = gutils->rematerializableAllocations.find(&call);
3438 if (found != gutils->rematerializableAllocations.end()) {
3441 if (primalNeededInReverse && !cacheWholeAllocation) {
3442 assert(!unnecessaryValues.count(&call));
3451 auto AllocationLoop = gutils->OrigLI->getLoopFor(call.getParent());
3455 assert(found->second.LI);
3456 if (
auto MD =
hasMetadata(&call,
"enzyme_fromstack")) {
3458 gutils->rematerializedPrimalOrShadowAllocations.push_back(
3461 restoreFromStack(MD);
3468 cast<PointerType>(call.getType())->getAddressSpace() == 10) {
3470 gutils->rematerializedPrimalOrShadowAllocations.push_back(
3479 IRBuilder<> Builder2(&call);
3481 auto dbgLoc = gutils->getNewFromOriginal(call.getDebugLoc());
3483 dbgLoc, gutils->TLI, &call, gutils);
3485 gutils->rematerializedPrimalOrShadowAllocations.push_back(
3492 }
else if (!cacheWholeAllocation) {
3493 if (unnecessaryValues.count(&call)) {
3502 else if (
auto MD =
hasMetadata(&call,
"enzyme_fromstack")) {
3503 restoreFromStack(MD);
3514 bool hasPDFree = gutils->allocationsWithGuaranteedFree.count(&call);
3515 if (!primalNeededInReverse && hasPDFree) {
3516 if (unnecessaryValues.count(&call)) {
3524 if (
auto MD =
hasMetadata(&call,
"enzyme_fromstack")) {
3525 restoreFromStack(MD);
3534 cast<PointerType>(call.getType())->getAddressSpace() == 10) {
3539 if (!primalNeededInReverse) {
3542 auto pn = BuilderZ.CreatePHI(call.getType(), 1,
3543 call.getName() +
"_replacementJ");
3544 gutils->fictiousPHIs[pn] = &call;
3545 gutils->replaceAWithB(newCall, pn);
3546 gutils->erase(newCall);
3549 gutils->cacheForReverse(BuilderZ, newCall,
3561 if ((primalNeededInReverse &&
3562 !gutils->unnecessaryIntermediates.count(&call)) ||
3564 Value *nop = gutils->cacheForReverse(
3570 IRBuilder<> Builder2(&call);
3572 auto dbgLoc = gutils->getNewFromOriginal(call.getDebugLoc());
3574 gutils->TLI, &call, gutils);
3582 auto pn = BuilderZ.CreatePHI(call.getType(), 1,
3583 call.getName() +
"_replacementB");
3584 gutils->fictiousPHIs[pn] = &call;
3585 gutils->replaceAWithB(newCall, pn);
3586 gutils->erase(newCall);
3592 if (funcName ==
"julia.gc_loaded") {
3600 if (
auto placeholder = dyn_cast<PHINode>(&*ifound->second)) {
3605 gutils->invertedPointers.erase(ifound);
3606 gutils->erase(placeholder);
3611 gutils->invertedPointers.erase(ifound);
3612 auto res = gutils->invertPointerM(&call, BuilderZ);
3614 gutils->replaceAWithB(placeholder, res);
3615 gutils->erase(placeholder);
3622 if (funcName ==
"julia.pointer_from_objref") {
3631 auto placeholder = cast<PHINode>(&*ifound->second);
3635 gutils, &call, Mode, oldUnreachable);
3638 gutils->
erase(placeholder);
3643 Value *ptrshadow = gutils->
invertPointerM(call.getArgOperand(0), BuilderZ);
3646 call.getType(), BuilderZ,
3647 [&](Value *v) -> Value * { return BuilderZ.CreateCall(called, {v}); },
3651 gutils->
erase(placeholder);
3655 if (funcName.contains(
"__enzyme_todense")) {
3656 if (gutils->isConstantValue(&call)) {
3657 eraseIfUnused(call);
3661 auto ifound = gutils->invertedPointers.find(&call);
3662 assert(ifound != gutils->invertedPointers.end());
3664 auto placeholder = cast<PHINode>(&*ifound->second);
3668 gutils, &call, Mode, oldUnreachable);
3670 gutils->invertedPointers.erase(ifound);
3671 gutils->erase(placeholder);
3672 eraseIfUnused(call);
3676 SmallVector<Value *, 3> args;
3677 for (
size_t i = 0; i < 2; i++)
3678 args.push_back(gutils->getNewFromOriginal(call.getArgOperand(i)));
3679 for (
size_t i = 2; i < call.arg_size(); ++i)
3680 args.push_back(gutils->invertPointerM(call.getArgOperand(0), BuilderZ));
3682 Value *res = UndefValue::get(gutils->getShadowType(call.getType()));
3683 if (gutils->getWidth() == 1) {
3684 res = BuilderZ.CreateCall(called, args);
3686 for (
size_t w = 0; w < gutils->getWidth(); ++w) {
3687 SmallVector<Value *, 3> targs = {args[0], args[1]};
3688 for (
size_t i = 2; i < call.arg_size(); ++i)
3691 auto tres = BuilderZ.CreateCall(called, targs);
3692 res = BuilderZ.CreateInsertValue(res, tres, w);
3696 gutils->replaceAWithB(placeholder, res);
3697 gutils->erase(placeholder);
3698 eraseIfUnused(call);
3702 if (funcName ==
"memcpy" || funcName ==
"memmove") {
3703 auto ID = (funcName ==
"memcpy") ? Intrinsic::memcpy : Intrinsic::memmove;
3704 visitMemTransferCommon(ID, MaybeAlign(1),
3705 MaybeAlign(1), call,
3706 call.getArgOperand(0), call.getArgOperand(1),
3707 gutils->getNewFromOriginal(call.getArgOperand(2)),
3708 ConstantInt::getFalse(call.getContext()));
3711 if (funcName ==
"memset" || funcName ==
"memset_pattern16" ||
3712 funcName ==
"__memset_chk") {
3713 visitMemSetCommon(call);
3716 if (funcName ==
"enzyme_zerotype") {
3717 IRBuilder<> BuilderZ(&call);
3718 getForwardBuilder(BuilderZ);
3720 bool backwardsShadow =
false;
3721 bool forwardsShadow =
true;
3722 for (
auto pair : gutils->backwardsOnlyShadows) {
3723 if (pair.second.stores.count(&call)) {
3724 backwardsShadow =
true;
3725 forwardsShadow = pair.second.primalInitialize;
3726 if (
auto inst = dyn_cast<Instruction>(pair.first))
3727 if (!forwardsShadow && pair.second.LI &&
3728 pair.second.LI->contains(inst->getParent()))
3729 backwardsShadow =
false;
3740 eraseIfUnused(call,
true,
false);
3742 eraseIfUnused(call);
3744 Value *orig_op0 = call.getArgOperand(0);
3747 if (gutils->isConstantValue(orig_op0)) {
3752 Value *op0 = gutils->invertPointerM(orig_op0, BuilderZ);
3753 Value *op1 = gutils->getNewFromOriginal(call.getArgOperand(1));
3754 Value *op2 = gutils->getNewFromOriginal(call.getArgOperand(2));
3755 auto Defs = gutils->getInvertedBundles(
3762 SmallVector<Value *, 4> args = {op0, op1, op2};
3764 BuilderZ.CreateCall(call.getCalledFunction(), args, Defs);
3765 llvm::SmallVector<unsigned int, 9> ToCopy2(
MD_ToCopy);
3766 ToCopy2.push_back(LLVMContext::MD_noalias);
3767 cal->copyMetadata(call, ToCopy2);
3768 cal->setAttributes(call.getAttributes());
3769 if (
auto m =
hasMetadata(&call,
"enzyme_zerostack"))
3770 cal->setMetadata(
"enzyme_zerostack", m);
3771 cal->setCallingConv(call.getCallingConv());
3772 cal->setTailCallKind(call.getTailCallKind());
3773 cal->setDebugLoc(gutils->getNewFromOriginal(call.getDebugLoc()));
3779 if (funcName ==
"cuStreamCreate") {
3780 Value *val =
nullptr;
3782#if LLVM_VERSION_MAJOR < 17
3783 if (call.getContext().supportsTypedPointers()) {
3784 if (isa<PointerType>(call.getArgOperand(0)->getType()))
3785 PT = call.getArgOperand(0)->getType()->getPointerElementType();
3790 val = gutils->getNewFromOriginal(call.getOperand(0));
3791 if (!isa<PointerType>(val->getType()))
3792 val = BuilderZ.CreateIntToPtr(val,
getUnqual(PT));
3793 val = BuilderZ.CreateLoad(PT, val);
3794 val = gutils->cacheForReverse(BuilderZ, val,
3798 PHINode *toReplace =
3799 BuilderZ.CreatePHI(PT, 1, call.getName() +
"_psxtmp");
3800 val = gutils->cacheForReverse(BuilderZ, toReplace,
3805 IRBuilder<> Builder2(&call);
3806 getReverseBuilder(Builder2);
3807 val = gutils->lookupM(val, Builder2);
3808 auto FreeFunc = gutils->newFunc->getParent()->getOrInsertFunction(
3809 "cuStreamDestroy", call.getType(), PT);
3810 Value *nargs[] = {val};
3811 Builder2.CreateCall(FreeFunc, nargs);
3814 eraseIfUnused(call,
true,
false);
3817 if (funcName ==
"cuStreamDestroy") {
3820 eraseIfUnused(call,
true,
false);
3823 if (funcName ==
"cuStreamSynchronize") {
3826 IRBuilder<> Builder2(&call);
3827 getReverseBuilder(Builder2);
3828 Value *nargs[] = {gutils->lookupM(
3829 gutils->getNewFromOriginal(call.getOperand(0)), Builder2)};
3830 auto callval = call.getCalledOperand();
3831 Builder2.CreateCall(call.getFunctionType(), callval, nargs);
3834 eraseIfUnused(call,
true,
false);
3837 if (funcName ==
"posix_memalign" || funcName ==
"cuMemAllocAsync" ||
3838 funcName ==
"cuMemAlloc" || funcName ==
"cuMemAlloc_v2" ||
3839 funcName ==
"cudaMalloc" || funcName ==
"cudaMallocAsync" ||
3840 funcName ==
"cudaMallocHost" || funcName ==
"cudaMallocFromPoolAsync") {
3841 bool constval = gutils->isConstantInstruction(&call);
3845#if LLVM_VERSION_MAJOR < 17
3846 if (call.getContext().supportsTypedPointers()) {
3847 if (isa<PointerType>(call.getArgOperand(0)->getType()))
3848 PT = call.getArgOperand(0)->getType()->getPointerElementType();
3852 Value *stream =
nullptr;
3853 if (funcName ==
"cuMemAllocAsync")
3854 stream = gutils->getNewFromOriginal(call.getArgOperand(2));
3855 else if (funcName ==
"cudaMallocAsync")
3856 stream = gutils->getNewFromOriginal(call.getArgOperand(2));
3857 else if (funcName ==
"cudaMallocFromPoolAsync")
3858 stream = gutils->getNewFromOriginal(call.getArgOperand(3));
3860 auto M = gutils->newFunc->getParent();
3867 gutils->invertPointerM(call.getArgOperand(0), BuilderZ);
3868 SmallVector<Value *, 1> args;
3869 SmallVector<ValueType, 1> valtys;
3870 args.push_back(ptrshadow);
3872 for (
size_t i = 1; i < call.arg_size(); ++i) {
3873 args.push_back(gutils->getNewFromOriginal(call.getArgOperand(i)));
3877 auto Defs = gutils->getInvertedBundles(&call, valtys, BuilderZ,
3880 val = applyChainRule(
3882 [&](Value *ptrshadow) {
3883 args[0] = ptrshadow;
3885 BuilderZ.CreateCall(called, args, Defs);
3886 if (!isa<PointerType>(ptrshadow->getType()))
3887 ptrshadow = BuilderZ.CreateIntToPtr(ptrshadow,
getUnqual(PT));
3888 Value *val = BuilderZ.CreateLoad(PT, ptrshadow);
3891 BuilderZ.CreateBitCast(val,
getInt8PtrTy(call.getContext()));
3894 ConstantInt::get(Type::getInt8Ty(call.getContext()), 0);
3895 auto len_arg = gutils->getNewFromOriginal(
3896 call.getArgOperand((funcName ==
"posix_memalign") ? 2 : 1));
3898 if (funcName ==
"posix_memalign" ||
3899 funcName ==
"cudaMallocHost") {
3900 BuilderZ.CreateMemSet(dst_arg, val_arg, len_arg, MaybeAlign());
3901 }
else if (funcName ==
"cudaMalloc") {
3902 Type *tys[] = {PT, val_arg->getType(), len_arg->getType()};
3903 auto F = M->getOrInsertFunction(
3905 FunctionType::get(call.getType(), tys,
false));
3906 Value *nargs[] = {dst_arg, val_arg, len_arg};
3907 auto memset = cast<CallInst>(BuilderZ.CreateCall(F, nargs));
3908 memset->addParamAttr(0, Attribute::NonNull);
3909 }
else if (funcName ==
"cudaMallocAsync" ||
3910 funcName ==
"cudaMallocFromPoolAsync") {
3911 Type *tys[] = {PT, val_arg->getType(), len_arg->getType(),
3913 auto F = M->getOrInsertFunction(
3915 FunctionType::get(call.getType(), tys,
false));
3916 Value *nargs[] = {dst_arg, val_arg, len_arg, stream};
3917 auto memset = cast<CallInst>(BuilderZ.CreateCall(F, nargs));
3918 memset->addParamAttr(0, Attribute::NonNull);
3919 }
else if (funcName ==
"cuMemAllocAsync") {
3920 Type *tys[] = {PT, val_arg->getType(), len_arg->getType(),
3922 auto F = M->getOrInsertFunction(
3924 FunctionType::get(call.getType(), tys,
false));
3925 Value *nargs[] = {dst_arg, val_arg, len_arg, stream};
3926 auto memset = cast<CallInst>(BuilderZ.CreateCall(F, nargs));
3927 memset->addParamAttr(0, Attribute::NonNull);
3928 }
else if (funcName ==
"cuMemAlloc" ||
3929 funcName ==
"cuMemAlloc_v2") {
3930 Type *tys[] = {PT, val_arg->getType(), len_arg->getType()};
3931 auto F = M->getOrInsertFunction(
3933 FunctionType::get(call.getType(), tys,
false));
3934 Value *nargs[] = {dst_arg, val_arg, len_arg};
3935 auto memset = cast<CallInst>(BuilderZ.CreateCall(F, nargs));
3936 memset->addParamAttr(0, Attribute::NonNull);
3938 llvm_unreachable(
"unhandled allocation");
3946 val = gutils->cacheForReverse(
3949 PHINode *toReplace = BuilderZ.CreatePHI(gutils->getShadowType(PT), 1,
3950 call.getName() +
"_psxtmp");
3951 val = gutils->cacheForReverse(
3958 IRBuilder<> Builder2(&call);
3959 getReverseBuilder(Builder2);
3960 Value *tofree = gutils->lookupM(val, Builder2, ValueToValueMapTy(),
3963 Type *VoidTy = Type::getVoidTy(M->getContext());
3966 Value *streamL =
nullptr;
3968 streamL = gutils->lookupM(stream, Builder2);
3972 [&](Value *tofree) {
3973 if (funcName ==
"posix_memalign") {
3975 M->getOrInsertFunction(
"free", VoidTy, IntPtrTy);
3976 Builder2.CreateCall(FreeFunc, tofree);
3977 }
else if (funcName ==
"cuMemAllocAsync") {
3978 auto FreeFunc = M->getOrInsertFunction(
3979 "cuMemFreeAsync", VoidTy, IntPtrTy, streamL->getType());
3980 Value *nargs[] = {tofree, streamL};
3981 Builder2.CreateCall(FreeFunc, nargs);
3982 }
else if (funcName ==
"cuMemAlloc" ||
3983 funcName ==
"cuMemAlloc_v2") {
3985 M->getOrInsertFunction(
"cuMemFree", VoidTy, IntPtrTy);
3986 Value *nargs[] = {tofree};
3987 Builder2.CreateCall(FreeFunc, nargs);
3988 }
else if (funcName ==
"cudaMalloc") {
3990 M->getOrInsertFunction(
"cudaFree", VoidTy, IntPtrTy);
3991 Value *nargs[] = {tofree};
3992 Builder2.CreateCall(FreeFunc, nargs);
3993 }
else if (funcName ==
"cudaMallocAsync" ||
3994 funcName ==
"cudaMallocFromPoolAsync") {
3995 auto FreeFunc = M->getOrInsertFunction(
3996 "cudaFreeAsync", VoidTy, IntPtrTy, streamL->getType());
3997 Value *nargs[] = {tofree, streamL};
3998 Builder2.CreateCall(FreeFunc, nargs);
3999 }
else if (funcName ==
"cudaMallocHost") {
4001 M->getOrInsertFunction(
"cudaFreeHost", VoidTy, IntPtrTy);
4002 Value *nargs[] = {tofree};
4003 Builder2.CreateCall(FreeFunc, nargs);
4005 llvm_unreachable(
"unknown function to free");
4016 eraseIfUnused(call,
true,
false);
4031 IRBuilder<> Builder2(newCall->getNextNode());
4032 auto ptrv = gutils->getNewFromOriginal(call.getOperand(0));
4033 if (!isa<PointerType>(ptrv->getType()))
4034 ptrv = BuilderZ.CreateIntToPtr(ptrv,
getUnqual(PT));
4035 auto load = Builder2.CreateLoad(PT, ptrv,
"posix_preread");
4036 Builder2.SetInsertPoint(&call);
4037 getReverseBuilder(Builder2);
4038 auto tofree = gutils->lookupM(load, Builder2, ValueToValueMapTy(),
4040 Value *streamL =
nullptr;
4041 if (funcName ==
"cuMemAllocAsync")
4042 streamL = gutils->getNewFromOriginal(call.getArgOperand(2));
4043 else if (funcName ==
"cudaMallocAsync")
4044 streamL = gutils->getNewFromOriginal(call.getArgOperand(2));
4045 else if (funcName ==
"cudaMallocFromPoolAsync")
4046 streamL = gutils->getNewFromOriginal(call.getArgOperand(3));
4048 streamL = gutils->lookupM(streamL, Builder2);
4050 auto M = gutils->newFunc->getParent();
4051 Type *VoidTy = Type::getVoidTy(M->getContext());
4054 if (funcName ==
"posix_memalign") {
4055 auto FreeFunc = M->getOrInsertFunction(
"free", VoidTy, IntPtrTy);
4056 Builder2.CreateCall(FreeFunc, tofree);
4057 }
else if (funcName ==
"cuMemAllocAsync") {
4058 auto FreeFunc = M->getOrInsertFunction(
"cuMemFreeAsync", VoidTy,
4059 IntPtrTy, streamL->getType());
4060 Value *nargs[] = {tofree, streamL};
4061 Builder2.CreateCall(FreeFunc, nargs);
4062 }
else if (funcName ==
"cuMemAlloc" || funcName ==
"cuMemAlloc_v2") {
4063 auto FreeFunc = M->getOrInsertFunction(
"cuMemFree", VoidTy, IntPtrTy);
4064 Value *nargs[] = {tofree};
4065 Builder2.CreateCall(FreeFunc, nargs);
4066 }
else if (funcName ==
"cudaMalloc") {
4067 auto FreeFunc = M->getOrInsertFunction(
"cudaFree", VoidTy, IntPtrTy);
4068 Value *nargs[] = {tofree};
4069 Builder2.CreateCall(FreeFunc, nargs);
4070 }
else if (funcName ==
"cudaMallocAsync" ||
4071 funcName ==
"cudaMallocFromPoolAsync") {
4072 auto FreeFunc = M->getOrInsertFunction(
"cudaFreeAsync", VoidTy,
4073 IntPtrTy, streamL->getType());
4074 Value *nargs[] = {tofree, streamL};
4075 Builder2.CreateCall(FreeFunc, nargs);
4076 }
else if (funcName ==
"cudaMallocHost") {
4078 M->getOrInsertFunction(
"cudaFreeHost", VoidTy, IntPtrTy);
4079 Value *nargs[] = {tofree};
4080 Builder2.CreateCall(FreeFunc, nargs);
4082 llvm_unreachable(
"unknown function to free");
4091 assert(gutils->invertedPointers.find(&call) ==
4092 gutils->invertedPointers.end());
4096 if (!gutils->isConstantValue(call.getArgOperand(0))) {
4097 IRBuilder<> Builder2(&call);
4098 getForwardBuilder(Builder2);
4099 auto origfree = call.getArgOperand(0);
4100 auto newfree = gutils->getNewFromOriginal(call.getArgOperand(0));
4101 auto tofree = gutils->invertPointerM(origfree, Builder2);
4104 *call.getModule(), &call, newfree->getType(), gutils->getWidth());
4107 if (
auto instArg = dyn_cast<Instruction>(call.getArgOperand(0)))
4108 used = unnecessaryInstructions.find(instArg) ==
4109 unnecessaryInstructions.end();
4111 SmallVector<Value *, 3> args;
4113 args.push_back(newfree);
4116 Constant::getNullValue(call.getArgOperand(0)->getType()));
4118 auto rule = [&args](Value *tofree) { args.push_back(tofree); };
4119 applyChainRule(Builder2, rule, tofree);
4121 for (
size_t i = 1; i < call.arg_size(); i++) {
4122 args.push_back(gutils->getNewFromOriginal(call.getArgOperand(i)));
4125 auto frees = Builder2.CreateCall(free->getFunctionType(), free, args);
4126 frees->setDebugLoc(gutils->getNewFromOriginal(call.getDebugLoc()));
4128 eraseIfUnused(call);
4131 eraseIfUnused(call);
4133 auto callval = call.getCalledOperand();
4135 for (
auto rmat : gutils->backwardsOnlyShadows) {
4136 if (rmat.second.frees.count(&call)) {
4137 bool shouldFree =
false;
4138 if (rmat.second.primalInitialize) {
4144 IRBuilder<> Builder2(&call);
4145 getForwardBuilder(Builder2);
4146 auto origfree = call.getArgOperand(0);
4147 auto tofree = gutils->invertPointerM(origfree, Builder2);
4148 if (tofree != origfree) {
4149 SmallVector<Value *, 2> args = {tofree};
4151 Builder2.CreateCall(call.getFunctionType(), callval, args);
4152 CI->setAttributes(call.getAttributes());
4160 for (
auto rmat : gutils->rematerializableAllocations) {
4161 if (rmat.second.frees.count(&call)) {
4165 eraseIfUnused(call);
4168 eraseIfUnused(call,
true,
false);
4172 std::map<UsageKey, bool> Seen;
4173 for (
auto pair : gutils->knownRecomputeHeuristic)
4176 bool primalNeededInReverse =
4180 bool cacheWholeAllocation =
4181 gutils->needsCacheWholeAllocation(rmat.first);
4182 if (cacheWholeAllocation) {
4183 primalNeededInReverse =
true;
4187 if (!cacheWholeAllocation) {
4188 eraseIfUnused(call);
4191 assert(!unnecessaryValues.count(rmat.first));
4192 (void)primalNeededInReverse;
4193 assert(primalNeededInReverse);
4198 if (gutils->forwardDeallocations.count(&call)) {
4200 eraseIfUnused(call,
true,
false);
4202 eraseIfUnused(call);
4206 if (gutils->postDominatingFrees.count(&call)) {
4207 eraseIfUnused(call,
true,
false);
4212 if (isa<ConstantPointerNull>(val)) {
4213 llvm::errs() <<
"removing free of null pointer\n";
4214 eraseIfUnused(call,
true,
false);
4219 llvm::errs() <<
"freeing without malloc " << *val <<
"\n";
4220 eraseIfUnused(call,
true,
false);
4224 if (call.hasFnAttr(
"enzyme_sample")) {
4229 bool constval = gutils->isConstantInstruction(&call);
4234 IRBuilder<> Builder2(&call);
4235 getReverseBuilder(Builder2);
4237 auto trace = call.getArgOperand(call.arg_size() - 1);
4238 auto address = call.getArgOperand(0);
4240 auto dtrace = lookup(gutils->getNewFromOriginal(trace), Builder2);
4241 auto daddress = lookup(gutils->getNewFromOriginal(address), Builder2);
4244 if (TR.query(&call)[{-1}].isPossiblePointer()) {
4245 dchoice = gutils->invertPointerM(&call, Builder2);
4247 dchoice = diffe(&call, Builder2);
4250 if (call.hasMetadata(
"enzyme_gradient_setter")) {
4251 auto gradient_setter = cast<Function>(
4252 cast<ValueAsMetadata>(
4253 call.getMetadata(
"enzyme_gradient_setter")->getOperand(0).get())
4257 Builder2, gradient_setter->getFunctionType(), gradient_setter,
4258 daddress, dchoice, dtrace);
4264 if (call.hasFnAttr(
"enzyme_insert_argument")) {
4265 IRBuilder<> Builder2(&call);
4266 getReverseBuilder(Builder2);
4268 auto name = call.getArgOperand(0);
4269 auto arg = call.getArgOperand(1);
4270 auto trace = call.getArgOperand(2);
4272 auto gradient_setter = cast<Function>(
4273 cast<ValueAsMetadata>(
4274 call.getMetadata(
"enzyme_gradient_setter")->getOperand(0).get())
4277 auto dtrace = lookup(gutils->getNewFromOriginal(trace), Builder2);
4278 auto dname = lookup(gutils->getNewFromOriginal(name), Builder2);
4281 if (TR.query(arg)[{-1}].isPossiblePointer()) {
4282 darg = gutils->invertPointerM(arg, Builder2);
4284 darg = diffe(arg, Builder2);
4288 gradient_setter->getFunctionType(),
4289 gradient_setter, dname, darg, dtrace);