31 llvm::Type *SRetType =
nullptr) {
32 auto Attrs = arg->getParent()->getAttributes();
37 .
getAttribute(AttributeList::FirstArgIndex + arg->getArgNo(),
43 if (tracked.
count == 0) {
47 bool hasReturnRootingAfterArg =
false;
48 for (
size_t i = arg->getArgNo() + 1; i < arg->getParent()->arg_size(); i++) {
49 if (Attrs.hasAttribute(AttributeList::FirstArgIndex + i,
50 "enzymejl_returnRoots")) {
51 hasReturnRootingAfterArg =
true;
57 if (!hasReturnRootingAfterArg) {
61 SmallVector<Value *> storedValues;
63 auto &DL = arg->getParent()->getParent()->getDataLayout();
64 SmallVector<size_t> sret_offsets;
66 std::deque<std::pair<llvm::Type *, std::vector<unsigned>>> todo = {
68 while (!todo.empty()) {
69 auto cur = std::move(todo[0]);
71 auto path = std::move(cur.second);
74 if (
auto PT = dyn_cast<PointerType>(ty)) {
78 SmallVector<Constant *, 1> IdxList;
80 ConstantInt::get(Type::getInt64Ty(PT->getContext()), 0));
84 ConstantInt::get(Type::getInt32Ty(PT->getContext()), v));
85 auto nullp = ConstantPointerNull::get(PointerType::getUnqual(SRetType));
86 auto gep = ConstantExpr::getGetElementPtr(SRetType, nullp, IdxList);
88 if (gep == ConstantPointerNull::get(PointerType::getUnqual(PT))) {
89 sret_offsets.push_back(0);
92#if LLVM_VERSION_MAJOR >= 20
93 SmallMapVector<Value *, APInt, 4> VariableOffsets;
95 MapVector<Value *, APInt> VariableOffsets;
97 auto width = DL.getPointerSize() * 8;
98 APInt Offset(width, 0);
99 bool success =
collectOffset(cast<GEPOperator>(gep), DL, width,
100 VariableOffsets, Offset);
102 llvm_unreachable(
"Illegal offset collection");
103 sret_offsets.push_back(Offset.getZExtValue());
107 if (
auto AT = dyn_cast<ArrayType>(ty)) {
108 for (
size_t i = 0; i < AT->getNumElements(); i++) {
109 std::vector<unsigned> path2(path);
111 todo.emplace_back(AT->getElementType(), path2);
116 if (
auto VT = dyn_cast<VectorType>(ty)) {
117 for (
size_t i = 0; i < VT->getElementCount().getKnownMinValue(); i++) {
118 std::vector<unsigned> path2(path);
120 todo.emplace_back(VT->getElementType(), path2);
125 if (
auto ST = dyn_cast<StructType>(ty)) {
126 for (
size_t i = 0; i < ST->getNumElements(); i++) {
127 std::vector<unsigned> path2(path);
129 todo.emplace_back(ST->getTypeAtIndex(i), path2);
138 assert(I->getParent()->getParent() == arg->getParent());
140 if (isa<ICmpInst>(I)) {
143 if (isa<LoadInst>(I)) {
146 if (
auto SI = dyn_cast<StoreInst>(I)) {
147 assert(SI->getValueOperand() != cur);
152 storedValues.push_back(SI->getValueOperand());
157 if (isa<MemSetInst>(I))
160 if (
auto MSI = dyn_cast<MemTransferInst>(I)) {
161 if (
auto Len = dyn_cast<ConstantInt>(MSI->getLength())) {
163 for (
auto offset : sret_offsets) {
164 if (byteOffset + Len->getSExtValue() <= offset)
166 if (offset + DL.getPointerSize() <= byteOffset)
177 llvm::raw_string_ostream ss(s);
178 ss <<
"Unknown user of sret-like argument\n";
180 wrap(cur), wrap(arg),
nullptr);
187 while (!storedValues.empty()) {
188 auto sv = storedValues.pop_back_val();
189 if (
auto I = dyn_cast<Instruction>(sv)) {
190 assert(I->getParent()->getParent() == arg->getParent());
192 bool foundUse =
false;
193 for (
auto &U : sv->uses()) {
194 if (
auto SI = dyn_cast<StoreInst>(U.getUser())) {
195 if (SI->getValueOperand() == sv) {
200 if (
auto evi = dyn_cast<ExtractValueInst>(base)) {
201 base = evi->getAggregateOperand();
203 if (
auto arg2 = dyn_cast<Argument>(base)) {
207 "enzymejl_returnRoots")
217 if (
auto IVI = dyn_cast<InsertValueInst>(sv)) {
219 IVI->getInsertedValueOperand()->getType());
220 if (tracked.
count == 0) {
221 storedValues.push_back(IVI->getAggregateOperand());
224 if (isa<UndefValue>(IVI->getAggregateOperand()) ||
225 isa<PoisonValue>(IVI->getAggregateOperand()) ||
226 isa<ConstantAggregateZero>(IVI->getAggregateOperand())) {
227 storedValues.push_back(IVI->getInsertedValueOperand());
230 storedValues.push_back(IVI->getAggregateOperand());
231 storedValues.push_back(IVI->getInsertedValueOperand());
234 if (
auto ST = dyn_cast<StructType>(sv->getType())) {
236 for (
size_t i = 0; i < ST->getNumElements(); i++) {
239 if (tracked.
count == 0) {
242 std::map<std::vector<unsigned>,
bool> paths_to_cover;
244 std::deque<std::pair<llvm::Type *, std::vector<unsigned>>> todo =
245 {{ST->getElementType(i), {}}};
246 while (!todo.empty()) {
247 auto cur = std::move(todo[0]);
249 auto path = std::move(cur.second);
252 if (
auto PT = dyn_cast<PointerType>(ty)) {
254 paths_to_cover[path] =
false;
259 if (
auto AT = dyn_cast<ArrayType>(ty)) {
260 for (
size_t k = 0; k < AT->getNumElements(); k++) {
261 std::vector<unsigned> path2(path);
263 todo.emplace_back(AT->getElementType(), path2);
268 if (
auto VT = dyn_cast<VectorType>(ty)) {
270 k < VT->getElementCount().getKnownMinValue(); k++) {
271 std::vector<unsigned> path2(path);
273 todo.emplace_back(VT->getElementType(), path2);
278 if (
auto ST2 = dyn_cast<StructType>(ty)) {
279 for (
size_t k = 0; k < ST2->getNumElements(); k++) {
280 std::vector<unsigned> path2(path);
282 todo.emplace_back(ST2->getTypeAtIndex(k), path2);
289 for (
auto u : sv->users()) {
290 if (
auto ev0 = dyn_cast<ExtractValueInst>(u)) {
291 if (ev0->getIndices()[0] == i) {
292 std::vector<unsigned> extract_path;
293 for (
size_t k = 1; k < ev0->getNumIndices(); ++k) {
294 extract_path.push_back(ev0->getIndices()[k]);
296 storedValues.push_back(ev0);
299 for (
auto &pair : paths_to_cover) {
300 const auto &p = pair.first;
302 if (extract_path.size() > p.size()) {
305 for (
size_t idx = 0; idx < extract_path.size(); ++idx) {
306 if (extract_path[idx] != p[idx]) {
320 bool fullyCovered =
true;
321 for (
const auto &pair : paths_to_cover) {
323 fullyCovered =
false;
329 llvm::errs() <<
" failed to find extracted pointer for " << *sv
330 <<
" at index " << i <<
"\n";
339 if (!isa<PointerType>(sv->getType()) ||
341 llvm::errs() <<
" sf: " << *arg->getParent() <<
"\n";
342 llvm::errs() <<
" arg: " << *arg <<
"\n";
343 llvm::errs() <<
"Pointer of wrong type: " << *sv <<
"\n";
348 bool saw_bitcast =
false;
349 for (
auto u : sv->users()) {
350 if (
auto ev0 = dyn_cast<CastInst>(u)) {
351 auto t2 = ev0->getType();
352 if (isa<PointerType>(t2) &&
isSpecialPtr(cast<PointerType>(t2))) {
354 storedValues.push_back(ev0);
363 if (hasReturnRootingAfterArg) {
365 llvm::raw_string_ostream ss(s);
366 ss <<
"Could not find use of stored value\n";
367 ss <<
" sv: " << *sv <<
"\n";
369 nullptr, wrap(arg),
nullptr);
540 auto RT = F->getReturnType();
541 std::set<size_t> srets;
542 std::set<size_t> enzyme_srets;
544 std::set<size_t> reroot_enzyme_srets;
546 std::set<size_t> noroot_enzyme_srets;
548 std::set<size_t> rroots;
550 std::set<size_t> reret_roots;
552 auto FT = F->getFunctionType();
553 auto Attrs = F->getAttributes();
555 std::map<size_t, size_t> selected_roots;
558 std::map<size_t, size_t> srets_without_stores;
560 for (
size_t i = 0, end = FT->getNumParams(); i < end; i++) {
561 if (Attrs.hasAttribute(AttributeList::FirstArgIndex + i,
562 Attribute::StructRet))
564 if (Attrs.hasAttribute(AttributeList::FirstArgIndex + i,
"enzyme_sret")) {
565 bool anyJLStore =
false;
566 enzyme_srets.insert(i);
570 reroot_enzyme_srets.insert(i);
571 }
else if (anyJLStore) {
577 Attrs.getAttribute(AttributeList::FirstArgIndex + i,
"enzyme_sret")
581 srets_without_stores[i] = count;
582 noroot_enzyme_srets.insert(i);
587 !Attrs.hasAttribute(AttributeList::FirstArgIndex + i,
"enzyme_sret_v"));
589 if (Attrs.hasAttribute(AttributeList::FirstArgIndex + i,
590 "enzymejl_returnRoots")) {
595 reret_roots.insert(i);
597 selected_roots[i] = sret_idx;
600 assert(!Attrs.hasAttribute(AttributeList::FirstArgIndex + i,
601 "enzymejl_returnRoots_v"));
605 if (srets.size() == 1) {
606 assert(*srets.begin() == 0);
607 assert(enzyme_srets.size() == 0);
608 llvm::Type *SRetType = F->getParamStructRetType(0);
612 if (!tracked.
count) {
616 bool anyJLStore =
false;
617 bool rerooting =
needsReRooting(F->getArg(0), anyJLStore, SRetType);
622 assert(rroots.size() == 1);
623 assert(*rroots.begin() == 1);
626#if LLVM_VERSION_MAJOR >= 16
629 llvm::raw_string_ostream ss(s);
630 ss <<
"Illegal GC setup in which rerooting is required\n";
631 ss <<
" + F: " << *F <<
"\n";
633 nullptr,
nullptr,
nullptr);
641 "enzymejl_returnRoots")
642 .getValueAsString());
644 assert(count == tracked.
count);
648 F->addParamAttr(0, Attribute::get(F->getContext(),
"enzyme_sret",
650 Attrs = F->getAttributes();
653 enzyme_srets.insert(i);
655 reroot_enzyme_srets.insert(i);
656 }
else if (anyJLStore) {
659 srets_without_stores[i] = count;
660 noroot_enzyme_srets.insert(i);
663 }
else if (srets.size() == 0 && enzyme_srets.size() == 0 &&
664 rroots.size() == 0) {
672 for (
auto &pair : srets_without_stores) {
674 reroot_enzyme_srets.insert(pair.first);
677 assert(srets.size() == 0);
679 SmallVector<Type *, 1> Types;
680 if (!RT->isVoidTy()) {
684 auto T_jlvalue = StructType::get(F->getContext(), {});
689 for (
auto idx : enzyme_srets) {
691 Attrs.getAttribute(AttributeList::FirstArgIndex + idx,
"enzyme_sret")
694#if LLVM_VERSION_MAJOR < 17
695 if (F->getContext().supportsTypedPointers()) {
696 auto T = FT->getParamType(idx)->getPointerElementType();
699 llvm::raw_string_ostream ss(s);
700 ss <<
"Type mismatch in FixupJuliaCallingConvention:\n";
701 ss <<
" + T: " << *T <<
"\n";
702 ss <<
" + SRetType: " << *SRetType <<
"\n";
703 EmitFailure(
"TypeMismatch", F->getSubprogram(), F, ss.str());
707 Types.push_back(SRetType);
708 if (reroot_enzyme_srets.count(idx)) {
712 for (
auto idx : rroots) {
716 "enzymejl_returnRoots")
717 .getValueAsString());
718 auto T = ArrayType::get(T_prjlvalue, count);
719#if LLVM_VERSION_MAJOR < 17
720 if (F->getContext().supportsTypedPointers()) {
721 auto NT = FT->getParamType(idx)->getPointerElementType();
725 if (reret_roots.count(idx)) {
732 Types.size() <= 1 ? nullptr : StructType::get(F->getContext(), Types);
733 Type *sretTy =
nullptr;
735 sretTy = Types.size() == 1 ? Types[0] : ST;
737 ArrayType *roots_AT =
738 numRooting ? ArrayType::get(T_prjlvalue, numRooting) :
nullptr;
748 reroot_enzyme_srets.clear();
749 }
else if (countF.
count) {
751 llvm::errs() <<
" sretTy: " << *sretTy <<
"\n";
752 llvm::errs() <<
" numRooting: " << numRooting <<
"\n";
753 llvm::errs() <<
" tracked.count: " << countF.
count <<
"\n";
756 if (numRooting != countF.
count) {
758 llvm::raw_string_ostream ss(s);
759 ss <<
"Illegal GC setup in which numRooting (" << numRooting
760 <<
") != tracked.count (" << countF.
count <<
")\n";
761 ss <<
" sretTy: " << *sretTy <<
"\n";
762 ss <<
" Types.size(): " << Types.size() <<
"\n";
763 for (
size_t i = 0; i < Types.size(); i++) {
764 ss <<
" + Types[" << i <<
"] = " << *Types[i] <<
"\n";
766 ss <<
" F: " << *F <<
"\n";
768 nullptr,
nullptr,
nullptr);
770 assert(numRooting == countF.
count);
774 AttributeList NewAttrs;
775 SmallVector<Type *, 1> types;
779 NewAttrs = NewAttrs.addAttribute(
780 F->getContext(), AttributeList::FirstArgIndex + nexti,
781 Attribute::get(F->getContext(), Attribute::StructRet, sretTy));
782 NewAttrs = NewAttrs.addAttribute(F->getContext(),
783 AttributeList::FirstArgIndex + nexti,
788 NewAttrs = NewAttrs.addAttribute(
789 F->getContext(), AttributeList::FirstArgIndex + nexti,
790 "enzymejl_returnRoots", std::to_string(numRooting));
791 NewAttrs = NewAttrs.addAttribute(F->getContext(),
792 AttributeList::FirstArgIndex + nexti,
794 NewAttrs = NewAttrs.addAttribute(F->getContext(),
795 AttributeList::FirstArgIndex + nexti,
796 Attribute::WriteOnly);
801 for (
size_t i = 0, end = FT->getNumParams(); i < end; i++) {
802 if (enzyme_srets.count(i) || rroots.count(i))
805 for (
auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + i)) {
806 NewAttrs = NewAttrs.addAttribute(
807 F->getContext(), AttributeList::FirstArgIndex + nexti, attr);
809 types.push_back(F->getFunctionType()->getParamType(i));
813 for (
auto attr : Attrs.getAttributes(AttributeList::FunctionIndex))
814 NewAttrs = NewAttrs.addAttribute(F->getContext(),
815 AttributeList::FunctionIndex, attr);
817 FunctionType *FTy = FunctionType::get(Type::getVoidTy(F->getContext()), types,
821 auto &M = *F->getParent();
822 Function *NewF = Function::Create(FTy, F->getLinkage(), F->getAddressSpace(),
825 ValueToValueMapTy VMap;
827 Function::arg_iterator DestI = NewF->arg_begin();
828 Argument *sret =
nullptr;
833 Argument *roots =
nullptr;
841 std::map<size_t, PHINode *> delArgMap;
842 for (Argument &I : F->args()) {
843 auto i = I.getArgNo();
844 if (enzyme_srets.count(i) || rroots.count(i)) {
845 VMap[&I] = delArgMap[i] = PHINode::Create(I.getType(), 0);
848 assert(DestI != NewF->arg_end());
849 DestI->setName(I.getName());
850 VMap[&I] = &*DestI++;
854 SmallPtrSet<Function *, 1> calls_todo;
857 SmallVector<ReturnInst *, 8> Returns;
858 CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
859 Returns,
"",
nullptr);
861 SmallVector<CallInst *, 1> callers;
862 for (
auto U : F->users()) {
863 auto CI = dyn_cast<CallInst>(U);
865 assert(CI->getCalledFunction() == F);
866 callers.push_back(CI);
870 size_t curOffset = 0;
871 size_t sretCount = 0;
872 if (!RT->isVoidTy()) {
873 for (
auto &RT : Returns) {
875 Value *gep = ST ? B.CreateConstInBoundsGEP2_32(ST, sret, 0, 0) : sret;
876 Value *rval = RT->getReturnValue();
877 B.CreateStore(rval, gep);
885 auto NR = B.CreateRetVoid();
886 RT->eraseFromParent();
897 for (
size_t i = 0, end = FT->getNumParams(); i < end; i++) {
899 if (enzyme_srets.count(i)) {
900 auto argFound = delArgMap.find(i);
901 assert(argFound != delArgMap.end());
902 auto arg = argFound->second;
904 SmallVector<Instruction *, 1> uses;
905 SmallVector<unsigned, 1> op;
906 for (
auto &U : arg->uses()) {
907 auto I = cast<Instruction>(U.getUser());
909 op.push_back(U.getOperandNo());
911 IRBuilder<> EB(&NewF->getEntryBlock().front());
913 ST ? EB.CreateConstInBoundsGEP2_32(ST, sret, 0, sretCount) : sret;
914 for (
size_t i = 0; i < uses.size(); i++) {
915 uses[i]->setOperand(op[i], gep);
918 if (reroot_enzyme_srets.count(i)) {
921 for (
auto &RT : Returns) {
923 if (noroot_enzyme_srets.count(i)) {
924 for (
size_t i = 0; i < cnt; i++) {
925 B.CreateStore(ConstantPointerNull::get(T_prjlvalue),
926 B.CreateConstInBoundsGEP2_32(roots_AT, roots, 0,
944 if (rroots.count(i)) {
945 auto attr = Attrs.getAttribute(AttributeList::FirstArgIndex + i,
946 "enzymejl_returnRoots");
947 auto attrv = attr.getValueAsString();
948 assert(attrv.size());
951 auto argFound = delArgMap.find(i);
952 assert(argFound != delArgMap.end());
953 auto arg = argFound->second;
955 SmallVector<Instruction *, 1> uses;
956 SmallVector<unsigned, 1> op;
957 for (
auto &U : arg->uses()) {
958 auto I = cast<Instruction>(U.getUser());
960 op.push_back(U.getOperandNo());
962 IRBuilder<> EB(&NewF->getEntryBlock().front());
964 Value *gep =
nullptr;
970 if (curOffset != 0) {
971 gep = EB.CreateConstInBoundsGEP2_32(roots_AT, roots, 0, curOffset);
973 if (subCount != numRooting) {
974 gep = EB.CreatePointerCast(
975 gep,
getUnqual(ArrayType::get(T_prjlvalue, subCount)));
977 curOffset += subCount;
978 if (reret_roots.count(i))
983 ST ? EB.CreateConstInBoundsGEP2_32(ST, sret, 0, sretCount) : sret;
985 if (!reret_roots.count(i)) {
988 llvm::raw_string_ostream ss(s);
989 ss <<
"Illegal GC setup in which there was no roots_AT, but a new "
991 << *sret <<
"), but no rereturned roots at index i=" << i
994 nullptr,
nullptr,
nullptr);
1000 for (
size_t i = 0; i < uses.size(); i++) {
1001 uses[i]->setOperand(op[i], gep);
1009 assert(curOffset == numRooting);
1010 assert(sretCount == Types.size());
1013 auto &DL = F->getParent()->getDataLayout();
1016 for (
auto CI : callers) {
1017 auto Attrs = CI->getAttributes();
1018 AttributeList NewAttrs;
1020 IRBuilder<> EB(&CI->getParent()->getParent()->getEntryBlock().front());
1021 SmallVector<Value *, 1> vals;
1023 Value *sret =
nullptr;
1025 sret = EB.CreateAlloca(sretTy, 0,
"stack_sret");
1026 vals.push_back(sret);
1027 NewAttrs = NewAttrs.addAttribute(
1028 F->getContext(), AttributeList::FirstArgIndex + nexti,
1029 Attribute::get(F->getContext(), Attribute::StructRet, sretTy));
1032 AllocaInst *roots =
nullptr;
1034 roots = EB.CreateAlloca(roots_AT, 0,
"stack_roots_AT");
1035 vals.push_back(roots);
1036 NewAttrs = NewAttrs.addAttribute(
1038 F->getContext(), AttributeList::FirstArgIndex + nexti,
1039 "enzymejl_returnRoots", std::to_string(numRooting));
1043 for (
auto attr : Attrs.getAttributes(AttributeList::FunctionIndex))
1044 NewAttrs = NewAttrs.addAttribute(F->getContext(),
1045 AttributeList::FunctionIndex, attr);
1047 SmallVector<std::tuple<Value *, Value *, Type *>> preCallReplacements;
1048 SmallVector<std::tuple<Value *, Value *, Type *, bool>>
1049 postCallReplacements;
1052 size_t local_root_count = 0;
1053 size_t sretCount = 0;
1054 if (!RT->isVoidTy()) {
1062 for (
size_t i = 0, end = CI->arg_size(); i < end; i++) {
1064 if (enzyme_srets.count(i)) {
1065 auto val = CI->getArgOperand(i);
1067 if (isa<UndefValue>(val) || isa<PoisonValue>(val) ||
1068 isa<ConstantPointerNull>(val)) {
1070 llvm::raw_string_ostream ss(s);
1071 ss <<
"Unsupported constant argument in "
1072 "FixupJuliaCallingConvention\n";
1073 ss <<
" + val: " << *val <<
"\n";
1074 ss <<
" + Function being rewritten: " << F->getName() <<
"\n";
1075 ss <<
" + CI erring: " << *CI <<
"\n";
1076 ss <<
" + Function containing CI: "
1077 << CI->getParent()->getParent()->getName() <<
"\n";
1080 nullptr,
nullptr,
nullptr);
1082 EmitFailure(
"UnsupportedArgument", CI->getDebugLoc(), CI,
1089 IRBuilder<> GEPB(cast<Instruction>(sret)->getNextNode());
1090 gep = GEPB.CreateConstInBoundsGEP2_32(ST, sret, 0, sretCount);
1093 bool handled =
false;
1094 if (
auto AI = dyn_cast<AllocaInst>(
getBaseObject(val,
false))) {
1095 if (AI->getAllocatedType() == Types[sretCount] ||
1097 DL.getTypeSizeInBits(AI->getAllocatedType()) ==
1098 DL.getTypeSizeInBits(Types[sretCount]))) {
1099 AI->replaceAllUsesWith(gep);
1100 AI->eraseFromParent();
1106 assert(!isa<UndefValue>(val));
1107 assert(!isa<PoisonValue>(val));
1108 assert(!isa<ConstantPointerNull>(val));
1114 bool should_sret = sret_jlvalue;
1125 postCallReplacements.emplace_back(val, gep, Types[sretCount],
1133 preCallReplacements.emplace_back(val, gep, Types[sretCount]);
1137 if (roots_AT && reroot_enzyme_srets.count(i)) {
1145 if (rroots.count(i)) {
1146 auto val = CI->getArgOperand(i);
1147 if (isa<UndefValue>(val) || isa<PoisonValue>(val) ||
1148 isa<ConstantPointerNull>(val)) {
1150 llvm::raw_string_ostream ss(s);
1151 ss <<
"Unsupported constant argument in "
1152 "FixupJuliaCallingConvention\n";
1153 ss <<
" + val: " << *val <<
"\n";
1154 ss <<
" + Function being rewritten: " << F->getName() <<
"\n";
1155 ss <<
" + CI erring: " << *CI <<
"\n";
1156 ss <<
" + Function containing CI: "
1157 << CI->getParent()->getParent()->getName() <<
"\n";
1160 nullptr,
nullptr,
nullptr);
1162 EmitFailure(
"UnsupportedArgument", CI->getDebugLoc(), CI,
1167 auto attr = Attrs.getAttribute(AttributeList::FirstArgIndex + i,
1168 "enzymejl_returnRoots");
1169 auto attrv = attr.getValueAsString();
1170 assert(attrv.size());
1173 Value *gep =
nullptr;
1177 IRBuilder<> GEPB(cast<Instruction>(roots)->getNextNode());
1179 if (local_root_count != 0) {
1180 gep = GEPB.CreateConstInBoundsGEP2_32(roots_AT, roots, 0,
1184 if (subCount != numRooting) {
1185 gep = GEPB.CreatePointerCast(
1186 gep,
getUnqual(ArrayType::get(T_prjlvalue, subCount)));
1188 local_root_count += subCount;
1189 if (reret_roots.count(i))
1192 assert(reret_roots.count(i));
1194 IRBuilder<> GEPB(cast<Instruction>(sret)->getNextNode());
1197 gep = GEPB.CreateConstInBoundsGEP2_32(ST, sret, 0, sretCount);
1202 bool handled =
false;
1203 if (
auto AI = dyn_cast<AllocaInst>(
getBaseObject(val,
false))) {
1204 if (AI->getAllocatedType() ==
1205 ArrayType::get(T_prjlvalue, subCount)) {
1206 AI->replaceAllUsesWith(gep);
1207 AI->eraseFromParent();
1213 assert(!isa<UndefValue>(val));
1214 assert(!isa<PoisonValue>(val));
1215 assert(!isa<ConstantPointerNull>(val));
1217 preCallReplacements.emplace_back(
1218 val, gep, ArrayType::get(T_prjlvalue, subCount));
1219 postCallReplacements.emplace_back(
1220 val, gep, ArrayType::get(T_prjlvalue, subCount),
true);
1225 for (
auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + i))
1226 NewAttrs = NewAttrs.addAttribute(
1227 F->getContext(), AttributeList::FirstArgIndex + nexti, attr);
1228 vals.push_back(CI->getArgOperand(i));
1232 assert(sretCount == Types.size());
1233 assert(local_root_count == numRooting);
1240 for (
auto &&[val, gep, ty] : preCallReplacements) {
1245 SmallVector<OperandBundleDef, 1> Bundles;
1246 for (
unsigned I = 0, E = CI->getNumOperandBundles(); I != E; ++I)
1247 Bundles.emplace_back(CI->getOperandBundleAt(I));
1249 if (!NewF->getFunctionType()->isVarArg() &&
1250 NewF->getFunctionType()->getNumParams() != vals.size()) {
1251 llvm::errs() <<
"NewF: " << *NewF <<
"\n";
1252 for (
size_t i = 0; i < vals.size(); i++) {
1253 llvm::errs() <<
" Args[" << i <<
"] = " << *vals[i] <<
"\n";
1256 auto NC = B.CreateCall(NewF, vals, Bundles);
1257 NC->setAttributes(NewAttrs);
1259 SmallVector<std::pair<unsigned, MDNode *>, 4> TheMDs;
1260 CI->getAllMetadataOtherThanDebugLoc(TheMDs);
1261 SmallVector<unsigned, 1> toCopy;
1262 for (
auto pair : TheMDs)
1263 if (pair.first != LLVMContext::MD_range) {
1264 toCopy.push_back(pair.first);
1266 if (!toCopy.empty())
1267 NC->copyMetadata(*CI, toCopy);
1268 NC->setDebugLoc(CI->getDebugLoc());
1270 if (!RT->isVoidTy()) {
1271 auto gep = ST ? B.CreateConstInBoundsGEP2_32(ST, sret, 0, 0) : sret;
1272 auto ld = B.CreateLoad(RT, gep);
1273 if (
auto MD = CI->getMetadata(LLVMContext::MD_range))
1274 ld->setMetadata(LLVMContext::MD_range, MD);
1276 Value *replacement = ld;
1286 CI->replaceAllUsesWith(replacement);
1289 for (
auto &&[val, gep, ty, jlvalue] : postCallReplacements) {
1291 auto ld = B.CreateLoad(ty, gep);
1292 auto SI = B.CreateStore(ld, val);
1293 if (val->getType()->getPointerAddressSpace() == 10)
1301 NC->setCallingConv(CI->getCallingConv());
1302 CI->eraseFromParent();
1304 NewF->setAttributes(NewAttrs);
1305 SmallVector<std::pair<unsigned, MDNode *>, 1> MD;
1306 F->getAllMetadata(MD);
1307 for (
auto pair : MD)
1308 if (pair.first != LLVMContext::MD_dbg)
1309 NewF->addMetadata(pair.first, *pair.second);
1311 NewF->setCallingConv(F->getCallingConv());
1312 F->eraseFromParent();
1324 auto RT = F->getReturnType();
1325 auto FT = F->getFunctionType();
1326 auto Attrs = F->getAttributes();
1328 AttributeList NewAttrs;
1329 SmallVector<Type *, 1> types;
1330 SmallSet<size_t, 1> changed;
1331 for (
auto pair : llvm::enumerate(FT->params())) {
1332 auto T = pair.value();
1333 auto i = pair.index();
1337 for (
auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + i)) {
1338 if (attr.isStringAttribute() &&
1339 attr.getKindAsString() ==
"enzyme_sret_v") {
1341 kind =
"enzyme_sret";
1342 value = attr.getValueAsString();
1343 }
else if (attr.isStringAttribute() &&
1344 attr.getKindAsString() ==
"enzymejl_rooted_typ_v") {
1346 kind =
"enzymejl_rooted_typ";
1347 value = attr.getValueAsString();
1348 }
else if (attr.isStringAttribute() &&
1349 attr.getKindAsString() ==
"enzymejl_returnRoots_v") {
1351 kind =
"enzymejl_returnRoots";
1352 value = attr.getValueAsString();
1354 NewAttrs = NewAttrs.addAttribute(
1355 F->getContext(), AttributeList::FirstArgIndex + types.size(), attr);
1358 if (
auto AT = dyn_cast<ArrayType>(T)) {
1359 if (
auto PT = dyn_cast<PointerType>(AT->getElementType())) {
1360 auto AS = PT->getAddressSpace();
1361 if (AS == 11 || AS == 12 || AS == 13 || sretv) {
1362 for (
unsigned i = 0; i < AT->getNumElements(); i++) {
1364 NewAttrs = NewAttrs.addAttribute(
1365 F->getContext(), AttributeList::FirstArgIndex + types.size(),
1366 Attribute::get(F->getContext(), kind, value));
1368 types.push_back(PT);
1378 if (changed.size() == 0)
1381 for (
auto attr : Attrs.getAttributes(AttributeList::FunctionIndex))
1382 NewAttrs = NewAttrs.addAttribute(F->getContext(),
1383 AttributeList::FunctionIndex, attr);
1385 for (
auto attr : Attrs.getAttributes(AttributeList::ReturnIndex))
1386 NewAttrs = NewAttrs.addAttribute(F->getContext(),
1387 AttributeList::ReturnIndex, attr);
1390 FunctionType::get(FT->getReturnType(), types, FT->isVarArg());
1393 Function *NewF = Function::Create(FTy, F->getLinkage(), F->getAddressSpace(),
1394 F->getName(), F->getParent());
1396 ValueToValueMapTy VMap;
1398 Function::arg_iterator DestI = NewF->arg_begin();
1402 SmallVector<Instruction *, 1> toInsert;
1403 for (Argument &I : F->args()) {
1404 auto T = I.getType();
1405 if (
auto AT = dyn_cast<ArrayType>(T)) {
1406 if (changed.count(I.getArgNo())) {
1407 Value *V = UndefValue::get(T);
1408 for (
unsigned i = 0; i < AT->getNumElements(); i++) {
1409 DestI->setName(I.getName() +
"." +
1411 unsigned idx[1] = {i};
1412 auto IV = InsertValueInst::Create(V, (llvm::Value *)&*DestI++, idx);
1413 toInsert.push_back(IV);
1420 DestI->setName(I.getName());
1421 VMap[&I] = &*DestI++;
1424 SmallVector<ReturnInst *, 8> Returns;
1425 CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
1426 Returns,
"",
nullptr);
1429 IRBuilder<> EB(&*NewF->getEntryBlock().begin());
1430 for (
auto I : toInsert)
1434 SmallVector<CallInst *, 1> callers;
1435 for (
auto U : F->users()) {
1436 auto CI = dyn_cast<CallInst>(U);
1438 assert(CI->getCalledFunction() == F);
1439 callers.push_back(CI);
1442 for (
auto CI : callers) {
1443 auto Attrs = CI->getAttributes();
1444 AttributeList NewAttrs;
1447 for (
auto attr : Attrs.getAttributes(AttributeList::FunctionIndex))
1448 NewAttrs = NewAttrs.addAttribute(F->getContext(),
1449 AttributeList::FunctionIndex, attr);
1451 for (
auto attr : Attrs.getAttributes(AttributeList::ReturnIndex))
1452 NewAttrs = NewAttrs.addAttribute(F->getContext(),
1453 AttributeList::ReturnIndex, attr);
1455 SmallVector<Value *, 1> vals;
1456 for (
size_t j = 0, end = CI->arg_size(); j < end; j++) {
1458 auto T = CI->getArgOperand(j)->getType();
1459 if (
auto AT = dyn_cast<ArrayType>(T)) {
1460 if (isa<PointerType>(AT->getElementType())) {
1461 if (changed.count(j)) {
1466 Attrs.getAttributes(AttributeList::FirstArgIndex + j)) {
1467 if (attr.isStringAttribute() &&
1468 attr.getKindAsString() ==
"enzyme_sret_v") {
1470 kind =
"enzyme_sret";
1471 value = attr.getValueAsString();
1472 }
else if (attr.isStringAttribute() &&
1473 attr.getKindAsString() ==
"enzymejl_returnRoots_v") {
1475 kind =
"enzymejl_returnRoots";
1476 value = attr.getValueAsString();
1477 }
else if (attr.isStringAttribute() &&
1478 attr.getKindAsString() ==
"enzymejl_rooted_typ_v") {
1480 kind =
"enzymejl_rooted_typ_v";
1481 value = attr.getValueAsString();
1484 for (
unsigned i = 0; i < AT->getNumElements(); i++) {
1486 NewAttrs = NewAttrs.addAttribute(
1487 F->getContext(), AttributeList::FirstArgIndex + vals.size(),
1488 Attribute::get(F->getContext(), kind, value));
1497 for (
auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + j)) {
1498 if (attr.isStringAttribute() &&
1499 attr.getKindAsString() ==
"enzyme_sret_v") {
1500 NewAttrs = NewAttrs.addAttribute(
1501 F->getContext(), AttributeList::FirstArgIndex + vals.size(),
1502 Attribute::get(F->getContext(),
"enzyme_sret",
1503 attr.getValueAsString()));
1504 }
else if (attr.isStringAttribute() &&
1505 attr.getKindAsString() ==
"enzymejl_returnRoots_v") {
1506 NewAttrs = NewAttrs.addAttribute(
1507 F->getContext(), AttributeList::FirstArgIndex + vals.size(),
1508 Attribute::get(F->getContext(),
"enzymejl_returnRoots",
1509 attr.getValueAsString()));
1510 }
else if (attr.isStringAttribute() &&
1511 attr.getKindAsString() ==
"enzymejl_rooted_typ_v") {
1512 NewAttrs = NewAttrs.addAttribute(
1513 F->getContext(), AttributeList::FirstArgIndex + vals.size(),
1514 Attribute::get(F->getContext(),
"enzymejl_rooted_typ",
1515 attr.getValueAsString()));
1517 NewAttrs = NewAttrs.addAttribute(
1518 F->getContext(), AttributeList::FirstArgIndex + vals.size(),
1523 vals.push_back(CI->getArgOperand(j));
1526 SmallVector<OperandBundleDef, 1> Bundles;
1527 for (
unsigned I = 0, E = CI->getNumOperandBundles(); I != E; ++I)
1528 Bundles.emplace_back(CI->getOperandBundleAt(I));
1529 auto NC = B.CreateCall(NewF, vals, Bundles);
1530 NC->setAttributes(NewAttrs);
1532 SmallVector<std::pair<unsigned, MDNode *>, 4> TheMDs;
1533 CI->getAllMetadataOtherThanDebugLoc(TheMDs);
1534 SmallVector<unsigned, 1> toCopy;
1535 for (
auto pair : TheMDs)
1536 toCopy.push_back(pair.first);
1537 if (!toCopy.empty())
1538 NC->copyMetadata(*CI, toCopy);
1539 NC->setDebugLoc(CI->getDebugLoc());
1541 if (!RT->isVoidTy()) {
1543 CI->replaceAllUsesWith(NC);
1546 NC->setCallingConv(CI->getCallingConv());
1547 CI->eraseFromParent();
1549 NewF->setAttributes(NewAttrs);
1550 SmallVector<std::pair<unsigned, MDNode *>, 1> MD;
1551 F->getAllMetadata(MD);
1552 for (
auto pair : MD)
1553 if (pair.first != LLVMContext::MD_dbg)
1554 NewF->addMetadata(pair.first, *pair.second);
1556 NewF->setCallingConv(F->getCallingConv());
1557 F->eraseFromParent();