Enzyme main
Loading...
Searching...
No Matches
FixupJuliaCallingConvention.cpp
Go to the documentation of this file.
1#include "CApi.h"
2#include "FunctionUtils.h"
3#include "GradientUtils.h"
4#include "Utils.h"
5#include "llvm/ADT/SmallSet.h"
6#include "llvm/IR/Function.h"
7#include "llvm/IR/IRBuilder.h"
8#include "llvm/IR/Instructions.h"
9#include "llvm/IR/LegacyPassManager.h"
10#include "llvm/IR/Module.h"
11#include "llvm/Pass.h"
12#include "llvm/Transforms/Utils/BasicBlockUtils.h"
13#include "llvm/Transforms/Utils/Cloning.h"
14
15#define addAttribute addAttributeAtIndex
16#define removeAttribute removeAttributeAtIndex
17#define getAttribute getAttributeAtIndex
18#define hasAttribute hasAttributeAtIndex
19
20#if LLVM_VERSION_MAJOR >= 17
21#include "llvm/TargetParser/Triple.h"
22#endif
23
24using namespace llvm;
25
26extern bool
27DetectPointerArgOfFn(llvm::Function &F,
28 llvm::SmallPtrSetImpl<llvm::Function *> &calls_todo);
29
30bool needsReRooting(llvm::Argument *arg, bool &anyJLStore,
31 llvm::Type *SRetType = nullptr) {
32 auto Attrs = arg->getParent()->getAttributes();
33
34 if (!SRetType)
36 Attrs
37 .getAttribute(AttributeList::FirstArgIndex + arg->getArgNo(),
38 "enzyme_sret")
39 .getValueAsString(),
40 &arg->getContext());
41
42 CountTrackedPointers tracked(SRetType);
43 if (tracked.count == 0) {
44 return false;
45 }
46
47 bool hasReturnRootingAfterArg = false;
48 for (size_t i = arg->getArgNo() + 1; i < arg->getParent()->arg_size(); i++) {
49 if (Attrs.hasAttribute(AttributeList::FirstArgIndex + i,
50 "enzymejl_returnRoots")) {
51 hasReturnRootingAfterArg = true;
52 break;
53 }
54 }
55
56 // If there is no returnRoots, we _must_ reroot the arg.
57 if (!hasReturnRootingAfterArg) {
58 return true;
59 }
60
61 SmallVector<Value *> storedValues;
62
63 auto &DL = arg->getParent()->getParent()->getDataLayout();
64 SmallVector<size_t> sret_offsets;
65 {
66 std::deque<std::pair<llvm::Type *, std::vector<unsigned>>> todo = {
67 {SRetType, {}}};
68 while (!todo.empty()) {
69 auto cur = std::move(todo[0]);
70 todo.pop_front();
71 auto path = std::move(cur.second);
72 auto ty = cur.first;
73
74 if (auto PT = dyn_cast<PointerType>(ty)) {
75 if (!isSpecialPtr(PT))
76 continue;
77
78 SmallVector<Constant *, 1> IdxList;
79 IdxList.push_back(
80 ConstantInt::get(Type::getInt64Ty(PT->getContext()), 0));
81
82 for (auto v : path)
83 IdxList.push_back(
84 ConstantInt::get(Type::getInt32Ty(PT->getContext()), v));
85 auto nullp = ConstantPointerNull::get(PointerType::getUnqual(SRetType));
86 auto gep = ConstantExpr::getGetElementPtr(SRetType, nullp, IdxList);
87
88 if (gep == ConstantPointerNull::get(PointerType::getUnqual(PT))) {
89 sret_offsets.push_back(0);
90 continue;
91 }
92#if LLVM_VERSION_MAJOR >= 20
93 SmallMapVector<Value *, APInt, 4> VariableOffsets;
94#else
95 MapVector<Value *, APInt> VariableOffsets;
96#endif
97 auto width = DL.getPointerSize() * 8;
98 APInt Offset(width, 0);
99 bool success = collectOffset(cast<GEPOperator>(gep), DL, width,
100 VariableOffsets, Offset);
101 if (!success)
102 llvm_unreachable("Illegal offset collection");
103 sret_offsets.push_back(Offset.getZExtValue());
104 continue;
105 }
106
107 if (auto AT = dyn_cast<ArrayType>(ty)) {
108 for (size_t i = 0; i < AT->getNumElements(); i++) {
109 std::vector<unsigned> path2(path);
110 path2.push_back(i);
111 todo.emplace_back(AT->getElementType(), path2);
112 }
113 continue;
114 }
115
116 if (auto VT = dyn_cast<VectorType>(ty)) {
117 for (size_t i = 0; i < VT->getElementCount().getKnownMinValue(); i++) {
118 std::vector<unsigned> path2(path);
119 path2.push_back(i);
120 todo.emplace_back(VT->getElementType(), path2);
121 }
122 continue;
123 }
124
125 if (auto ST = dyn_cast<StructType>(ty)) {
126 for (size_t i = 0; i < ST->getNumElements(); i++) {
127 std::vector<unsigned> path2(path);
128 path2.push_back(i);
129 todo.emplace_back(ST->getTypeAtIndex(i), path2);
130 }
131 continue;
132 }
133 }
134 }
135
136 bool legal = true;
137 for (auto &&[I, cur, byteOffset] : findAllUsersOf(arg)) {
138 assert(I->getParent()->getParent() == arg->getParent());
139
140 if (isa<ICmpInst>(I)) {
141 continue;
142 }
143 if (isa<LoadInst>(I)) {
144 continue;
145 }
146 if (auto SI = dyn_cast<StoreInst>(I)) {
147 assert(SI->getValueOperand() != cur);
148
149 if (CountTrackedPointers(SI->getValueOperand()->getType()).count == 0)
150 continue;
151
152 storedValues.push_back(SI->getValueOperand());
153 anyJLStore = true;
154 continue;
155 }
156
157 if (isa<MemSetInst>(I))
158 continue;
159
160 if (auto MSI = dyn_cast<MemTransferInst>(I)) {
161 if (auto Len = dyn_cast<ConstantInt>(MSI->getLength())) {
162 bool mlegal = true;
163 for (auto offset : sret_offsets) {
164 if (byteOffset + Len->getSExtValue() <= offset)
165 continue;
166 if (offset + DL.getPointerSize() <= byteOffset)
167 continue;
168 mlegal = false;
169 break;
170 }
171 if (mlegal)
172 break;
173 }
174 }
175
176 std::string s;
177 llvm::raw_string_ostream ss(s);
178 ss << "Unknown user of sret-like argument\n";
179 CustomErrorHandler(ss.str().c_str(), wrap(I), ErrorType::GCRewrite,
180 wrap(cur), wrap(arg), nullptr);
181 legal = false;
182 anyJLStore = true;
183 break;
184 }
185
186 if (legal) {
187 while (!storedValues.empty()) {
188 auto sv = storedValues.pop_back_val();
189 if (auto I = dyn_cast<Instruction>(sv)) {
190 assert(I->getParent()->getParent() == arg->getParent());
191 }
192 bool foundUse = false;
193 for (auto &U : sv->uses()) {
194 if (auto SI = dyn_cast<StoreInst>(U.getUser())) {
195 if (SI->getValueOperand() == sv) {
196 auto base = getBaseObject(SI->getPointerOperand());
197 if (base == arg) {
198 continue;
199 }
200 if (auto evi = dyn_cast<ExtractValueInst>(base)) {
201 base = evi->getAggregateOperand();
202 }
203 if (auto arg2 = dyn_cast<Argument>(base)) {
204 if (Attrs
205 .getAttribute(AttributeList::FirstArgIndex +
206 arg2->getArgNo(),
207 "enzymejl_returnRoots")
208 .isValid()) {
209 foundUse = true;
210 break;
211 }
212 }
213 }
214 }
215 }
216 if (!foundUse) {
217 if (auto IVI = dyn_cast<InsertValueInst>(sv)) {
218 CountTrackedPointers tracked(
219 IVI->getInsertedValueOperand()->getType());
220 if (tracked.count == 0) {
221 storedValues.push_back(IVI->getAggregateOperand());
222 continue;
223 }
224 if (isa<UndefValue>(IVI->getAggregateOperand()) ||
225 isa<PoisonValue>(IVI->getAggregateOperand()) ||
226 isa<ConstantAggregateZero>(IVI->getAggregateOperand())) {
227 storedValues.push_back(IVI->getInsertedValueOperand());
228 continue;
229 }
230 storedValues.push_back(IVI->getAggregateOperand());
231 storedValues.push_back(IVI->getInsertedValueOperand());
232 continue;
233 }
234 if (auto ST = dyn_cast<StructType>(sv->getType())) {
235 bool legal = true;
236 for (size_t i = 0; i < ST->getNumElements(); i++) {
237
238 CountTrackedPointers tracked(ST->getElementType(i));
239 if (tracked.count == 0) {
240 continue;
241 }
242 std::map<std::vector<unsigned>, bool> paths_to_cover;
243 {
244 std::deque<std::pair<llvm::Type *, std::vector<unsigned>>> todo =
245 {{ST->getElementType(i), {}}};
246 while (!todo.empty()) {
247 auto cur = std::move(todo[0]);
248 todo.pop_front();
249 auto path = std::move(cur.second);
250 auto ty = cur.first;
251
252 if (auto PT = dyn_cast<PointerType>(ty)) {
253 if (isSpecialPtr(PT)) {
254 paths_to_cover[path] = false;
255 }
256 continue;
257 }
258
259 if (auto AT = dyn_cast<ArrayType>(ty)) {
260 for (size_t k = 0; k < AT->getNumElements(); k++) {
261 std::vector<unsigned> path2(path);
262 path2.push_back(k);
263 todo.emplace_back(AT->getElementType(), path2);
264 }
265 continue;
266 }
267
268 if (auto VT = dyn_cast<VectorType>(ty)) {
269 for (size_t k = 0;
270 k < VT->getElementCount().getKnownMinValue(); k++) {
271 std::vector<unsigned> path2(path);
272 path2.push_back(k);
273 todo.emplace_back(VT->getElementType(), path2);
274 }
275 continue;
276 }
277
278 if (auto ST2 = dyn_cast<StructType>(ty)) {
279 for (size_t k = 0; k < ST2->getNumElements(); k++) {
280 std::vector<unsigned> path2(path);
281 path2.push_back(k);
282 todo.emplace_back(ST2->getTypeAtIndex(k), path2);
283 }
284 continue;
285 }
286 }
287 }
288
289 for (auto u : sv->users()) {
290 if (auto ev0 = dyn_cast<ExtractValueInst>(u)) {
291 if (ev0->getIndices()[0] == i) {
292 std::vector<unsigned> extract_path;
293 for (size_t k = 1; k < ev0->getNumIndices(); ++k) {
294 extract_path.push_back(ev0->getIndices()[k]);
295 }
296 storedValues.push_back(ev0);
297
298 // Mark paths covered
299 for (auto &pair : paths_to_cover) {
300 const auto &p = pair.first;
301 bool match = true;
302 if (extract_path.size() > p.size()) {
303 match = false;
304 } else {
305 for (size_t idx = 0; idx < extract_path.size(); ++idx) {
306 if (extract_path[idx] != p[idx]) {
307 match = false;
308 break;
309 }
310 }
311 }
312 if (match) {
313 pair.second = true;
314 }
315 }
316 }
317 }
318 }
319
320 bool fullyCovered = true;
321 for (const auto &pair : paths_to_cover) {
322 if (!pair.second) {
323 fullyCovered = false;
324 break;
325 }
326 }
327
328 if (!fullyCovered) {
329 llvm::errs() << " failed to find extracted pointer for " << *sv
330 << " at index " << i << "\n";
331 legal = false;
332 break;
333 }
334 }
335 if (legal) {
336 continue;
337 }
338 }
339 if (!isa<PointerType>(sv->getType()) ||
340 !isSpecialPtr(cast<PointerType>(sv->getType()))) {
341 llvm::errs() << " sf: " << *arg->getParent() << "\n";
342 llvm::errs() << " arg: " << *arg << "\n";
343 llvm::errs() << "Pointer of wrong type: " << *sv << "\n";
344 assert(0);
345 }
346
347 {
348 bool saw_bitcast = false;
349 for (auto u : sv->users()) {
350 if (auto ev0 = dyn_cast<CastInst>(u)) {
351 auto t2 = ev0->getType();
352 if (isa<PointerType>(t2) && isSpecialPtr(cast<PointerType>(t2))) {
353 saw_bitcast = true;
354 storedValues.push_back(ev0);
355 break;
356 }
357 }
358 }
359 if (saw_bitcast)
360 continue;
361 }
362
363 if (hasReturnRootingAfterArg) {
364 std::string s;
365 llvm::raw_string_ostream ss(s);
366 ss << "Could not find use of stored value\n";
367 ss << " sv: " << *sv << "\n";
368 CustomErrorHandler(ss.str().c_str(), wrap(sv), ErrorType::GCRewrite,
369 nullptr, wrap(arg), nullptr);
370 }
371 legal = false;
372 break;
373 }
374 }
375 }
376
377 return !legal;
378}
379
380// For a given enzymejl_returnRoots, which we assume is loaded after
381// the call (and therefore is needed to be preserved), check whether
382// there is an existing enzyme_sret for whom the roots could be assigned to,
383// or if an additional sret argument is required.
384// This is because count(sret_type) == returnRoots for the final merged type.
385// As a result, there _must_ be a sret_type corresponding to the return root.
386bool needsReReturning(llvm::Argument *arg, size_t &sret_idx,
387 std::map<size_t, size_t> &srets_without_stores) {
388 auto Attrs = arg->getParent()->getAttributes();
389
390 bool hasSRetBeforeArg = false;
391 for (size_t i = 0; i < arg->getArgNo(); i++) {
392 if (Attrs.hasAttribute(AttributeList::FirstArgIndex + i, "enzyme_sret")) {
393 hasSRetBeforeArg = true;
394 break;
395 }
396 }
397
398 if (!hasSRetBeforeArg) {
399 assert(srets_without_stores.size() == 0);
400 return true;
401 }
402
403 if (srets_without_stores.size() == 0) {
404 return true;
405 }
406
407 size_t subCount = convertRRootCountFromString(
408 Attrs
409 .getAttribute(AttributeList::FirstArgIndex + arg->getArgNo(),
410 "enzymejl_returnRoots")
411 .getValueAsString());
412
413 for (auto &pair : srets_without_stores) {
414 if (pair.second == subCount) {
415 sret_idx = pair.first;
416 srets_without_stores.erase(sret_idx);
417 return false;
418 }
419 }
420
421 llvm_unreachable("Unsupported needsReRooting");
422 return true;
423}
424
425static bool isOpaque(llvm::Type *T) {
426#if LLVM_VERSION_MAJOR >= 20
427 return T->isPointerTy();
428#else
429 return T->isOpaquePointerTy();
430#endif
431}
432
433static void removeRange(std::vector<std::pair<uint64_t, uint64_t>> &ranges,
434 uint64_t start, uint64_t end) {
435 std::vector<std::pair<uint64_t, uint64_t>> nextRanges;
436 for (auto &range : ranges) {
437 if (end <= range.first || start >= range.second) {
438 nextRanges.push_back(range);
439 } else {
440 if (start > range.first) {
441 nextRanges.push_back({range.first, start});
442 }
443 if (end < range.second) {
444 nextRanges.push_back({end, range.second});
445 }
446 }
447 }
448 ranges = std::move(nextRanges);
449}
450static bool isReadOnlyNoCapture(Function *F, unsigned argNo) {
451 return F->hasParamAttribute(argNo, Attribute::ReadOnly) &&
452 F->getArg(argNo)->hasNoCaptureAttr();
453}
454
455static bool isGuaranteedToFullyWrite(Function *F, unsigned argNo, Type *T) {
456 if (F->isDeclaration())
457 return false;
458
459 auto &DL = F->getParent()->getDataLayout();
460 auto size = DL.getTypeAllocSize(T);
461
462 std::vector<std::pair<uint64_t, uint64_t>> ranges = {{0, size}};
463 std::vector<std::pair<Value *, uint64_t>> worklist = {{F->getArg(argNo), 0}};
464 std::set<Value *> seen = {F->getArg(argNo)};
465
466 PostDominatorTree PDT(*F);
467
468 while (!worklist.empty()) {
469 auto item = worklist.back();
470 worklist.pop_back();
471 Value *val = item.first;
472 uint64_t offset = item.second;
473
474 for (auto *U : val->users()) {
475 if (auto *BI = dyn_cast<CastInst>(U)) {
476 if (seen.insert(BI).second)
477 worklist.push_back({BI, offset});
478 continue;
479 }
480
481 if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) {
482 APInt gepOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0);
483 if (GEP->accumulateConstantOffset(DL, gepOffset)) {
484 if (seen.insert(GEP).second)
485 worklist.push_back({GEP, offset + gepOffset.getZExtValue()});
486 continue;
487 }
488 }
489
490 if (auto *I = dyn_cast<Instruction>(U)) {
491 if (I->getParent() != &F->getEntryBlock() &&
492 !PDT.dominates(I->getParent(), &F->getEntryBlock()))
493 continue;
494
495 if (auto *SI = dyn_cast<StoreInst>(I)) {
496 if (SI->getPointerOperand() == val) {
497 auto storeSize =
498 DL.getTypeAllocSize(SI->getValueOperand()->getType());
499 removeRange(ranges, offset, offset + storeSize);
500 if (ranges.empty())
501 return true;
502 continue;
503 }
504 }
505
506 if (auto *MSI = dyn_cast<MemSetInst>(I)) {
507 if (MSI->getDest() == val) {
508 if (auto *CI = dyn_cast<ConstantInt>(MSI->getLength())) {
509 removeRange(ranges, offset, offset + CI->getZExtValue());
510 if (ranges.empty())
511 return true;
512 continue;
513 }
514 }
515 }
516
517 if (auto *MCI = dyn_cast<MemCpyInst>(I)) {
518 if (MCI->getDest() == val) {
519 if (auto *CI = dyn_cast<ConstantInt>(MCI->getLength())) {
520 removeRange(ranges, offset, offset + CI->getZExtValue());
521 if (ranges.empty())
522 return true;
523 continue;
524 }
525 }
526 }
527 }
528 }
529 }
530
531 return ranges.empty();
532}
533
534// TODO, for sret/sret_v check if it actually stores the jlvalue_t's into the
535// sret If so, confirm that those values are saved elsewhere in a returnroot
536void EnzymeFixupJuliaCallingConvention(Function *F, bool sret_jlvalue) {
537 if (F->empty())
538 return;
539
540 auto RT = F->getReturnType();
541 std::set<size_t> srets;
542 std::set<size_t> enzyme_srets;
543
544 std::set<size_t> reroot_enzyme_srets;
545
546 std::set<size_t> noroot_enzyme_srets;
547
548 std::set<size_t> rroots;
549
550 std::set<size_t> reret_roots;
551
552 auto FT = F->getFunctionType();
553 auto Attrs = F->getAttributes();
554
555 std::map<size_t, size_t> selected_roots;
556
557 // Map from the sret index to the number of stores, as unused
558 std::map<size_t, size_t> srets_without_stores;
559
560 for (size_t i = 0, end = FT->getNumParams(); i < end; i++) {
561 if (Attrs.hasAttribute(AttributeList::FirstArgIndex + i,
562 Attribute::StructRet))
563 srets.insert(i);
564 if (Attrs.hasAttribute(AttributeList::FirstArgIndex + i, "enzyme_sret")) {
565 bool anyJLStore = false;
566 enzyme_srets.insert(i);
567 if (needsReRooting(F->getArg(i), anyJLStore)) {
568 // Case 1: jlvalue_t's were stored into the sret, but were not stored
569 // into an existing rooted argument.
570 reroot_enzyme_srets.insert(i);
571 } else if (anyJLStore) {
572 // Case 2: jlvalue_t's were stored into the sret, and the were stored
573 // into an existing rooted argument.
574 } else {
575 // Case 3: No jlvalue_t's were stored into the sret.
576 llvm::Type *SRetType = convertSRetTypeFromString(
577 Attrs.getAttribute(AttributeList::FirstArgIndex + i, "enzyme_sret")
578 .getValueAsString(),
579 &F->getContext());
580 if (auto count = CountTrackedPointers(SRetType).count) {
581 srets_without_stores[i] = count;
582 noroot_enzyme_srets.insert(i);
583 }
584 }
585 }
586 assert(
587 !Attrs.hasAttribute(AttributeList::FirstArgIndex + i, "enzyme_sret_v"));
588
589 if (Attrs.hasAttribute(AttributeList::FirstArgIndex + i,
590 "enzymejl_returnRoots")) {
591 rroots.insert(i);
592 size_t sret_idx;
593 // Existing
594 if (needsReReturning(F->getArg(i), sret_idx, srets_without_stores)) {
595 reret_roots.insert(i);
596 } else {
597 selected_roots[i] = sret_idx;
598 }
599 }
600 assert(!Attrs.hasAttribute(AttributeList::FirstArgIndex + i,
601 "enzymejl_returnRoots_v"));
602 }
603
604 // Regular julia function, needing no intervention
605 if (srets.size() == 1) {
606 assert(*srets.begin() == 0);
607 assert(enzyme_srets.size() == 0);
608 llvm::Type *SRetType = F->getParamStructRetType(0);
609 CountTrackedPointers tracked(SRetType);
610
611 // No jlvaluet to rewrite
612 if (!tracked.count) {
613 return;
614 }
615
616 bool anyJLStore = false;
617 bool rerooting = needsReRooting(F->getArg(0), anyJLStore, SRetType);
618
619 // We now assume we have an sret.
620 // If it is properly rooted, we don't have any work to do
621 if (rroots.size()) {
622 assert(rroots.size() == 1);
623 assert(*rroots.begin() == 1);
624 // GVN is only powerful enough at LLVM 16+
625 // (https://godbolt.org/z/ebY3exW9K)
626#if LLVM_VERSION_MAJOR >= 16
627 if (rerooting) {
628 std::string s;
629 llvm::raw_string_ostream ss(s);
630 ss << "Illegal GC setup in which rerooting is required\n";
631 ss << " + F: " << *F << "\n";
633 nullptr, nullptr, nullptr);
634 }
635 assert(!rerooting);
636#endif
637
638 size_t count = convertRRootCountFromString(
639 Attrs
640 .getAttribute(AttributeList::FirstArgIndex + 1,
641 "enzymejl_returnRoots")
642 .getValueAsString());
643
644 assert(count == tracked.count);
645 return;
646 }
647
648 F->addParamAttr(0, Attribute::get(F->getContext(), "enzyme_sret",
649 convertSRetTypeToString(SRetType)));
650 Attrs = F->getAttributes();
651 srets.clear();
652 size_t i = 0;
653 enzyme_srets.insert(i);
654 if (rerooting) {
655 reroot_enzyme_srets.insert(i);
656 } else if (anyJLStore) {
657 } else {
658 if (auto count = CountTrackedPointers(SRetType).count) {
659 srets_without_stores[i] = count;
660 noroot_enzyme_srets.insert(i);
661 }
662 }
663 } else if (srets.size() == 0 && enzyme_srets.size() == 0 &&
664 rroots.size() == 0) {
665 // No sret/rooting, no intervention needed.
666 return;
667 }
668
669 // Number of additional roots, which contain actually no data at all.
670 // Consider this additional rerooting of the sret, except this time
671 // just fill it with 0's
672 for (auto &pair : srets_without_stores) {
673 assert(pair.second);
674 reroot_enzyme_srets.insert(pair.first);
675 }
676
677 assert(srets.size() == 0);
678
679 SmallVector<Type *, 1> Types;
680 if (!RT->isVoidTy()) {
681 Types.push_back(RT);
682 }
683
684 auto T_jlvalue = StructType::get(F->getContext(), {});
685 auto T_prjlvalue = PointerType::get(T_jlvalue, AddressSpace::Tracked);
686
687 size_t numRooting = RT->isVoidTy() ? 0 : CountTrackedPointers(RT).count;
688
689 for (auto idx : enzyme_srets) {
690 llvm::Type *SRetType = convertSRetTypeFromString(
691 Attrs.getAttribute(AttributeList::FirstArgIndex + idx, "enzyme_sret")
692 .getValueAsString(),
693 &F->getContext());
694#if LLVM_VERSION_MAJOR < 17
695 if (F->getContext().supportsTypedPointers()) {
696 auto T = FT->getParamType(idx)->getPointerElementType();
697 if (T != SRetType) {
698 std::string s;
699 llvm::raw_string_ostream ss(s);
700 ss << "Type mismatch in FixupJuliaCallingConvention:\n";
701 ss << " + T: " << *T << "\n";
702 ss << " + SRetType: " << *SRetType << "\n";
703 EmitFailure("TypeMismatch", F->getSubprogram(), F, ss.str());
704 }
705 }
706#endif
707 Types.push_back(SRetType);
708 if (reroot_enzyme_srets.count(idx)) {
709 numRooting += CountTrackedPointers(SRetType).count;
710 }
711 }
712 for (auto idx : rroots) {
713 size_t count = convertRRootCountFromString(
714 Attrs
715 .getAttribute(AttributeList::FirstArgIndex + idx,
716 "enzymejl_returnRoots")
717 .getValueAsString());
718 auto T = ArrayType::get(T_prjlvalue, count);
719#if LLVM_VERSION_MAJOR < 17
720 if (F->getContext().supportsTypedPointers()) {
721 auto NT = FT->getParamType(idx)->getPointerElementType();
722 assert(NT == T);
723 }
724#endif
725 if (reret_roots.count(idx)) {
726 Types.push_back(T);
727 }
728 numRooting += count;
729 }
730
731 StructType *ST =
732 Types.size() <= 1 ? nullptr : StructType::get(F->getContext(), Types);
733 Type *sretTy = nullptr;
734 if (Types.size())
735 sretTy = Types.size() == 1 ? Types[0] : ST;
736
737 ArrayType *roots_AT =
738 numRooting ? ArrayType::get(T_prjlvalue, numRooting) : nullptr;
739
740 if (sretTy) {
741 CountTrackedPointers countF(sretTy);
742 // If all fields of the sret struct are tracked pointers, the struct itself
743 // acts as a root anchor on the caller's stack frame. In this scenario, we
744 // do not allocate an additional explicit ReturnRoots array argument.
745 if (countF.all) {
746 roots_AT = nullptr;
747 numRooting = 0;
748 reroot_enzyme_srets.clear();
749 } else if (countF.count) {
750 if (!roots_AT) {
751 llvm::errs() << " sretTy: " << *sretTy << "\n";
752 llvm::errs() << " numRooting: " << numRooting << "\n";
753 llvm::errs() << " tracked.count: " << countF.count << "\n";
754 }
755 assert(roots_AT);
756 if (numRooting != countF.count) {
757 std::string s;
758 llvm::raw_string_ostream ss(s);
759 ss << "Illegal GC setup in which numRooting (" << numRooting
760 << ") != tracked.count (" << countF.count << ")\n";
761 ss << " sretTy: " << *sretTy << "\n";
762 ss << " Types.size(): " << Types.size() << "\n";
763 for (size_t i = 0; i < Types.size(); i++) {
764 ss << " + Types[" << i << "] = " << *Types[i] << "\n";
765 }
766 ss << " F: " << *F << "\n";
768 nullptr, nullptr, nullptr);
769 }
770 assert(numRooting == countF.count);
771 }
772 }
773
774 AttributeList NewAttrs;
775 SmallVector<Type *, 1> types;
776 size_t nexti = 0;
777 if (sretTy) {
778 types.push_back(getUnqual(sretTy));
779 NewAttrs = NewAttrs.addAttribute(
780 F->getContext(), AttributeList::FirstArgIndex + nexti,
781 Attribute::get(F->getContext(), Attribute::StructRet, sretTy));
782 NewAttrs = NewAttrs.addAttribute(F->getContext(),
783 AttributeList::FirstArgIndex + nexti,
784 Attribute::NoAlias);
785 nexti++;
786 }
787 if (roots_AT) {
788 NewAttrs = NewAttrs.addAttribute(
789 F->getContext(), AttributeList::FirstArgIndex + nexti,
790 "enzymejl_returnRoots", std::to_string(numRooting));
791 NewAttrs = NewAttrs.addAttribute(F->getContext(),
792 AttributeList::FirstArgIndex + nexti,
793 Attribute::NoAlias);
794 NewAttrs = NewAttrs.addAttribute(F->getContext(),
795 AttributeList::FirstArgIndex + nexti,
796 Attribute::WriteOnly);
797 types.push_back(getUnqual(roots_AT));
798 nexti++;
799 }
800
801 for (size_t i = 0, end = FT->getNumParams(); i < end; i++) {
802 if (enzyme_srets.count(i) || rroots.count(i))
803 continue;
804
805 for (auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + i)) {
806 NewAttrs = NewAttrs.addAttribute(
807 F->getContext(), AttributeList::FirstArgIndex + nexti, attr);
808 }
809 types.push_back(F->getFunctionType()->getParamType(i));
810 nexti++;
811 }
812
813 for (auto attr : Attrs.getAttributes(AttributeList::FunctionIndex))
814 NewAttrs = NewAttrs.addAttribute(F->getContext(),
815 AttributeList::FunctionIndex, attr);
816
817 FunctionType *FTy = FunctionType::get(Type::getVoidTy(F->getContext()), types,
818 FT->isVarArg());
819
820 // Create the new function
821 auto &M = *F->getParent();
822 Function *NewF = Function::Create(FTy, F->getLinkage(), F->getAddressSpace(),
823 F->getName(), &M);
824
825 ValueToValueMapTy VMap;
826 // Loop over the arguments, copying the names of the mapped arguments over...
827 Function::arg_iterator DestI = NewF->arg_begin();
828 Argument *sret = nullptr;
829 if (sretTy) {
830 sret = &*DestI;
831 DestI++;
832 }
833 Argument *roots = nullptr;
834 if (roots_AT) {
835 roots = &*DestI;
836 DestI++;
837 }
838
839 // To handle the deleted args, it needs to be replaced by a non-arg operand.
840 // This map contains the temporary phi nodes corresponding
841 std::map<size_t, PHINode *> delArgMap;
842 for (Argument &I : F->args()) {
843 auto i = I.getArgNo();
844 if (enzyme_srets.count(i) || rroots.count(i)) {
845 VMap[&I] = delArgMap[i] = PHINode::Create(I.getType(), 0);
846 continue;
847 }
848 assert(DestI != NewF->arg_end());
849 DestI->setName(I.getName()); // Copy the name over...
850 VMap[&I] = &*DestI++; // Add mapping to VMap
851 }
852 // Compute the readonly/nocapture/etc properties for analysis use later.
853 {
854 SmallPtrSet<Function *, 1> calls_todo;
855 (void)DetectPointerArgOfFn(*F, calls_todo);
856 }
857 SmallVector<ReturnInst *, 8> Returns; // Ignore returns cloned.
858 CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
859 Returns, "", nullptr);
860
861 SmallVector<CallInst *, 1> callers;
862 for (auto U : F->users()) {
863 auto CI = dyn_cast<CallInst>(U);
864 assert(CI);
865 assert(CI->getCalledFunction() == F);
866 callers.push_back(CI);
867 }
868
869 {
870 size_t curOffset = 0;
871 size_t sretCount = 0;
872 if (!RT->isVoidTy()) {
873 for (auto &RT : Returns) {
874 IRBuilder<> B(RT);
875 Value *gep = ST ? B.CreateConstInBoundsGEP2_32(ST, sret, 0, 0) : sret;
876 Value *rval = RT->getReturnValue();
877 B.CreateStore(rval, gep);
878
879 if (roots) {
880 moveSRetToFromRoots(B, rval->getType(), rval, roots_AT, roots,
881 /*rootOffset*/ 0,
883 }
884
885 auto NR = B.CreateRetVoid();
886 RT->eraseFromParent();
887 RT = NR;
888 }
889 if (roots_AT)
890 curOffset = CountTrackedPointers(RT).count;
891 sretCount++;
892 }
893
894 // TODO this must be re-ordered to interleave the sret/roots/etc args as
895 // required.
896
897 for (size_t i = 0, end = FT->getNumParams(); i < end; i++) {
898
899 if (enzyme_srets.count(i)) {
900 auto argFound = delArgMap.find(i);
901 assert(argFound != delArgMap.end());
902 auto arg = argFound->second;
903 assert(arg);
904 SmallVector<Instruction *, 1> uses;
905 SmallVector<unsigned, 1> op;
906 for (auto &U : arg->uses()) {
907 auto I = cast<Instruction>(U.getUser());
908 uses.push_back(I);
909 op.push_back(U.getOperandNo());
910 }
911 IRBuilder<> EB(&NewF->getEntryBlock().front());
912 auto gep =
913 ST ? EB.CreateConstInBoundsGEP2_32(ST, sret, 0, sretCount) : sret;
914 for (size_t i = 0; i < uses.size(); i++) {
915 uses[i]->setOperand(op[i], gep);
916 }
917
918 if (reroot_enzyme_srets.count(i)) {
919 assert(roots_AT);
920 auto cnt = CountTrackedPointers(Types[sretCount]).count;
921 for (auto &RT : Returns) {
922 IRBuilder<> B(RT);
923 if (noroot_enzyme_srets.count(i)) {
924 for (size_t i = 0; i < cnt; i++) {
925 B.CreateStore(ConstantPointerNull::get(T_prjlvalue),
926 B.CreateConstInBoundsGEP2_32(roots_AT, roots, 0,
927 i + curOffset));
928 }
929 } else {
930 moveSRetToFromRoots(B, Types[sretCount], gep, roots_AT, roots,
931 curOffset,
933 }
934 }
935 curOffset += cnt;
936 }
937
938 delete arg;
939
940 sretCount++;
941 continue;
942 }
943
944 if (rroots.count(i)) {
945 auto attr = Attrs.getAttribute(AttributeList::FirstArgIndex + i,
946 "enzymejl_returnRoots");
947 auto attrv = attr.getValueAsString();
948 assert(attrv.size());
949 size_t subCount = convertRRootCountFromString(attrv);
950
951 auto argFound = delArgMap.find(i);
952 assert(argFound != delArgMap.end());
953 auto arg = argFound->second;
954 assert(arg);
955 SmallVector<Instruction *, 1> uses;
956 SmallVector<unsigned, 1> op;
957 for (auto &U : arg->uses()) {
958 auto I = cast<Instruction>(U.getUser());
959 uses.push_back(I);
960 op.push_back(U.getOperandNo());
961 }
962 IRBuilder<> EB(&NewF->getEntryBlock().front());
963
964 Value *gep = nullptr;
965 if (roots_AT) {
966 assert(roots);
967 assert(roots_AT);
968
969 gep = roots;
970 if (curOffset != 0) {
971 gep = EB.CreateConstInBoundsGEP2_32(roots_AT, roots, 0, curOffset);
972 }
973 if (subCount != numRooting) {
974 gep = EB.CreatePointerCast(
975 gep, getUnqual(ArrayType::get(T_prjlvalue, subCount)));
976 }
977 curOffset += subCount;
978 if (reret_roots.count(i))
979 sretCount++;
980 } else {
981 assert(sret);
982 gep =
983 ST ? EB.CreateConstInBoundsGEP2_32(ST, sret, 0, sretCount) : sret;
984
985 if (!reret_roots.count(i)) {
986
987 std::string s;
988 llvm::raw_string_ostream ss(s);
989 ss << "Illegal GC setup in which there was no roots_AT, but a new "
990 "sret ("
991 << *sret << "), but no rereturned roots at index i=" << i
992 << "\n";
993 CustomErrorHandler(s.c_str(), wrap(gep), ErrorType::InternalError,
994 nullptr, nullptr, nullptr);
995 }
996
997 sretCount++;
998 }
999
1000 for (size_t i = 0; i < uses.size(); i++) {
1001 uses[i]->setOperand(op[i], gep);
1002 }
1003
1004 delete arg;
1005 continue;
1006 }
1007 }
1008
1009 assert(curOffset == numRooting);
1010 assert(sretCount == Types.size());
1011 }
1012
1013 auto &DL = F->getParent()->getDataLayout();
1014
1015 // TODO fix caller side
1016 for (auto CI : callers) {
1017 auto Attrs = CI->getAttributes();
1018 AttributeList NewAttrs;
1019 IRBuilder<> B(CI);
1020 IRBuilder<> EB(&CI->getParent()->getParent()->getEntryBlock().front());
1021 SmallVector<Value *, 1> vals;
1022 size_t nexti = 0;
1023 Value *sret = nullptr;
1024 if (sretTy) {
1025 sret = EB.CreateAlloca(sretTy, 0, "stack_sret");
1026 vals.push_back(sret);
1027 NewAttrs = NewAttrs.addAttribute(
1028 F->getContext(), AttributeList::FirstArgIndex + nexti,
1029 Attribute::get(F->getContext(), Attribute::StructRet, sretTy));
1030 nexti++;
1031 }
1032 AllocaInst *roots = nullptr;
1033 if (roots_AT) {
1034 roots = EB.CreateAlloca(roots_AT, 0, "stack_roots_AT");
1035 vals.push_back(roots);
1036 NewAttrs = NewAttrs.addAttribute(
1037
1038 F->getContext(), AttributeList::FirstArgIndex + nexti,
1039 "enzymejl_returnRoots", std::to_string(numRooting));
1040 nexti++;
1041 }
1042
1043 for (auto attr : Attrs.getAttributes(AttributeList::FunctionIndex))
1044 NewAttrs = NewAttrs.addAttribute(F->getContext(),
1045 AttributeList::FunctionIndex, attr);
1046
1047 SmallVector<std::tuple<Value *, Value *, Type *>> preCallReplacements;
1048 SmallVector<std::tuple<Value *, Value *, Type *, bool>>
1049 postCallReplacements;
1050
1051 {
1052 size_t local_root_count = 0;
1053 size_t sretCount = 0;
1054 if (!RT->isVoidTy()) {
1055 if (roots_AT) {
1056 local_root_count += CountTrackedPointers(RT).count;
1057 }
1058 sretCount++;
1059 }
1060
1061 /// TODO continue from here down for external rewrites
1062 for (size_t i = 0, end = CI->arg_size(); i < end; i++) {
1063
1064 if (enzyme_srets.count(i)) {
1065 auto val = CI->getArgOperand(i);
1066
1067 if (isa<UndefValue>(val) || isa<PoisonValue>(val) ||
1068 isa<ConstantPointerNull>(val)) {
1069 std::string s;
1070 llvm::raw_string_ostream ss(s);
1071 ss << "Unsupported constant argument in "
1072 "FixupJuliaCallingConvention\n";
1073 ss << " + val: " << *val << "\n";
1074 ss << " + Function being rewritten: " << F->getName() << "\n";
1075 ss << " + CI erring: " << *CI << "\n";
1076 ss << " + Function containing CI: "
1077 << CI->getParent()->getParent()->getName() << "\n";
1078 if (CustomErrorHandler) {
1079 CustomErrorHandler(s.c_str(), wrap(CI), ErrorType::InternalError,
1080 nullptr, nullptr, nullptr);
1081 } else {
1082 EmitFailure("UnsupportedArgument", CI->getDebugLoc(), CI,
1083 ss.str());
1084 }
1085 }
1086
1087 Value *gep = sret;
1088 if (ST) {
1089 IRBuilder<> GEPB(cast<Instruction>(sret)->getNextNode());
1090 gep = GEPB.CreateConstInBoundsGEP2_32(ST, sret, 0, sretCount);
1091 }
1092
1093 bool handled = false;
1094 if (auto AI = dyn_cast<AllocaInst>(getBaseObject(val, false))) {
1095 if (AI->getAllocatedType() == Types[sretCount] ||
1096 (isOpaque(AI->getType()) &&
1097 DL.getTypeSizeInBits(AI->getAllocatedType()) ==
1098 DL.getTypeSizeInBits(Types[sretCount]))) {
1099 AI->replaceAllUsesWith(gep);
1100 AI->eraseFromParent();
1101 handled = true;
1102 }
1103 }
1104
1105 if (!handled) {
1106 assert(!isa<UndefValue>(val));
1107 assert(!isa<PoisonValue>(val));
1108 assert(!isa<ConstantPointerNull>(val));
1109
1110 // On Julia 1.12+, the sret does not actually contain the jlvaluet
1111 // (and it should not). However, if the sret does not contain a
1112 // return roots (per tracked pointers), we do still need to perform
1113 // the store.
1114 bool should_sret = sret_jlvalue;
1115 if (!should_sret) {
1116 CountTrackedPointers tracked(Types[sretCount]);
1117 if (tracked.count && tracked.all)
1118 should_sret = true;
1119 }
1120
1121 // Don't bother to copy back in if the original function doesn't
1122 // store anything.
1123 bool copyBack = !isReadOnlyNoCapture(F, i);
1124 if (copyBack) {
1125 postCallReplacements.emplace_back(val, gep, Types[sretCount],
1126 should_sret);
1127 }
1128 // Only copy in the inital value if the function reads, or we are
1129 // going to copy back and the function doesn't store all bytes.
1130 if (!isWriteOnly(CI, i) ||
1131 (copyBack &&
1132 !isGuaranteedToFullyWrite(F, i, Types[sretCount]))) {
1133 preCallReplacements.emplace_back(val, gep, Types[sretCount]);
1134 }
1135 }
1136
1137 if (roots_AT && reroot_enzyme_srets.count(i)) {
1138 local_root_count += CountTrackedPointers(Types[sretCount]).count;
1139 }
1140
1141 sretCount++;
1142 continue;
1143 }
1144
1145 if (rroots.count(i)) {
1146 auto val = CI->getArgOperand(i);
1147 if (isa<UndefValue>(val) || isa<PoisonValue>(val) ||
1148 isa<ConstantPointerNull>(val)) {
1149 std::string s;
1150 llvm::raw_string_ostream ss(s);
1151 ss << "Unsupported constant argument in "
1152 "FixupJuliaCallingConvention\n";
1153 ss << " + val: " << *val << "\n";
1154 ss << " + Function being rewritten: " << F->getName() << "\n";
1155 ss << " + CI erring: " << *CI << "\n";
1156 ss << " + Function containing CI: "
1157 << CI->getParent()->getParent()->getName() << "\n";
1158 if (CustomErrorHandler) {
1159 CustomErrorHandler(s.c_str(), wrap(CI), ErrorType::InternalError,
1160 nullptr, nullptr, nullptr);
1161 } else {
1162 EmitFailure("UnsupportedArgument", CI->getDebugLoc(), CI,
1163 ss.str());
1164 }
1165 }
1166
1167 auto attr = Attrs.getAttribute(AttributeList::FirstArgIndex + i,
1168 "enzymejl_returnRoots");
1169 auto attrv = attr.getValueAsString();
1170 assert(attrv.size());
1171 size_t subCount = convertRRootCountFromString(attrv);
1172
1173 Value *gep = nullptr;
1174
1175 if (roots_AT) {
1176 assert(roots);
1177 IRBuilder<> GEPB(cast<Instruction>(roots)->getNextNode());
1178 gep = roots;
1179 if (local_root_count != 0) {
1180 gep = GEPB.CreateConstInBoundsGEP2_32(roots_AT, roots, 0,
1181 local_root_count);
1182 }
1183
1184 if (subCount != numRooting) {
1185 gep = GEPB.CreatePointerCast(
1186 gep, getUnqual(ArrayType::get(T_prjlvalue, subCount)));
1187 }
1188 local_root_count += subCount;
1189 if (reret_roots.count(i))
1190 sretCount++;
1191 } else {
1192 assert(reret_roots.count(i));
1193 assert(sret);
1194 IRBuilder<> GEPB(cast<Instruction>(sret)->getNextNode());
1195 gep = sret;
1196 if (ST) {
1197 gep = GEPB.CreateConstInBoundsGEP2_32(ST, sret, 0, sretCount);
1198 }
1199 sretCount++;
1200 }
1201
1202 bool handled = false;
1203 if (auto AI = dyn_cast<AllocaInst>(getBaseObject(val, false))) {
1204 if (AI->getAllocatedType() ==
1205 ArrayType::get(T_prjlvalue, subCount)) {
1206 AI->replaceAllUsesWith(gep);
1207 AI->eraseFromParent();
1208 handled = true;
1209 }
1210 }
1211
1212 if (!handled) {
1213 assert(!isa<UndefValue>(val));
1214 assert(!isa<PoisonValue>(val));
1215 assert(!isa<ConstantPointerNull>(val));
1216 // TODO consider doing pre-emptive pre zero of the section?
1217 preCallReplacements.emplace_back(
1218 val, gep, ArrayType::get(T_prjlvalue, subCount));
1219 postCallReplacements.emplace_back(
1220 val, gep, ArrayType::get(T_prjlvalue, subCount), true);
1221 }
1222 continue;
1223 }
1224
1225 for (auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + i))
1226 NewAttrs = NewAttrs.addAttribute(
1227 F->getContext(), AttributeList::FirstArgIndex + nexti, attr);
1228 vals.push_back(CI->getArgOperand(i));
1229 nexti++;
1230 }
1231
1232 assert(sretCount == Types.size());
1233 assert(local_root_count == numRooting);
1234 }
1235
1236 // Because we will += into the corresponding derivative sret, we need to
1237 // pass in the values that were actually there before the call
1238 // TODO we can optimize this further and avoid the copy in the primal and/or
1239 // forward mode as the copy is _only_ needed for the adjoint.
1240 for (auto &&[val, gep, ty] : preCallReplacements) {
1241 copyNonJLValueInto(B, ty, ty, gep, {}, ty, val, {}, /*shouldZero*/ true);
1242 }
1243
1244 // Actually perform the call, copying over relevant information.
1245 SmallVector<OperandBundleDef, 1> Bundles;
1246 for (unsigned I = 0, E = CI->getNumOperandBundles(); I != E; ++I)
1247 Bundles.emplace_back(CI->getOperandBundleAt(I));
1248
1249 if (!NewF->getFunctionType()->isVarArg() &&
1250 NewF->getFunctionType()->getNumParams() != vals.size()) {
1251 llvm::errs() << "NewF: " << *NewF << "\n";
1252 for (size_t i = 0; i < vals.size(); i++) {
1253 llvm::errs() << " Args[" << i << "] = " << *vals[i] << "\n";
1254 }
1255 }
1256 auto NC = B.CreateCall(NewF, vals, Bundles);
1257 NC->setAttributes(NewAttrs);
1258
1259 SmallVector<std::pair<unsigned, MDNode *>, 4> TheMDs;
1260 CI->getAllMetadataOtherThanDebugLoc(TheMDs);
1261 SmallVector<unsigned, 1> toCopy;
1262 for (auto pair : TheMDs)
1263 if (pair.first != LLVMContext::MD_range) {
1264 toCopy.push_back(pair.first);
1265 }
1266 if (!toCopy.empty())
1267 NC->copyMetadata(*CI, toCopy);
1268 NC->setDebugLoc(CI->getDebugLoc());
1269
1270 if (!RT->isVoidTy()) {
1271 auto gep = ST ? B.CreateConstInBoundsGEP2_32(ST, sret, 0, 0) : sret;
1272 auto ld = B.CreateLoad(RT, gep);
1273 if (auto MD = CI->getMetadata(LLVMContext::MD_range))
1274 ld->setMetadata(LLVMContext::MD_range, MD);
1275 ld->takeName(CI);
1276 Value *replacement = ld;
1277
1278 // We don't need to override the jlvalue_t's with the rooted versions here
1279 // since we already stored the full value into the sret above.
1280 // if (fromRoots) {
1281 // replacement = moveSRetToFromRoots(B, replacement->getType(),
1282 // replacement, root_AT, root, /*rootOffset*/0,
1283 // SRetRootMovement::RootPointerToSRetValue);
1284 //}
1285
1286 CI->replaceAllUsesWith(replacement);
1287 }
1288
1289 for (auto &&[val, gep, ty, jlvalue] : postCallReplacements) {
1290 if (jlvalue) {
1291 auto ld = B.CreateLoad(ty, gep);
1292 auto SI = B.CreateStore(ld, val);
1293 if (val->getType()->getPointerAddressSpace() == 10)
1294 PostCacheStore(SI, B);
1295 } else {
1296 copyNonJLValueInto(B, ty, ty, val, {}, ty, gep, {},
1297 /*shouldZero*/ false);
1298 }
1299 }
1300
1301 NC->setCallingConv(CI->getCallingConv());
1302 CI->eraseFromParent();
1303 }
1304 NewF->setAttributes(NewAttrs);
1305 SmallVector<std::pair<unsigned, MDNode *>, 1> MD;
1306 F->getAllMetadata(MD);
1307 for (auto pair : MD)
1308 if (pair.first != LLVMContext::MD_dbg)
1309 NewF->addMetadata(pair.first, *pair.second);
1310 NewF->takeName(F);
1311 NewF->setCallingConv(F->getCallingConv());
1312 F->eraseFromParent();
1313}
1314
1315#include "llvm/Passes/PassBuilder.h"
1316
1317#include <string>
1318
1319using namespace llvm;
1320
1322 if (F->empty())
1323 return;
1324 auto RT = F->getReturnType();
1325 auto FT = F->getFunctionType();
1326 auto Attrs = F->getAttributes();
1327
1328 AttributeList NewAttrs;
1329 SmallVector<Type *, 1> types;
1330 SmallSet<size_t, 1> changed;
1331 for (auto pair : llvm::enumerate(FT->params())) {
1332 auto T = pair.value();
1333 auto i = pair.index();
1334 bool sretv = false;
1335 StringRef kind;
1336 StringRef value;
1337 for (auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + i)) {
1338 if (attr.isStringAttribute() &&
1339 attr.getKindAsString() == "enzyme_sret_v") {
1340 sretv = true;
1341 kind = "enzyme_sret";
1342 value = attr.getValueAsString();
1343 } else if (attr.isStringAttribute() &&
1344 attr.getKindAsString() == "enzymejl_rooted_typ_v") {
1345 sretv = true;
1346 kind = "enzymejl_rooted_typ";
1347 value = attr.getValueAsString();
1348 } else if (attr.isStringAttribute() &&
1349 attr.getKindAsString() == "enzymejl_returnRoots_v") {
1350 sretv = true;
1351 kind = "enzymejl_returnRoots";
1352 value = attr.getValueAsString();
1353 } else {
1354 NewAttrs = NewAttrs.addAttribute(
1355 F->getContext(), AttributeList::FirstArgIndex + types.size(), attr);
1356 }
1357 }
1358 if (auto AT = dyn_cast<ArrayType>(T)) {
1359 if (auto PT = dyn_cast<PointerType>(AT->getElementType())) {
1360 auto AS = PT->getAddressSpace();
1361 if (AS == 11 || AS == 12 || AS == 13 || sretv) {
1362 for (unsigned i = 0; i < AT->getNumElements(); i++) {
1363 if (sretv) {
1364 NewAttrs = NewAttrs.addAttribute(
1365 F->getContext(), AttributeList::FirstArgIndex + types.size(),
1366 Attribute::get(F->getContext(), kind, value));
1367 }
1368 types.push_back(PT);
1369 }
1370 changed.insert(i);
1371 continue;
1372 }
1373 }
1374 }
1375 assert(!sretv);
1376 types.push_back(T);
1377 }
1378 if (changed.size() == 0)
1379 return;
1380
1381 for (auto attr : Attrs.getAttributes(AttributeList::FunctionIndex))
1382 NewAttrs = NewAttrs.addAttribute(F->getContext(),
1383 AttributeList::FunctionIndex, attr);
1384
1385 for (auto attr : Attrs.getAttributes(AttributeList::ReturnIndex))
1386 NewAttrs = NewAttrs.addAttribute(F->getContext(),
1387 AttributeList::ReturnIndex, attr);
1388
1389 FunctionType *FTy =
1390 FunctionType::get(FT->getReturnType(), types, FT->isVarArg());
1391
1392 // Create the new function
1393 Function *NewF = Function::Create(FTy, F->getLinkage(), F->getAddressSpace(),
1394 F->getName(), F->getParent());
1395
1396 ValueToValueMapTy VMap;
1397 // Loop over the arguments, copying the names of the mapped arguments over...
1398 Function::arg_iterator DestI = NewF->arg_begin();
1399
1400 // To handle the deleted args, it needs to be replaced by a non-arg operand.
1401 // This map contains the temporary phi nodes corresponding
1402 SmallVector<Instruction *, 1> toInsert;
1403 for (Argument &I : F->args()) {
1404 auto T = I.getType();
1405 if (auto AT = dyn_cast<ArrayType>(T)) {
1406 if (changed.count(I.getArgNo())) {
1407 Value *V = UndefValue::get(T);
1408 for (unsigned i = 0; i < AT->getNumElements(); i++) {
1409 DestI->setName(I.getName() + "." +
1410 std::to_string(i)); // Copy the name over...
1411 unsigned idx[1] = {i};
1412 auto IV = InsertValueInst::Create(V, (llvm::Value *)&*DestI++, idx);
1413 toInsert.push_back(IV);
1414 V = IV;
1415 }
1416 VMap[&I] = V;
1417 continue;
1418 }
1419 }
1420 DestI->setName(I.getName()); // Copy the name over...
1421 VMap[&I] = &*DestI++; // Add mapping to VMap
1422 }
1423
1424 SmallVector<ReturnInst *, 8> Returns; // Ignore returns cloned.
1425 CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
1426 Returns, "", nullptr);
1427
1428 {
1429 IRBuilder<> EB(&*NewF->getEntryBlock().begin());
1430 for (auto I : toInsert)
1431 EB.Insert(I);
1432 }
1433
1434 SmallVector<CallInst *, 1> callers;
1435 for (auto U : F->users()) {
1436 auto CI = dyn_cast<CallInst>(U);
1437 assert(CI);
1438 assert(CI->getCalledFunction() == F);
1439 callers.push_back(CI);
1440 }
1441
1442 for (auto CI : callers) {
1443 auto Attrs = CI->getAttributes();
1444 AttributeList NewAttrs;
1445 IRBuilder<> B(CI);
1446
1447 for (auto attr : Attrs.getAttributes(AttributeList::FunctionIndex))
1448 NewAttrs = NewAttrs.addAttribute(F->getContext(),
1449 AttributeList::FunctionIndex, attr);
1450
1451 for (auto attr : Attrs.getAttributes(AttributeList::ReturnIndex))
1452 NewAttrs = NewAttrs.addAttribute(F->getContext(),
1453 AttributeList::ReturnIndex, attr);
1454
1455 SmallVector<Value *, 1> vals;
1456 for (size_t j = 0, end = CI->arg_size(); j < end; j++) {
1457
1458 auto T = CI->getArgOperand(j)->getType();
1459 if (auto AT = dyn_cast<ArrayType>(T)) {
1460 if (isa<PointerType>(AT->getElementType())) {
1461 if (changed.count(j)) {
1462 bool sretv = false;
1463 std::string kind;
1464 StringRef value;
1465 for (auto attr :
1466 Attrs.getAttributes(AttributeList::FirstArgIndex + j)) {
1467 if (attr.isStringAttribute() &&
1468 attr.getKindAsString() == "enzyme_sret_v") {
1469 sretv = true;
1470 kind = "enzyme_sret";
1471 value = attr.getValueAsString();
1472 } else if (attr.isStringAttribute() &&
1473 attr.getKindAsString() == "enzymejl_returnRoots_v") {
1474 sretv = true;
1475 kind = "enzymejl_returnRoots";
1476 value = attr.getValueAsString();
1477 } else if (attr.isStringAttribute() &&
1478 attr.getKindAsString() == "enzymejl_rooted_typ_v") {
1479 sretv = true;
1480 kind = "enzymejl_rooted_typ_v";
1481 value = attr.getValueAsString();
1482 }
1483 }
1484 for (unsigned i = 0; i < AT->getNumElements(); i++) {
1485 if (sretv)
1486 NewAttrs = NewAttrs.addAttribute(
1487 F->getContext(), AttributeList::FirstArgIndex + vals.size(),
1488 Attribute::get(F->getContext(), kind, value));
1489 vals.push_back(
1490 GradientUtils::extractMeta(B, CI->getArgOperand(j), i));
1491 }
1492 continue;
1493 }
1494 }
1495 }
1496
1497 for (auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + j)) {
1498 if (attr.isStringAttribute() &&
1499 attr.getKindAsString() == "enzyme_sret_v") {
1500 NewAttrs = NewAttrs.addAttribute(
1501 F->getContext(), AttributeList::FirstArgIndex + vals.size(),
1502 Attribute::get(F->getContext(), "enzyme_sret",
1503 attr.getValueAsString()));
1504 } else if (attr.isStringAttribute() &&
1505 attr.getKindAsString() == "enzymejl_returnRoots_v") {
1506 NewAttrs = NewAttrs.addAttribute(
1507 F->getContext(), AttributeList::FirstArgIndex + vals.size(),
1508 Attribute::get(F->getContext(), "enzymejl_returnRoots",
1509 attr.getValueAsString()));
1510 } else if (attr.isStringAttribute() &&
1511 attr.getKindAsString() == "enzymejl_rooted_typ_v") {
1512 NewAttrs = NewAttrs.addAttribute(
1513 F->getContext(), AttributeList::FirstArgIndex + vals.size(),
1514 Attribute::get(F->getContext(), "enzymejl_rooted_typ",
1515 attr.getValueAsString()));
1516 } else {
1517 NewAttrs = NewAttrs.addAttribute(
1518 F->getContext(), AttributeList::FirstArgIndex + vals.size(),
1519 attr);
1520 }
1521 }
1522
1523 vals.push_back(CI->getArgOperand(j));
1524 }
1525
1526 SmallVector<OperandBundleDef, 1> Bundles;
1527 for (unsigned I = 0, E = CI->getNumOperandBundles(); I != E; ++I)
1528 Bundles.emplace_back(CI->getOperandBundleAt(I));
1529 auto NC = B.CreateCall(NewF, vals, Bundles);
1530 NC->setAttributes(NewAttrs);
1531
1532 SmallVector<std::pair<unsigned, MDNode *>, 4> TheMDs;
1533 CI->getAllMetadataOtherThanDebugLoc(TheMDs);
1534 SmallVector<unsigned, 1> toCopy;
1535 for (auto pair : TheMDs)
1536 toCopy.push_back(pair.first);
1537 if (!toCopy.empty())
1538 NC->copyMetadata(*CI, toCopy);
1539 NC->setDebugLoc(CI->getDebugLoc());
1540
1541 if (!RT->isVoidTy()) {
1542 NC->takeName(CI);
1543 CI->replaceAllUsesWith(NC);
1544 }
1545
1546 NC->setCallingConv(CI->getCallingConv());
1547 CI->eraseFromParent();
1548 }
1549 NewF->setAttributes(NewAttrs);
1550 SmallVector<std::pair<unsigned, MDNode *>, 1> MD;
1551 F->getAllMetadata(MD);
1552 for (auto pair : MD)
1553 if (pair.first != LLVMContext::MD_dbg)
1554 NewF->addMetadata(pair.first, *pair.second);
1555 NewF->takeName(F);
1556 NewF->setCallingConv(F->getCallingConv());
1557 F->eraseFromParent();
1558}
1559
1561 : public PassInfoMixin<FixupJuliaCallingConventionNewPM> {
1562 bool sret_jlvalue;
1563
1564public:
1566 : sret_jlvalue(sret_jlvalue) {}
1567
1568 PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM) {
1569 bool changed = false;
1570 SmallVector<llvm::Function *, 16> Functions;
1571 for (auto &F : M) {
1572 if (F.empty())
1573 continue;
1574 Functions.push_back(&F);
1575 }
1576 for (auto *F : Functions) {
1577 EnzymeFixupJuliaCallingConvention(F, sret_jlvalue);
1578 changed = true;
1579 }
1580 return changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
1581 }
1582};
1583
1585 : public PassInfoMixin<FixupBatchedJuliaCallingConventionNewPM> {
1586public:
1587 PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM) {
1588 bool changed = false;
1589 SmallVector<llvm::Function *, 16> Functions;
1590 for (auto &F : M) {
1591 if (F.empty())
1592 continue;
1593 Functions.push_back(&F);
1594 }
1595 for (auto *F : Functions) {
1597 changed = true;
1598 }
1599 return changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
1600 }
1601};
1602
1603// Expose New PM pass for registration
1604bool registerFixupJuliaPass(StringRef Name, ModulePassManager &MPM) {
1605 if (Name == "enzyme-fixup-julia") {
1606 MPM.addPass(FixupJuliaCallingConventionNewPM(false));
1607 return true;
1608 }
1609 if (Name == "enzyme-fixup-julia-sret") {
1610 MPM.addPass(FixupJuliaCallingConventionNewPM(true));
1611 return true;
1612 }
1613 if (Name == "enzyme-fixup-batched-julia") {
1615 return true;
1616 }
1617 return false;
1618}
static bool isReadOnlyNoCapture(Function *F, unsigned argNo)
bool needsReRooting(llvm::Argument *arg, bool &anyJLStore, llvm::Type *SRetType=nullptr)
static void removeRange(std::vector< std::pair< uint64_t, uint64_t > > &ranges, uint64_t start, uint64_t end)
void EnzymeFixupBatchedJuliaCallingConvention(Function *F)
bool needsReReturning(llvm::Argument *arg, size_t &sret_idx, std::map< size_t, size_t > &srets_without_stores)
#define getAttribute
static bool isOpaque(llvm::Type *T)
bool DetectPointerArgOfFn(llvm::Function &F, llvm::SmallPtrSetImpl< llvm::Function * > &calls_todo)
static bool isGuaranteedToFullyWrite(Function *F, unsigned argNo, Type *T)
bool registerFixupJuliaPass(StringRef Name, ModulePassManager &MPM)
void EnzymeFixupJuliaCallingConvention(Function *F, bool sret_jlvalue)
llvm::SmallVector< llvm::Instruction *, 2 > PostCacheStore(llvm::StoreInst *SI, llvm::IRBuilder<> &B)
Definition Utils.cpp:423
LLVMValueRef(* CustomErrorHandler)(const char *, LLVMValueRef, ErrorType, const void *, LLVMValueRef, LLVMBuilderRef)
Definition Utils.cpp:62
llvm::Value * moveSRetToFromRoots(llvm::IRBuilder<> &B, llvm::Type *jltype, llvm::Value *sret, llvm::Type *root_ty, llvm::Value *rootRet, size_t rootOffset, SRetRootMovement direction)
Definition Utils.cpp:4714
SmallVector< std::tuple< Instruction *, Value *, size_t >, 1 > findAllUsersOf(Value *AI)
Definition Utils.cpp:3210
void copyNonJLValueInto(llvm::IRBuilder<> &B, llvm::Type *curType, llvm::Type *dstType, llvm::Value *dst, llvm::ArrayRef< unsigned > dstPrefix0, llvm::Type *srcType, llvm::Value *src, llvm::ArrayRef< unsigned > srcPrefix0, bool shouldZero)
Definition Utils.cpp:4837
bool collectOffset(GEPOperator *gep, const DataLayout &DL, unsigned BitWidth, MapVector< Value *, APInt > &VariableOffsets, APInt &ConstantOffset)
Definition Utils.cpp:4169
static size_t convertRRootCountFromString(llvm::StringRef str)
Definition Utils.h:2481
static llvm::PointerType * getUnqual(llvm::Type *T)
Definition Utils.h:1179
void EmitFailure(llvm::StringRef RemarkName, const llvm::DiagnosticLocation &Loc, const llvm::Instruction *CodeRegion, Args &...args)
Definition Utils.h:203
static llvm::Type * convertSRetTypeFromString(llvm::StringRef str, llvm::LLVMContext *C=nullptr)
Definition Utils.h:2433
static llvm::Value * getBaseObject(llvm::Value *V, bool offsetAllowed=true)
Definition Utils.h:1507
static bool isSpecialPtr(llvm::Type *Ty)
Definition Utils.h:2354
@ Tracked
Definition Utils.h:2341
static std::string convertSRetTypeToString(llvm::Type *T)
Definition Utils.h:2428
static bool isWriteOnly(const llvm::Function *F, ssize_t arg=-1)
Definition Utils.h:1788
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM)
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM)
static llvm::Value * extractMeta(llvm::IRBuilder<> &Builder, llvm::Value *Agg, unsigned off, const llvm::Twine &name="")
Helper routine to extract a nested element from a struct/array. This is.