Enzyme main
Loading...
Searching...
No Matches
CacheUtility.cpp
Go to the documentation of this file.
1//===- CacheUtility.cpp - Caching values in the forward pass for later use
2//-===//
3//
4// Enzyme Project
5//
6// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
7// Exceptions. See https://llvm.org/LICENSE.txt for license information.
8// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
9//
10// If using this code in an academic setting, please cite the following:
11// @incollection{enzymeNeurips,
12// title = {Instead of Rewriting Foreign Code for Machine Learning,
13// Automatically Synthesize Fast Gradients},
14// author = {Moses, William S. and Churavy, Valentin},
15// booktitle = {Advances in Neural Information Processing Systems 33},
16// year = {2020},
17// note = {To appear in},
18// }
19//
20//===----------------------------------------------------------------------===//
21//
22// This file defines a base helper class CacheUtility that manages the cache
23// of values from the forward pass for later use.
24//
25//===----------------------------------------------------------------------===//
26
27#include "CacheUtility.h"
28#include "FunctionUtils.h"
29
30using namespace llvm;
31
32/// Pack 8 bools together in a single byte
33extern "C" {
34llvm::cl::opt<bool>
35 EfficientBoolCache("enzyme-smallbool", cl::init(false), cl::Hidden,
36 cl::desc("Place 8 bools together in a single byte"));
37
38llvm::cl::opt<bool> EnzymeZeroCache("enzyme-zero-cache", cl::init(false),
39 cl::Hidden,
40 cl::desc("Zero initialize the cache"));
41
42llvm::cl::opt<bool>
43 EnzymePrintPerf("enzyme-print-perf", cl::init(false), cl::Hidden,
44 cl::desc("Enable Enzyme to print performance info"));
45
46llvm::cl::opt<bool> EfficientMaxCache(
47 "enzyme-max-cache", cl::init(false), cl::Hidden,
48 cl::desc(
49 "Avoid reallocs when possible by potentially overallocating cache"));
50}
51
53
54/// Erase this instruction both from LLVM modules and any local data-structures
55void CacheUtility::erase(Instruction *I) {
56 assert(I);
57
58 if (auto found = findInMap(scopeMap, (Value *)I)) {
59 scopeFrees.erase(found->first);
60 scopeAllocs.erase(found->first);
61 scopeInstructions.erase(found->first);
62 }
63 if (auto AI = dyn_cast<AllocaInst>(I)) {
64 scopeFrees.erase(AI);
65 scopeAllocs.erase(AI);
66 scopeInstructions.erase(AI);
67 }
68 scopeMap.erase(I);
69 SE.eraseValueFromMap(I);
70
71 if (!I->use_empty()) {
72 std::string str;
73 raw_string_ostream ss(str);
74 ss << "Erased value with a use:\n";
75 ss << *newFunc->getParent() << "\n";
76 ss << *newFunc << "\n";
77 ss << *I << "\n";
80 nullptr, nullptr, nullptr);
81 } else {
82 EmitFailure("GetIndexError", I->getDebugLoc(), I, ss.str());
83 }
84 I->replaceAllUsesWith(UndefValue::get(I->getType()));
85 }
86 assert(I->use_empty());
87 I->eraseFromParent();
88}
89
90/// Replace this instruction both in LLVM modules and any local data-structures
91void CacheUtility::replaceAWithB(Value *A, Value *B, bool storeInCache) {
92 auto found = scopeMap.find(A);
93 if (found != scopeMap.end()) {
94 insert_or_assign2(scopeMap, B, found->second);
95
96 llvm::AllocaInst *cache = found->second.first;
97 if (storeInCache) {
98 assert(isa<Instruction>(B));
99 auto stfound = scopeInstructions.find(cache);
100 if (stfound != scopeInstructions.end()) {
101 SmallVector<Instruction *, 3> tmpInstructions(stfound->second.begin(),
102 stfound->second.end());
103 scopeInstructions.erase(stfound);
104 for (auto st : tmpInstructions)
105 cast<StoreInst>(&*st)->eraseFromParent();
106 MDNode *TBAA = nullptr;
107 if (auto I = dyn_cast<Instruction>(A))
108 TBAA = I->getMetadata(LLVMContext::MD_tbaa);
109 storeInstructionInCache(found->second.second, cast<Instruction>(B),
110 cache, TBAA);
111 }
112 }
113
114 scopeMap.erase(A);
115 }
116 A->replaceAllUsesWith(B);
117}
118
119// Create a new canonical induction variable of Type Ty for Loop L
120// Return the variable and the increment instruction
121std::pair<PHINode *, Instruction *>
122InsertNewCanonicalIV(Loop *L, Type *Ty, const llvm::Twine &Name) {
123 assert(L);
124 assert(Ty);
125
126 BasicBlock *Header = L->getHeader();
127 assert(Header);
128 IRBuilder<> B(Header, Header->begin());
129 PHINode *CanonicalIV = B.CreatePHI(Ty, 1, Name);
130
131 B.SetInsertPoint(Header->getFirstNonPHIOrDbg());
132 Instruction *Inc = cast<Instruction>(
133 B.CreateAdd(CanonicalIV, ConstantInt::get(Ty, 1), Name + ".next",
134 /*NUW*/ true, /*NSW*/ true));
135
136 for (BasicBlock *Pred : predecessors(Header)) {
137 assert(Pred);
138 if (L->contains(Pred)) {
139 CanonicalIV->addIncoming(Inc, Pred);
140 } else {
141 CanonicalIV->addIncoming(ConstantInt::get(Ty, 0), Pred);
142 }
143 }
144 assert(L->getCanonicalInductionVariable() == CanonicalIV);
145 return std::pair<PHINode *, Instruction *>(CanonicalIV, Inc);
146}
147
148// Create a new canonical induction variable of Type Ty for Loop L
149// Return the variable and the increment instruction
150std::pair<PHINode *, Instruction *> FindCanonicalIV(Loop *L, Type *Ty) {
151 assert(L);
152 assert(Ty);
153
154 BasicBlock *Header = L->getHeader();
155 assert(Header);
156 for (BasicBlock::iterator II = Header->begin(); isa<PHINode>(II); ++II) {
157 PHINode *PN = cast<PHINode>(II);
158 if (PN->getType() != Ty)
159 continue;
160
161 Instruction *Inc = nullptr;
162 bool legal = true;
163 for (BasicBlock *Pred : predecessors(Header)) {
164 assert(Pred);
165 if (L->contains(Pred)) {
166 auto Inc2 =
167 dyn_cast<BinaryOperator>(PN->getIncomingValueForBlock(Pred));
168 if (!Inc2 || Inc2->getOpcode() != Instruction::Add ||
169 Inc2->getOperand(0) != PN) {
170 legal = false;
171 break;
172 }
173 auto CI = dyn_cast<ConstantInt>(Inc2->getOperand(1));
174 if (!CI || !CI->isOne()) {
175 legal = false;
176 break;
177 }
178 if (Inc) {
179 if (Inc2 != Inc) {
180 legal = false;
181 break;
182 }
183 } else
184 Inc = Inc2;
185 } else {
186 auto CI = dyn_cast<ConstantInt>(PN->getIncomingValueForBlock(Pred));
187 if (!CI || !CI->isZero()) {
188 legal = false;
189 break;
190 }
191 }
192 }
193 if (!legal)
194 continue;
195 if (!Inc)
196 continue;
197 if (Inc != getFirstNonPHIOrDbg(Header))
198 Inc->moveBefore(getFirstNonPHIOrDbg(Header));
199 return std::make_pair(PN, Inc);
200 }
201 llvm::errs() << *Header << "\n";
202 assert(0 && "Could not find canonical IV");
203 return std::pair<PHINode *, Instruction *>(nullptr, nullptr);
204}
205
206// Attempt to rewrite all phinode's in the loop in terms of the
207// induction variable
209 BasicBlock *Header, PHINode *CanonicalIV, Instruction *Increment,
211 llvm::function_ref<void(Instruction *, Value *)> replacer,
212 llvm::function_ref<void(Instruction *)> eraser) {
213 assert(Header);
214 assert(CanonicalIV);
215 SmallVector<Instruction *, 8> IVsToRemove;
216
217 auto CanonicalSCEV = SE.getSCEV(CanonicalIV);
218
219 for (BasicBlock::iterator II = Header->begin(); isa<PHINode>(II);) {
220 PHINode *PN = cast<PHINode>(II);
221 ++II;
222 if (PN == CanonicalIV)
223 continue;
224 if (!SE.isSCEVable(PN->getType()))
225 continue;
226 const SCEV *S = SE.getSCEV(PN);
227 if (SE.getCouldNotCompute() == S || isa<SCEVUnknown>(S))
228 continue;
229 // we may expand code for phi where not legal (computing with
230 // subloop expressions). Check that this isn't the case
231 if (!SE.dominates(S, Header))
232 continue;
233
234 if (S == CanonicalSCEV) {
235 replacer(PN, CanonicalIV);
236 eraser(PN);
237 continue;
238 }
239
240 IRBuilder<> B(PN);
241 auto Tmp = B.CreatePHI(PN->getType(), 0);
242 for (auto Pred : predecessors(Header))
243 Tmp->addIncoming(UndefValue::get(Tmp->getType()), Pred);
244 replacer(PN, Tmp);
245 eraser(PN);
246
247 // This scope is necessary to ensure scevexpander cleans up before we erase
248 // things
249#if LLVM_VERSION_MAJOR >= 22
250 SCEVExpander Exp(SE, "enzyme");
251#else
252 SCEVExpander Exp(SE, Header->getParent()->getParent()->getDataLayout(),
253 "enzyme");
254#endif
255
256 // We place that at first non phi as it may produce a non-phi instruction
257 // and must thus be expanded after all phi's
258 Value *NewIV =
259 Exp.expandCodeFor(S, Tmp->getType(), Header->getFirstNonPHI());
260
261 // Explicity preserve wrap behavior from original iv. This is necessary
262 // until this PR in llvm is merged:
263 // https://github.com/llvm/llvm-project/pull/78199
264 if (auto addrec = dyn_cast<SCEVAddRecExpr>(S)) {
265 if (addrec->getLoop()->getHeader() == Header) {
266 if (auto add_or_mul = dyn_cast<BinaryOperator>(NewIV)) {
267#if LLVM_VERSION_MAJOR >= 23
268 if (any(addrec->getNoWrapFlags(llvm::SCEV::FlagNUW)))
269 add_or_mul->setHasNoUnsignedWrap(true);
270 if (any(addrec->getNoWrapFlags(llvm::SCEV::FlagNSW)))
271 add_or_mul->setHasNoSignedWrap(true);
272#else
273 if (addrec->getNoWrapFlags(llvm::SCEV::FlagNUW))
274 add_or_mul->setHasNoUnsignedWrap(true);
275 if (addrec->getNoWrapFlags(llvm::SCEV::FlagNSW))
276 add_or_mul->setHasNoSignedWrap(true);
277#endif
278 }
279 }
280 }
281 replacer(Tmp, NewIV);
282 eraser(Tmp);
283 }
284
285 // Replace existing increments with canonical Increment
286 Increment->moveAfter(CanonicalIV->getParent()->getFirstNonPHI());
287 SmallVector<Instruction *, 1> toErase;
288 for (auto use : CanonicalIV->users()) {
289 auto BO = dyn_cast<BinaryOperator>(use);
290 if (BO == nullptr)
291 continue;
292 if (BO->getOpcode() != BinaryOperator::Add)
293 continue;
294 if (use == Increment)
295 continue;
296
297 Value *toadd = nullptr;
298 if (BO->getOperand(0) == CanonicalIV) {
299 toadd = BO->getOperand(1);
300 } else {
301 assert(BO->getOperand(1) == CanonicalIV);
302 toadd = BO->getOperand(0);
303 }
304 if (auto CI = dyn_cast<ConstantInt>(toadd)) {
305 if (!CI->isOne())
306 continue;
307 BO->replaceAllUsesWith(Increment);
308 toErase.push_back(BO);
309 } else {
310 continue;
311 }
312 }
313 for (auto BO : toErase)
314 eraser(BO);
315}
316
317void CanonicalizeLatches(const Loop *L, BasicBlock *Header,
318 BasicBlock *Preheader, PHINode *CanonicalIV,
320 Instruction *Increment,
321 ArrayRef<BasicBlock *> latches) {
322 // Attempt to explicitly rewrite the latch
323 if (latches.size() == 1 && isa<BranchInst>(latches[0]->getTerminator()) &&
324 cast<BranchInst>(latches[0]->getTerminator())->isConditional())
325 for (auto use : CanonicalIV->users()) {
326 if (auto cmp = dyn_cast<ICmpInst>(use)) {
327 if (cast<BranchInst>(latches[0]->getTerminator())->getCondition() !=
328 cmp)
329 continue;
330 // Force i to be on LHS
331 if (cmp->getOperand(0) != CanonicalIV) {
332 // Below also swaps predicate correctly
333 cmp->swapOperands();
334 }
335 assert(cmp->getOperand(0) == CanonicalIV);
336
337 auto scv = SE.getSCEVAtScope(cmp->getOperand(1), L);
338 if (cmp->isUnsigned() ||
339 (scv != SE.getCouldNotCompute() && SE.isKnownNonNegative(scv))) {
340
341 // valid replacements (since unsigned comparison and i starts at 0
342 // counting up)
343
344 // * i < n => i != n, valid since first time i >= n occurs at i == n
345 if (cmp->getPredicate() == ICmpInst::ICMP_ULT ||
346 cmp->getPredicate() == ICmpInst::ICMP_SLT) {
347 cmp->setPredicate(ICmpInst::ICMP_NE);
348 goto cend;
349 }
350
351 // * i <= n => i != n+1, valid since first time i > n occurs at i ==
352 // n+1 [ which we assert is in bitrange as not infinite loop ]
353 if (cmp->getPredicate() == ICmpInst::ICMP_ULE ||
354 cmp->getPredicate() == ICmpInst::ICMP_SLE) {
355 IRBuilder<> builder(Preheader->getTerminator());
356 if (auto inst = dyn_cast<Instruction>(cmp->getOperand(1))) {
357 builder.SetInsertPoint(inst->getNextNode());
358 }
359 cmp->setOperand(
360 1,
361 builder.CreateNUWAdd(
362 cmp->getOperand(1),
363 ConstantInt::get(cmp->getOperand(1)->getType(), 1, false)));
364 cmp->setPredicate(ICmpInst::ICMP_NE);
365 goto cend;
366 }
367
368 // * i >= n => i == n, valid since first time i >= n occurs at i == n
369 if (cmp->getPredicate() == ICmpInst::ICMP_UGE ||
370 cmp->getPredicate() == ICmpInst::ICMP_SGE) {
371 cmp->setPredicate(ICmpInst::ICMP_EQ);
372 goto cend;
373 }
374
375 // * i > n => i == n+1, valid since first time i > n occurs at i ==
376 // n+1 [ which we assert is in bitrange as not infinite loop ]
377 if (cmp->getPredicate() == ICmpInst::ICMP_UGT ||
378 cmp->getPredicate() == ICmpInst::ICMP_SGT) {
379 IRBuilder<> builder(Preheader->getTerminator());
380 if (auto inst = dyn_cast<Instruction>(cmp->getOperand(1))) {
381 builder.SetInsertPoint(inst->getNextNode());
382 }
383 cmp->setOperand(
384 1,
385 builder.CreateNUWAdd(
386 cmp->getOperand(1),
387 ConstantInt::get(cmp->getOperand(1)->getType(), 1, false)));
388 cmp->setPredicate(ICmpInst::ICMP_EQ);
389 goto cend;
390 }
391 }
392 cend:;
393 if (cmp->getPredicate() == ICmpInst::ICMP_NE) {
394 }
395 }
396 }
397
398 // Replace previous increment usage with new increment value
399 if (Increment) {
400 Increment->moveAfter(CanonicalIV->getParent()->getFirstNonPHI());
401
402 if (latches.size() == 1 && isa<BranchInst>(latches[0]->getTerminator()) &&
403 cast<BranchInst>(latches[0]->getTerminator())->isConditional())
404 for (auto use : Increment->users()) {
405 if (auto cmp = dyn_cast<ICmpInst>(use)) {
406 if (cast<BranchInst>(latches[0]->getTerminator())->getCondition() !=
407 cmp)
408 continue;
409
410 // Force i+1 to be on LHS
411 if (cmp->getOperand(0) != Increment) {
412 // Below also swaps predicate correctly
413 cmp->swapOperands();
414 }
415 assert(cmp->getOperand(0) == Increment);
416
417 auto scv = SE.getSCEVAtScope(cmp->getOperand(1), L);
418 if (cmp->isUnsigned() ||
419 (scv != SE.getCouldNotCompute() && SE.isKnownNonNegative(scv))) {
420
421 // valid replacements (since unsigned comparison and i starts at 0
422 // counting up)
423
424 // * i+1 < n => i+1 != n, valid since first time i+1 >= n occurs at
425 // i+1 == n
426 if (cmp->getPredicate() == ICmpInst::ICMP_ULT ||
427 cmp->getPredicate() == ICmpInst::ICMP_SLT) {
428 cmp->setPredicate(ICmpInst::ICMP_NE);
429 continue;
430 }
431
432 // * i+1 <= n => i != n, valid since first time i+1 > n occurs at
433 // i+1 == n+1 => i == n
434 if (cmp->getPredicate() == ICmpInst::ICMP_ULE ||
435 cmp->getPredicate() == ICmpInst::ICMP_SLE) {
436 cmp->setOperand(0, CanonicalIV);
437 cmp->setPredicate(ICmpInst::ICMP_NE);
438 continue;
439 }
440
441 // * i+1 >= n => i+1 == n, valid since first time i+1 >= n occurs at
442 // i+1 == n
443 if (cmp->getPredicate() == ICmpInst::ICMP_UGE ||
444 cmp->getPredicate() == ICmpInst::ICMP_SGE) {
445 cmp->setPredicate(ICmpInst::ICMP_EQ);
446 continue;
447 }
448
449 // * i+1 > n => i == n, valid since first time i+1 > n occurs at i+1
450 // == n+1 => i == n
451 if (cmp->getPredicate() == ICmpInst::ICMP_UGT ||
452 cmp->getPredicate() == ICmpInst::ICMP_SGT) {
453 cmp->setOperand(0, CanonicalIV);
454 cmp->setPredicate(ICmpInst::ICMP_EQ);
455 continue;
456 }
457 }
458 }
459 }
460 }
461}
462
463llvm::AllocaInst *CacheUtility::getDynamicLoopLimit(llvm::Loop *L,
464 bool ReverseLimit) {
465 assert(L);
466 assert(loopContexts.find(L) != loopContexts.end());
467 auto &found = loopContexts[L];
468 assert(found.dynamic);
469 if (found.trueLimit)
470 return cast<AllocaInst>(&*found.trueLimit);
471
472 LimitContext lctx(ReverseLimit,
473 ReverseLimit ? found.preheader : &newFunc->getEntryBlock());
474 AllocaInst *LimitVar =
475 createCacheForScope(lctx, found.var->getType(), "loopLimit",
476 /*shouldfree*/ true);
477
478 for (auto ExitBlock : found.exitBlocks) {
479 IRBuilder<> B(ExitBlock, ExitBlock->begin());
480 auto Limit = B.CreatePHI(found.var->getType(), 1);
481
482 for (BasicBlock *Pred : predecessors(ExitBlock)) {
483 if (L->contains(Pred)) {
484 Limit->addIncoming(found.var, Pred);
485 } else {
486 Limit->addIncoming(UndefValue::get(found.var->getType()), Pred);
487 }
488 }
489
490 storeInstructionInCache(lctx, Limit, LimitVar);
491 }
492 found.trueLimit = LimitVar;
493 return LimitVar;
494}
495
496bool CacheUtility::getContext(BasicBlock *BB, LoopContext &loopContext,
497 bool ReverseLimit) {
498 assert(BB->getParent() == newFunc);
499 Loop *L = LI.getLoopFor(BB);
500
501 // Not inside a loop
502 if (L == nullptr)
503 return false;
504
505 // Previously handled this loop
506 if (auto found = findInMap(loopContexts, L)) {
507 loopContext = *found;
508 return true;
509 }
510
511 // Need to canonicalize
512 loopContexts[L].parent = L->getParentLoop();
513
514 loopContexts[L].header = L->getHeader();
515 assert(loopContexts[L].header && "loop must have header");
516
517 loopContexts[L].preheader = L->getLoopPreheader();
518 if (!L->getLoopPreheader()) {
519 llvm::errs() << "fn: " << *L->getHeader()->getParent() << "\n";
520 llvm::errs() << "L: " << *L << "\n";
521 }
522 assert(loopContexts[L].preheader && "loop must have preheader");
523 getExitBlocks(L, loopContexts[L].exitBlocks);
524
525 loopContexts[L].offset = nullptr;
526 loopContexts[L].allocLimit = nullptr;
527 // A precisely matching canonical IV shouldve been run during preprocessing.
528 auto pair = FindCanonicalIV(L, Type::getInt64Ty(BB->getContext()));
529 PHINode *CanonicalIV = pair.first;
530 auto incVar = pair.second;
531 assert(CanonicalIV);
532 loopContexts[L].var = CanonicalIV;
533 loopContexts[L].incvar = incVar;
534 CanonicalizeLatches(L, loopContexts[L].header, loopContexts[L].preheader,
535 CanonicalIV, SE, *this, incVar,
536 getLatches(L, loopContexts[L].exitBlocks));
537 loopContexts[L].antivaralloc =
538 IRBuilder<>(inversionAllocs)
539 .CreateAlloca(CanonicalIV->getType(), nullptr,
540 CanonicalIV->getName() + "'ac");
541 loopContexts[L].antivaralloc->setAlignment(
542 Align(cast<IntegerType>(CanonicalIV->getType())->getBitWidth() / 8));
543
544 const SCEV *Limit = nullptr;
545 const SCEV *MaxIterations = nullptr;
546 {
547 const SCEV *MayExitMaxBECount = nullptr;
548
549 SmallVector<BasicBlock *, 8> ExitingBlocks;
550 L->getExitingBlocks(ExitingBlocks);
551
552 // Remove all exiting blocks that are guaranteed
553 // to result in unreachable
554 for (auto &ExitingBlock : ExitingBlocks) {
555 BasicBlock *Exit = nullptr;
556 for (auto *SBB : successors(ExitingBlock)) {
557 if (!L->contains(SBB)) {
558 if (SE.GuaranteedUnreachable.count(SBB))
559 continue;
560 Exit = SBB;
561 break;
562 }
563 }
564 if (!Exit)
565 ExitingBlock = nullptr;
566 }
567 ExitingBlocks.erase(
568 std::remove(ExitingBlocks.begin(), ExitingBlocks.end(), nullptr),
569 ExitingBlocks.end());
570
571 // Compute the exit in the scenarios where an unreachable
572 // is not hit
573 for (BasicBlock *ExitingBlock : ExitingBlocks) {
574 assert(L->contains(ExitingBlock));
575
576 ScalarEvolution::ExitLimit EL =
577 SE.computeExitLimit(L, ExitingBlock, /*AllowPredicates*/ true);
578
579 bool seenHeaders = false;
580 SmallPtrSet<BasicBlock *, 4> Seen;
581 std::deque<BasicBlock *> Todo = {ExitingBlock};
582 while (Todo.size()) {
583 auto cur = Todo.front();
584 Todo.pop_front();
585 if (Seen.count(cur))
586 continue;
587 if (!L->contains(cur))
588 continue;
589 if (cur == loopContexts[L].header) {
590 seenHeaders = true;
591 break;
592 }
593 for (auto S : successors(cur)) {
594 Todo.push_back(S);
595 }
596 }
597 if (seenHeaders) {
598 if (MaxIterations == nullptr ||
599 MaxIterations == SE.getCouldNotCompute()) {
600 MaxIterations = EL.ExactNotTaken;
601 }
602 if (MaxIterations != SE.getCouldNotCompute()) {
603 if (EL.ExactNotTaken != SE.getCouldNotCompute()) {
604 MaxIterations =
605 SE.getUMaxFromMismatchedTypes(MaxIterations, EL.ExactNotTaken);
606 }
607 }
608
609 if (MayExitMaxBECount == nullptr ||
610 EL.ExactNotTaken == SE.getCouldNotCompute())
611 MayExitMaxBECount = EL.ExactNotTaken;
612
613 if (EL.ExactNotTaken != MayExitMaxBECount) {
614 MayExitMaxBECount = SE.getCouldNotCompute();
615 }
616 }
617 }
618 if (MayExitMaxBECount == nullptr) {
619 MayExitMaxBECount = SE.getCouldNotCompute();
620 }
621 if (MaxIterations == nullptr) {
622 MaxIterations = SE.getCouldNotCompute();
623 }
624 Limit = MayExitMaxBECount;
625 }
626 assert(Limit);
627 Value *LimitVar = nullptr;
628
629 if (SE.getCouldNotCompute() != Limit) {
630
631 if (CanonicalIV == nullptr) {
632 report_fatal_error("Couldn't get canonical IV.");
633 }
634
635 SmallPtrSet<const SCEV *, 2> PotentialMins;
636 SmallVector<const SCEV *, 2> Todo = {Limit};
637 while (Todo.size()) {
638 auto S = Todo.back();
639 Todo.pop_back();
640 if (auto SA = dyn_cast<SCEVSMaxExpr>(S)) {
641 for (auto op : SA->operands())
642 Todo.push_back(op);
643 } else if (auto SA = dyn_cast<SCEVUMaxExpr>(S)) {
644 for (auto op : SA->operands())
645 Todo.push_back(op);
646 } else if (auto SA = dyn_cast<SCEVAddExpr>(S)) {
647 for (auto op : SA->operands())
648 Todo.push_back(op);
649 } else
650 PotentialMins.insert(S);
651 }
652 for (auto op : PotentialMins) {
653 auto SM = dyn_cast<SCEVMulExpr>(op);
654 if (!SM)
655 continue;
656 if (SM->getNumOperands() != 2)
657 continue;
658 for (int i = 0; i < 2; i++)
659 if (auto C = dyn_cast<SCEVConstant>(SM->getOperand(i))) {
660 // is minus 1
661#if LLVM_VERSION_MAJOR > 16
662 if (C->getAPInt().isAllOnes())
663#else
664 if (C->getAPInt().isAllOnesValue())
665#endif
666 {
667 const SCEV *prev = SM->getOperand(1 - i);
668 while (true) {
669 if (auto ext = dyn_cast<SCEVZeroExtendExpr>(prev)) {
670 prev = ext->getOperand();
671 continue;
672 }
673 if (auto ext = dyn_cast<SCEVSignExtendExpr>(prev)) {
674 prev = ext->getOperand();
675 continue;
676 }
677 break;
678 }
679 if (auto V = dyn_cast<SCEVUnknown>(prev)) {
680 if (auto omp_lb_post = dyn_cast<LoadInst>(V->getValue())) {
681 auto AI =
682 dyn_cast<AllocaInst>(omp_lb_post->getPointerOperand());
683 if (AI) {
684 for (auto u : AI->users()) {
685 CallInst *call = dyn_cast<CallInst>(u);
686 if (!call)
687 continue;
688 Function *F = call->getCalledFunction();
689 if (!F)
690 continue;
691 if (F->getName() == "__kmpc_for_static_init_4" ||
692 F->getName() == "__kmpc_for_static_init_4u" ||
693 F->getName() == "__kmpc_for_static_init_8" ||
694 F->getName() == "__kmpc_for_static_init_8u") {
695 Value *lb = nullptr;
696 for (auto u : call->getArgOperand(4)->users()) {
697 if (auto si = dyn_cast<StoreInst>(u)) {
698 lb = si->getValueOperand();
699 break;
700 }
701 }
702 assert(lb);
703 Value *ub = nullptr;
704 for (auto u : call->getArgOperand(5)->users()) {
705 if (auto si = dyn_cast<StoreInst>(u)) {
706 ub = si->getValueOperand();
707 break;
708 }
709 }
710 assert(ub);
711 IRBuilder<> post(omp_lb_post->getNextNode());
712 loopContexts[L].allocLimit = post.CreateZExtOrTrunc(
713 post.CreateSub(ub, lb), CanonicalIV->getType());
714 loopContexts[L].offset = post.CreateZExtOrTrunc(
715 post.CreateSub(omp_lb_post, lb, "", true, true),
716 CanonicalIV->getType());
717 goto endOMP;
718 }
719 }
720 }
721 }
722 }
723 }
724 }
725 }
726 endOMP:;
727
728 if (Limit->getType() != CanonicalIV->getType())
729 Limit = SE.getZeroExtendExpr(Limit, CanonicalIV->getType());
730
731#if LLVM_VERSION_MAJOR >= 22
732 SCEVExpander Exp(SE, "enzyme");
733#else
734 SCEVExpander Exp(SE, BB->getParent()->getParent()->getDataLayout(),
735 "enzyme");
736#endif
737 LimitVar = Exp.expandCodeFor(Limit, CanonicalIV->getType(),
738 loopContexts[L].preheader->getTerminator());
739 loopContexts[L].dynamic = false;
740 loopContexts[L].maxLimit = LimitVar;
741 } else {
742 // TODO if assumeDynamicLoopOfSizeOne(L), only lazily allocate the scope
743 // cache
744 DebugLoc loc = L->getHeader()->begin()->getDebugLoc();
745 for (auto &I : *L->getHeader()) {
746 if (loc)
747 break;
748 loc = I.getDebugLoc();
749 }
750 EmitWarning("NoLimit", loc, L->getHeader(),
751 "SE could not compute loop limit of ",
752 L->getHeader()->getName(), " of ",
753 L->getHeader()->getParent()->getName(), "lim: ", *Limit,
754 " maxlim: ", *MaxIterations);
755
756 loopContexts[L].dynamic = true;
757 loopContexts[L].maxLimit = nullptr;
758
760 LimitVar = nullptr;
761 } else {
762 LimitVar = getDynamicLoopLimit(L, ReverseLimit);
763 }
764 }
765 loopContexts[L].trueLimit = LimitVar;
766 if (EfficientMaxCache && loopContexts[L].dynamic &&
767 SE.getCouldNotCompute() != MaxIterations) {
768 if (MaxIterations->getType() != CanonicalIV->getType())
769 MaxIterations =
770 SE.getZeroExtendExpr(MaxIterations, CanonicalIV->getType());
771
772#if LLVM_VERSION_MAJOR >= 22
773 SCEVExpander Exp(SE, "enzyme");
774#else
775 SCEVExpander Exp(SE, BB->getParent()->getParent()->getDataLayout(),
776 "enzyme");
777#endif
778
779 loopContexts[L].maxLimit =
780 Exp.expandCodeFor(MaxIterations, CanonicalIV->getType(),
781 loopContexts[L].preheader->getTerminator());
782 }
783 loopContext = loopContexts.find(L)->second;
784 return true;
785}
786
787/// Caching mechanism: creates a cache of type T in a scope given by ctx
788/// (where if ctx is in a loop there will be a corresponding number of slots)
790 StringRef name, bool shouldFree,
791 bool allocateInternal,
792 Value *extraSize) {
793 assert(ctx.Block);
794 assert(T);
795
796 auto sublimits =
797 getSubLimits(/*inForwardPass*/ true, nullptr, ctx, extraSize);
798
799 auto i64 = Type::getInt64Ty(T->getContext());
800
801 // List of types stored in the cache for each Loop-Chunk
802 // This is stored from innner-most chunk to outermost
803 // Thus it begins with the underlying type, and adds pointers
804 // to the previous type.
805 SmallVector<Type *, 4> types = {T};
806 SmallVector<PointerType *, 4> malloctypes;
807 bool isi1 = T->isIntegerTy() && cast<IntegerType>(T)->getBitWidth() == 1;
808 if (EfficientBoolCache && isi1 && sublimits.size() != 0)
809 types[0] = Type::getInt8Ty(T->getContext());
810 for (size_t i = 0; i < sublimits.size(); ++i) {
811 Type *allocType;
812 {
813 BasicBlock *BB =
814 BasicBlock::Create(newFunc->getContext(), "entry", newFunc);
815 IRBuilder<> B(BB);
816 auto P = B.CreatePHI(i64, 1);
817
818 CallInst *malloccall;
819 Instruction *Zero;
820 allocType = cast<PointerType>(CreateAllocation(B, types.back(), P,
821 "tmpfortypecalc",
822 &malloccall, &Zero)
823 ->getType());
824 malloctypes.push_back(cast<PointerType>(malloccall->getType()));
825 for (auto &I : make_early_inc_range(reverse(*BB)))
826 I.eraseFromParent();
827
828 BB->eraseFromParent();
829 }
830 types.push_back(allocType);
831 }
832
833 // Allocate the outermost type on the stack
834 IRBuilder<> entryBuilder(inversionAllocs);
835 entryBuilder.setFastMathFlags(getFast());
836 AllocaInst *alloc =
837 entryBuilder.CreateAlloca(types.back(), nullptr, name + "_cache");
838 {
839 ConstantInt *byteSizeOfType = ConstantInt::get(
840 i64, newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits(
841 types.back()) /
842 8);
843 unsigned align =
844 getCacheAlignment((unsigned)byteSizeOfType->getZExtValue());
845 alloc->setAlignment(Align(align));
846 }
847 if (sublimits.size() == 0) {
848 auto val = getUndefinedValueForType(*newFunc->getParent(), types.back());
849 if (!isa<UndefValue>(val))
850 scopeInstructions[alloc].push_back(entryBuilder.CreateStore(val, alloc));
851 }
852
853 Value *storeInto = alloc;
854
855 // Iterating from outermost chunk to innermost chunk
856 // Allocate and store the requisite memory if needed
857 // and lookup the next level pointer of the cache
858 for (int i = sublimits.size() - 1; i >= 0; i--) {
859 const auto &containedloops = sublimits[i].second;
860
861 Type *myType = types[i];
862
863 ConstantInt *byteSizeOfType = ConstantInt::get(
864 Type::getInt64Ty(T->getContext()),
865 newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits(myType) /
866 8);
867
868 unsigned bsize = (unsigned)byteSizeOfType->getZExtValue();
869 unsigned alignSize = getCacheAlignment(bsize);
870
871 CallInst *malloccall = nullptr;
872
873 // Allocate and store the required memory
874 if (allocateInternal) {
875
876 IRBuilder<> allocationBuilder(
877 &containedloops.back().first.preheader->back());
878
879 Value *size = sublimits[i].first;
880 if (EfficientBoolCache && isi1 && i == 0) {
881 size = allocationBuilder.CreateLShr(
882 allocationBuilder.CreateAdd(
883 size, ConstantInt::get(Type::getInt64Ty(T->getContext()), 7),
884 "", true),
885 ConstantInt::get(Type::getInt64Ty(T->getContext()), 3));
886 }
887 if (extraSize && i == 0) {
888 ValueToValueMapTy available;
889 for (auto &sl : sublimits) {
890 for (auto &cl : sl.second) {
891 if (cl.first.var)
892 available[cl.first.var] = cl.first.var;
893 }
894 }
895 Value *es = unwrapM(extraSize, allocationBuilder, available,
897 assert(es);
898 size = allocationBuilder.CreateMul(size, es, "", /*NUW*/ true,
899 /*NSW*/ true);
900 }
901
902 StoreInst *storealloc = nullptr;
903 // Statically allocate memory for all iterations if possible
904 if (sublimits[i].second.back().first.maxLimit) {
905 Instruction *ZeroInst = nullptr;
906 Value *firstallocation = CreateAllocation(
907 allocationBuilder, myType, size, name + "_malloccache", &malloccall,
908 /*ZeroMem*/ EnzymeZeroCache ? &ZeroInst : nullptr);
909
910 scopeInstructions[alloc].push_back(malloccall);
911 if (firstallocation != malloccall)
912 scopeInstructions[alloc].push_back(
913 cast<Instruction>(firstallocation));
914
915 for (auto &actx : sublimits[i].second) {
916 if (actx.first.offset) {
917 malloccall->setMetadata("enzyme_ompfor",
918 MDNode::get(malloccall->getContext(), {}));
919 break;
920 }
921 }
922
923 if (ZeroInst) {
924 if (ZeroInst->getOperand(0) != malloccall) {
925 scopeInstructions[alloc].push_back(
926 cast<Instruction>(ZeroInst->getOperand(0)));
927 }
928 scopeInstructions[alloc].push_back(ZeroInst);
929 }
930 storealloc = allocationBuilder.CreateStore(firstallocation, storeInto);
931
932 scopeAllocs[alloc].push_back(malloccall);
933
934 // Mark the store as invariant since the allocation is static and
935 // will not be changed
936 if (CachePointerInvariantGroups.find(std::make_pair(
937 (Value *)alloc, i)) == CachePointerInvariantGroups.end()) {
938 MDNode *invgroup = MDNode::getDistinct(alloc->getContext(), {});
939 CachePointerInvariantGroups[std::make_pair((Value *)alloc, i)] =
940 invgroup;
941 }
942 storealloc->setMetadata(
943 LLVMContext::MD_invariant_group,
944 CachePointerInvariantGroups[std::make_pair((Value *)alloc, i)]);
945 scopeInstructions[alloc].push_back(storealloc);
946 for (auto post : PostCacheStore(storealloc, allocationBuilder)) {
947 scopeInstructions[alloc].push_back(post);
948 }
949 } else {
950 llvm::PointerType *allocType = cast<PointerType>(types[i + 1]);
951 llvm::PointerType *mallocType = malloctypes[i];
952
953 // Reallocate memory dynamically as a fallback
954 // TODO change this to a power-of-two allocation strategy
955
956 auto zerostore = allocationBuilder.CreateStore(
957 getUndefinedValueForType(*newFunc->getParent(), allocType,
958 /*forceZero*/ true),
959 storeInto);
960 scopeInstructions[alloc].push_back(zerostore);
961
962 IRBuilder<> build(containedloops.back().first.incvar->getNextNode());
963 Value *allocation = build.CreateLoad(allocType, storeInto);
964
965 if (allocation->getType() != mallocType) {
966 auto I =
967 cast<Instruction>(build.CreateBitCast(allocation, mallocType));
968 scopeInstructions[alloc].push_back(I);
969 allocation = I;
970 }
971
972 CallInst *realloccall = nullptr;
973 auto reallocation = CreateReAllocation(
974 build, allocation, myType, containedloops.back().first.incvar, size,
975 name + "_realloccache", &realloccall, EnzymeZeroCache && i == 0);
976
977 scopeInstructions[alloc].push_back(cast<Instruction>(reallocation));
978
979 if (reallocation->getType() != allocType) {
980 auto I =
981 cast<Instruction>(build.CreateBitCast(reallocation, allocType));
982 scopeInstructions[alloc].push_back(I);
983 reallocation = I;
984 }
985
986 scopeAllocs[alloc].push_back(realloccall);
987
988 storealloc = build.CreateStore(reallocation, storeInto);
989 // Unlike the static case we can not mark the memory as invariant
990 // since we are reloading/storing based off the number of loop
991 // iterations
992 scopeInstructions[alloc].push_back(storealloc);
993 for (auto post : PostCacheStore(storealloc, build)) {
994 scopeInstructions[alloc].push_back(post);
995 }
996 }
997
998 // Regardless of how allocated (dynamic vs static), mark it
999 // as having the requisite alignment
1000 storealloc->setAlignment(Align(alignSize));
1001 }
1002
1003 // Free the memory, if requested
1004 if (shouldFree) {
1005 if (CachePointerInvariantGroups.find(std::make_pair((Value *)alloc, i)) ==
1006 CachePointerInvariantGroups.end()) {
1007 MDNode *invgroup = MDNode::getDistinct(alloc->getContext(), {});
1008 CachePointerInvariantGroups[std::make_pair((Value *)alloc, i)] =
1009 invgroup;
1010 }
1011 Type *nextType = types[i + 1];
1012 auto freecall = freeCache(
1013 containedloops.back().first.preheader, sublimits, i, alloc, nextType,
1014 byteSizeOfType, storeInto,
1015 CachePointerInvariantGroups[std::make_pair((Value *)alloc, i)]);
1016 if (freecall && malloccall) {
1017 auto ident = MDNode::getDistinct(malloccall->getContext(), {});
1018 malloccall->setMetadata("enzyme_cache_alloc",
1019 MDNode::get(malloccall->getContext(), {ident}));
1020 freecall->setMetadata("enzyme_cache_free",
1021 MDNode::get(freecall->getContext(), {ident}));
1022 }
1023 }
1024
1025 // If we are not the final iteration, lookup the next pointer by indexing
1026 // into the relevant location of the current chunk allocation
1027 if (i != 0) {
1028 IRBuilder<> v(&sublimits[i - 1].second.back().first.preheader->back());
1029
1030 Value *idx = computeIndexOfChunk(
1031 /*inForwardPass*/ true, v, containedloops,
1032 /*available*/ ValueToValueMapTy());
1033
1034 storeInto = v.CreateLoad(types[i + 1], storeInto);
1035 cast<LoadInst>(storeInto)->setAlignment(Align(alignSize));
1036 storeInto = v.CreateGEP(types[i], storeInto, idx);
1037 cast<GetElementPtrInst>(storeInto)->setIsInBounds(true);
1038 }
1039 }
1040 return alloc;
1041}
1042
1043Value *CacheUtility::computeIndexOfChunk(
1044 bool inForwardPass, IRBuilder<> &v,
1045 ArrayRef<std::pair<LoopContext, llvm::Value *>> containedloops,
1046 const ValueToValueMapTy &available) {
1047 // List of loop indices in chunk from innermost to outermost
1048 SmallVector<Value *, 3> indices;
1049 // List of cumulative indices in chunk from innermost to outermost
1050 // where limit[i] = prod(loop limit[0..i])
1051 SmallVector<Value *, 3> limits;
1052
1053 // Iterate from innermost loop to outermost loop within a chunk
1054 for (size_t i = 0; i < containedloops.size(); ++i) {
1055 const auto &pair = containedloops[i];
1056
1057 const auto &idx = pair.first;
1058 Value *var = idx.var;
1059
1060 // In the SingleIteration, var may be null (since there's no legal phinode)
1061 // In that case the current iteration is simply the constnat Zero
1062 if (idx.var == nullptr)
1063 var = ConstantInt::get(Type::getInt64Ty(newFunc->getContext()), 0);
1064 else if (available.count(var)) {
1065 var = available.find(var)->second;
1066 } else if (!inForwardPass) {
1067 var = v.CreateLoad(idx.var->getType(), idx.antivaralloc);
1068 } else {
1069 var = idx.var;
1070 }
1071 if (idx.offset) {
1072 var = v.CreateAdd(var, lookupM(idx.offset, v), "", /*NUW*/ true,
1073 /*NSW*/ true);
1074 }
1075
1076 indices.push_back(var);
1077 Value *lim = pair.second;
1078 assert(lim);
1079 if (limits.size() == 0) {
1080 limits.push_back(lim);
1081 } else {
1082 limits.push_back(v.CreateMul(limits.back(), lim, "",
1083 /*NUW*/ true, /*NSW*/ true));
1084 }
1085 }
1086
1087 assert(indices.size() > 0);
1088
1089 // Compute the index into the pointer
1090 Value *idx = indices[0];
1091 for (unsigned ind = 1; ind < indices.size(); ++ind) {
1092 idx = v.CreateAdd(idx,
1093 v.CreateMul(indices[ind], limits[ind - 1], "",
1094 /*NUW*/ true, /*NSW*/ true),
1095 "", /*NUW*/ true, /*NSW*/ true);
1096 }
1097 return idx;
1098}
1099
1100/// Given a LimitContext ctx, representing a location inside a loop nest,
1101/// break each of the loops up into chunks of loops where each chunk's number
1102/// of iterations can be computed at the chunk preheader. Every dynamic loop
1103/// defines the start of a chunk. SubLimitType is a vector of chunk objects.
1104/// More specifically it is a vector of { # iters in a Chunk (sublimit), Chunk }
1105/// Each chunk object is a vector of loops contained within the chunk.
1106/// For every loop, this returns pair of the LoopContext and the limit of that
1107/// loop Both the vector of Chunks and vector of Loops within a Chunk go from
1108/// innermost loop to outermost loop.
1110 IRBuilder<> *RB,
1111 LimitContext ctx,
1112 Value *extraSize) {
1113 // Store the LoopContext's in InnerMost => Outermost order
1114 SmallVector<LoopContext, 4> contexts;
1115
1116 // Given a ``SingleIteration'' Limit Context, return a chunking of
1117 // one loop with size 1, and header/preheader of the BasicBlock
1118 // This is done to create a context for a block outside a loop
1119 // and is part of an experimental mechanism for merging stores
1120 // into a unified memcpy
1121 if (ctx.ForceSingleIteration) {
1122 LoopContext idx;
1123 auto subctx = ctx.Block;
1124 auto zero = ConstantInt::get(Type::getInt64Ty(newFunc->getContext()), 0);
1125 // The iteration count is always zero so we can set it as such
1126 idx.var = nullptr; // = zero;
1127 idx.incvar = nullptr;
1128 idx.antivaralloc = nullptr;
1129 idx.trueLimit = zero;
1130 idx.maxLimit = zero;
1131 idx.header = subctx;
1132 idx.preheader = subctx;
1133 idx.dynamic = false;
1134 idx.parent = nullptr;
1135 idx.exitBlocks = {};
1136 idx.offset = nullptr;
1137 idx.allocLimit = nullptr;
1138 contexts.push_back(idx);
1139 }
1140
1141 for (BasicBlock *blk = ctx.Block; blk != nullptr;) {
1142 LoopContext idx;
1143 if (!getContext(blk, idx, ctx.ReverseLimit)) {
1144 break;
1145 }
1146 contexts.emplace_back(std::move(idx));
1147 blk = idx.preheader;
1148 }
1149
1150 // Legal preheaders for loop i (indexed from inner => outer)
1151 SmallVector<BasicBlock *, 4> allocationPreheaders(contexts.size(), nullptr);
1152 // Limit of loop i (indexed from inner => outer)
1153 SmallVector<Value *, 4> limits(contexts.size(), nullptr);
1154
1155 // Iterate from outermost loop to innermost loop
1156 for (int i = contexts.size() - 1; i >= 0; --i) {
1157 // The outermost loop's preheader is the preheader directly
1158 // outside the loop nest
1159 if ((unsigned)i == contexts.size() - 1) {
1160 allocationPreheaders[i] = contexts[i].preheader;
1161 } else if (!contexts[i].maxLimit) {
1162 // For dynamic loops, the preheader is now forced to be the preheader
1163 // of that loop
1164 allocationPreheaders[i] = contexts[i].preheader;
1165 } else {
1166 // Otherwise try to use the preheader of the loop just outside this
1167 // one to allocate all iterations across both loops together
1168 allocationPreheaders[i] = allocationPreheaders[i + 1];
1169 }
1170
1171 // Dynamic loops are considered to have a limit of one for allocation
1172 // purposes This is because we want to allocate 1 x (# of iterations inside
1173 // chunk) inside every dynamic iteration
1174 if (!contexts[i].maxLimit) {
1175 limits[i] =
1176 ConstantInt::get(Type::getInt64Ty(ctx.Block->getContext()), 1);
1177 } else {
1178 // Map of previous induction variables we are allowed to use as part
1179 // of the computation of the number of iterations in this chunk
1180 ValueToValueMapTy prevMap;
1181
1182 // Iterate from outermost loop down
1183 for (int j = contexts.size() - 1;; --j) {
1184 // If the preheader allocating memory for loop i
1185 // is distinct from this preheader, we are therefore allocating
1186 // memory in a different chunk. We can use induction variables
1187 // from chunks outside us to compute loop bounds so add it to the
1188 // map
1189 if (allocationPreheaders[i] != contexts[j].preheader) {
1190 prevMap[contexts[j].var] = contexts[j].var;
1191 } else {
1192 break;
1193 }
1194 }
1195
1196 IRBuilder<> allocationBuilder(&allocationPreheaders[i]->back());
1197 Value *limitMinus1 = nullptr;
1198
1199 Value *limit = contexts[i].maxLimit;
1200 if (contexts[i].allocLimit)
1201 limit = contexts[i].allocLimit;
1202
1203 // Attempt to compute the limit of this loop at the corresponding
1204 // allocation preheader. This is null if it was not legal to compute
1205 limitMinus1 = unwrapM(limit, allocationBuilder, prevMap,
1207
1208 // We have a loop with static bounds, but whose limit is not available
1209 // to be computed at the current loop preheader (such as the innermost
1210 // loop of triangular iteration domain) Handle this case like a dynamic
1211 // loop and create a new chunk.
1212 if (limitMinus1 == nullptr) {
1213 EmitWarning("NoOuterLimit", *cast<Instruction>(&*limit),
1214 "Could not compute outermost loop limit by moving value ",
1215 *limit, " computed at block", contexts[i].header->getName(),
1216 " function ", contexts[i].header->getParent()->getName());
1217 allocationPreheaders[i] = contexts[i].preheader;
1218 allocationBuilder.SetInsertPoint(&allocationPreheaders[i]->back());
1219 limitMinus1 = unwrapM(limit, allocationBuilder, prevMap,
1221 if (limitMinus1 == nullptr) {
1222 llvm::errs() << *contexts[i].preheader->getParent() << "\n";
1223 llvm::errs() << "block: " << *allocationPreheaders[i] << "\n";
1224 llvm::errs() << "limit: " << *limit << "\n";
1225 }
1226 assert(limitMinus1 != nullptr);
1227 } else if (i == 0 && extraSize &&
1228 unwrapM(extraSize, allocationBuilder, prevMap,
1229 UnwrapMode::AttemptFullUnwrap) == nullptr) {
1231 "NoOuterLimit", *cast<Instruction>(extraSize), newFunc,
1232 cast<Instruction>(extraSize)->getParent(),
1233 "Could not compute outermost loop limit by moving extraSize value ",
1234 *extraSize, " computed at block", contexts[i].header->getName(),
1235 " function ", contexts[i].header->getParent()->getName());
1236 allocationPreheaders[i] = contexts[i].preheader;
1237 allocationBuilder.SetInsertPoint(&allocationPreheaders[i]->back());
1238 }
1239 assert(limitMinus1 != nullptr);
1240
1241 ValueToValueMapTy reverseMap;
1242 // Iterate from outermost loop down
1243 for (int j = contexts.size() - 1;; --j) {
1244 // If the preheader allocating memory for loop i
1245 // is distinct from this preheader, we are therefore allocating
1246 // memory in a different chunk. We can use induction variables
1247 // from chunks outside us to compute loop bounds so add it to the
1248 // map
1249 if (allocationPreheaders[i] != contexts[j].preheader) {
1250 if (!inForwardPass) {
1251 reverseMap[contexts[j].var] = RB->CreateLoad(
1252 contexts[j].var->getType(), contexts[j].antivaralloc);
1253 }
1254 } else {
1255 break;
1256 }
1257 }
1258
1259 // We now need to compute the actual limit as opposed to the limit
1260 // minus one.
1261 if (inForwardPass) {
1262 // For efficiency, avoid doing this multiple times for
1263 // the same <limitMinus1, Block requested at> pair by caching inside
1264 // of LimitCache.
1265 auto &map = LimitCache[limitMinus1];
1266 auto found = map.find(allocationPreheaders[i]);
1267 if (found != map.end() && found->second != nullptr) {
1268 limits[i] = found->second;
1269 } else {
1270 limits[i] = map[allocationPreheaders[i]] =
1271 allocationBuilder.CreateNUWAdd(
1272 limitMinus1, ConstantInt::get(limitMinus1->getType(), 1));
1273 }
1274 } else {
1275 Value *lim = unwrapM(limitMinus1, *RB, reverseMap,
1277 allocationPreheaders[i]);
1278 if (!lim) {
1279 llvm::errs() << *newFunc << "\n";
1280 llvm::errs() << *limitMinus1 << "\n";
1281 }
1282 assert(lim);
1283 limits[i] = RB->CreateNUWAdd(lim, ConstantInt::get(lim->getType(), 1));
1284 }
1285 }
1286 }
1287
1288 SubLimitType sublimits;
1289
1290 // Total number of iterations of current chunk of loops
1291 Value *size = nullptr;
1292 // Loops inside current chunk (stored innermost to outermost)
1293 SmallVector<std::pair<LoopContext, Value *>, 3> lims;
1294
1295 // Iterate from innermost to outermost loops
1296 for (unsigned i = 0; i < contexts.size(); ++i) {
1297 IRBuilder<> allocationBuilder(&allocationPreheaders[i]->back());
1298 lims.push_back(std::make_pair(contexts[i], limits[i]));
1299 // Compute the cumulative size
1300 if (size == nullptr) {
1301 // If starting with no cumulative size, this is the cumulative size
1302 size = limits[i];
1303 } else if (!inForwardPass) {
1304 size = RB->CreateMul(size, limits[i], "",
1305 /*NUW*/ true, /*NSW*/ true);
1306 } else {
1307 // Otherwise new size = old size * limits[i];
1308 auto cidx = std::make_tuple(size, limits[i], allocationPreheaders[i]);
1309 if (SizeCache.find(cidx) == SizeCache.end()) {
1310 SizeCache[cidx] =
1311 allocationBuilder.CreateMul(size, limits[i], "",
1312 /*NUW*/ true, /*NSW*/ true);
1313 }
1314 size = SizeCache[cidx];
1315 }
1316
1317 // If we are starting a new chunk in the next iteration
1318 // push this chunk to sublimits and clear the cumulative calculations
1319 if ((i + 1 < contexts.size()) &&
1320 (allocationPreheaders[i] != allocationPreheaders[i + 1])) {
1321 sublimits.push_back(std::make_pair(size, lims));
1322 size = nullptr;
1323 lims.clear();
1324 }
1325 }
1326
1327 // For any remaining loop chunks, add them to the list
1328 if (size != nullptr) {
1329 sublimits.push_back(std::make_pair(size, lims));
1330 lims.clear();
1331 }
1332
1333 return sublimits;
1334}
1335
1336/// Given an allocation defined at a particular ctx, store the value val
1337/// in the cache at the location defined in the given builder
1338void CacheUtility::storeInstructionInCache(LimitContext ctx,
1339 IRBuilder<> &BuilderM, Value *val,
1340 AllocaInst *cache, MDNode *TBAA) {
1341 assert(BuilderM.GetInsertBlock()->getParent() == newFunc);
1342#ifndef NDEBUG
1343 if (auto inst = dyn_cast<Instruction>(val))
1344 assert(inst->getParent()->getParent() == newFunc);
1345#endif
1346 IRBuilder<> v(BuilderM.GetInsertBlock());
1347 v.SetInsertPoint(BuilderM.GetInsertBlock(), BuilderM.GetInsertPoint());
1348 v.setFastMathFlags(getFast());
1349
1350 // Note for dynamic loops where the allocation is stored somewhere inside
1351 // the loop, we must ensure that we load the allocation after actually
1352 // storing the allocation itself.
1353 // To simplify things and ensure we always store after a
1354 // potential realloc occurs in this loop, we put our store after
1355 // any existing stores in the loop.
1356 // This is okay as there should be no load to the cache in the same block
1357 // where this instruction is defined as we will just use this instruction
1358 // TODO check that the store is actually aliasing/related
1359 if (BuilderM.GetInsertPoint() != BuilderM.GetInsertBlock()->end()) {
1360 for (auto I = BuilderM.GetInsertBlock()->rbegin(),
1361 E = BuilderM.GetInsertBlock()->rend();
1362 I != E; ++I) {
1363 if (&*I == &*BuilderM.GetInsertPoint())
1364 break;
1365 if (auto si = dyn_cast<StoreInst>(&*I)) {
1367 if (ni != nullptr) {
1368 v.SetInsertPoint(ni);
1369 } else {
1370 v.SetInsertPoint(si->getParent());
1371 }
1372 }
1373 }
1374 }
1375
1376 bool isi1 = val->getType()->isIntegerTy() &&
1377 cast<IntegerType>(val->getType())->getBitWidth() == 1;
1378 Value *loc = getCachePointer(val->getType(),
1379 /*inForwardPass*/ true, v, ctx, cache,
1380 /*storeInInstructionsMap*/ true,
1381 /*available*/ llvm::ValueToValueMapTy(),
1382 /*extraSize*/ nullptr);
1383
1384 Value *tostore = val;
1385
1386 // If we are doing the efficient bool cache, the actual value
1387 // we want to store needs to have the existing surrounding bits
1388 // set appropriately
1389 if (EfficientBoolCache && isi1) {
1390 if (auto gep = dyn_cast<GetElementPtrInst>(loc)) {
1391 auto bo = cast<BinaryOperator>(*gep->idx_begin());
1392 assert(bo->getOpcode() == BinaryOperator::LShr);
1393 auto subidx = v.CreateAnd(
1394 v.CreateTrunc(bo->getOperand(0),
1395 Type::getInt8Ty(cache->getContext())),
1396 ConstantInt::get(Type::getInt8Ty(cache->getContext()), 7));
1397 auto mask = v.CreateNot(v.CreateShl(
1398 ConstantInt::get(Type::getInt8Ty(cache->getContext()), 1), subidx));
1399
1400 Value *loadChunk = v.CreateLoad(mask->getType(), loc);
1401 auto cleared = v.CreateAnd(loadChunk, mask);
1402
1403 auto toset = v.CreateShl(
1404 v.CreateZExt(val, Type::getInt8Ty(cache->getContext())), subidx);
1405 tostore = v.CreateOr(cleared, toset);
1406 assert(tostore->getType() == mask->getType());
1407 }
1408 }
1409
1410#if LLVM_VERSION_MAJOR < 17
1411 if (tostore->getContext().supportsTypedPointers()) {
1412 if (tostore->getType() != loc->getType()->getPointerElementType()) {
1413 llvm::errs() << "val: " << *val << "\n";
1414 llvm::errs() << "tostore: " << *tostore << "\n";
1415 llvm::errs() << "loc: " << *loc << "\n";
1416 }
1417 assert(tostore->getType() == loc->getType()->getPointerElementType());
1418 }
1419#endif
1420
1421 StoreInst *storeinst = v.CreateStore(tostore, loc);
1422
1423 // If the value stored doesnt change (per efficient bool cache),
1424 // mark it as invariant
1425 if (tostore == val) {
1426 if (ValueInvariantGroups.find(cache) == ValueInvariantGroups.end()) {
1427 MDNode *invgroup = MDNode::getDistinct(cache->getContext(), {});
1428 ValueInvariantGroups[cache] = invgroup;
1429 }
1430 storeinst->setMetadata(LLVMContext::MD_invariant_group,
1431 ValueInvariantGroups[cache]);
1432 }
1433
1434 // Set alignment
1435 ConstantInt *byteSizeOfType =
1436 ConstantInt::get(Type::getInt64Ty(cache->getContext()),
1437 ctx.Block->getParent()
1438 ->getParent()
1439 ->getDataLayout()
1440 .getTypeAllocSizeInBits(val->getType()) /
1441 8);
1442 unsigned align = getCacheAlignment((unsigned)byteSizeOfType->getZExtValue());
1443 storeinst->setMetadata(LLVMContext::MD_tbaa, TBAA);
1444 storeinst->setAlignment(Align(align));
1445 scopeInstructions[cache].push_back(storeinst);
1446 for (auto post : PostCacheStore(storeinst, v)) {
1447 scopeInstructions[cache].push_back(post);
1448 }
1449}
1450
1451/// Given an allocation defined at a particular ctx, store the instruction
1452/// in the cache right after the instruction is executed
1454 llvm::Instruction *inst,
1455 llvm::AllocaInst *cache,
1456 llvm::MDNode *TBAA) {
1457 assert(ctx.Block);
1458 assert(inst);
1459 assert(cache);
1460
1461 // Find the correct place to issue the store
1462 IRBuilder<> v(inst->getParent());
1463 // If this is a PHINode, we need to store after all phinodes,
1464 // otherwise just after inst sufficies
1465 if (&*inst->getParent()->rbegin() != inst) {
1466 auto pn = dyn_cast<PHINode>(inst);
1467 Instruction *putafter = (pn && pn->getNumIncomingValues() > 0)
1468 ? (inst->getParent()->getFirstNonPHI())
1470 assert(putafter);
1471 v.SetInsertPoint(putafter);
1472 }
1473 v.setFastMathFlags(getFast());
1474 storeInstructionInCache(ctx, v, inst, cache, TBAA);
1475}
1476
1477/// Given an allocation specified by the LimitContext ctx and cache, compute a
1478/// pointer that can hold the underlying type being cached. This value should be
1479/// computed at BuilderM. Optionally, instructions needed to generate this
1480/// pointer can be stored in scopeInstructions
1481Value *CacheUtility::getCachePointer(llvm::Type *T, bool inForwardPass,
1482 IRBuilder<> &BuilderM, LimitContext ctx,
1483 Value *cache, bool storeInInstructionsMap,
1484 const ValueToValueMapTy &available,
1485 Value *extraSize) {
1486 assert(ctx.Block);
1487 assert(cache);
1488 auto sublimits = getSubLimits(inForwardPass, &BuilderM, ctx, extraSize);
1489
1490 Value *next = cache;
1491 assert(next->getType()->isPointerTy());
1492
1493 SmallVector<Type *, 4> types = {T};
1494 bool isi1 = T->isIntegerTy() && cast<IntegerType>(T)->getBitWidth() == 1;
1495 if (EfficientBoolCache && isi1 && sublimits.size() != 0)
1496 types[0] = Type::getInt8Ty(T->getContext());
1497 auto i64 = Type::getInt64Ty(T->getContext());
1498 for (size_t i = 0; i < sublimits.size(); ++i) {
1499 Type *allocType;
1500 {
1501 BasicBlock *BB =
1502 BasicBlock::Create(newFunc->getContext(), "entry", newFunc);
1503 IRBuilder<> B(BB);
1504 auto P = B.CreatePHI(i64, 1);
1505
1506 CallInst *malloccall;
1507 Instruction *Zero;
1508 allocType = cast<PointerType>(CreateAllocation(B, types.back(), P,
1509 "tmpfortypecalc",
1510 &malloccall, &Zero)
1511 ->getType());
1512 for (auto &I : make_early_inc_range(reverse(*BB)))
1513 I.eraseFromParent();
1514
1515 BB->eraseFromParent();
1516 }
1517 types.push_back(allocType);
1518 }
1519
1520 // Iterate from outermost loop to innermost loop
1521 for (int i = sublimits.size() - 1; i >= 0; i--) {
1522 // Lookup the next allocation pointer
1523 next = BuilderM.CreateLoad(types[i + 1], next);
1524 if (storeInInstructionsMap && isa<AllocaInst>(cache))
1525 scopeInstructions[cast<AllocaInst>(cache)].push_back(
1526 cast<Instruction>(next));
1527
1528 if (!next->getType()->isPointerTy()) {
1529 llvm::errs() << *newFunc << "\n";
1530 llvm::errs() << "cache: " << *cache << "\n";
1531 llvm::errs() << "next: " << *next << "\n";
1532 assert(next->getType()->isPointerTy());
1533 }
1534
1535 // Set appropriate invairant lookup flags
1536 if (CachePointerInvariantGroups.find(std::make_pair(cache, i)) ==
1537 CachePointerInvariantGroups.end()) {
1538 MDNode *invgroup = MDNode::getDistinct(cache->getContext(), {});
1539 CachePointerInvariantGroups[std::make_pair(cache, i)] = invgroup;
1540 }
1541 cast<LoadInst>(next)->setMetadata(
1542 LLVMContext::MD_invariant_group,
1543 CachePointerInvariantGroups[std::make_pair(cache, i)]);
1544
1545 // Set dereferenceable and alignment flags
1546 ConstantInt *byteSizeOfType = ConstantInt::get(
1547 Type::getInt64Ty(cache->getContext()),
1548 newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits(
1549 next->getType()) /
1550 8);
1551 cast<LoadInst>(next)->setMetadata(
1552 LLVMContext::MD_dereferenceable,
1553 MDNode::get(
1554 cache->getContext(),
1555 ArrayRef<Metadata *>(ConstantAsMetadata::get(byteSizeOfType))));
1556 unsigned align =
1557 getCacheAlignment((unsigned)byteSizeOfType->getZExtValue());
1558 cast<LoadInst>(next)->setAlignment(Align(align));
1559
1560 const auto &containedloops = sublimits[i].second;
1561
1562 if (containedloops.size() > 0) {
1563 Value *idx = computeIndexOfChunk(inForwardPass, BuilderM, containedloops,
1564 available);
1565 if (EfficientBoolCache && isi1 && i == 0)
1566 idx = BuilderM.CreateLShr(
1567 idx, ConstantInt::get(Type::getInt64Ty(newFunc->getContext()), 3));
1568 if (i == 0 && extraSize) {
1569 Value *es = lookupM(extraSize, BuilderM);
1570 assert(es);
1571 idx = BuilderM.CreateMul(idx, es, "", /*NUW*/ true, /*NSW*/ true);
1572 }
1573 next = BuilderM.CreateGEP(types[i], next, idx);
1574 cast<GetElementPtrInst>(next)->setIsInBounds(true);
1575 if (storeInInstructionsMap && isa<AllocaInst>(cache))
1576 scopeInstructions[cast<AllocaInst>(cache)].push_back(
1577 cast<Instruction>(next));
1578 }
1579 assert(next->getType()->isPointerTy());
1580 }
1581 return next;
1582}
1583
1584/// Perform the final load from the cache, applying requisite invariant
1585/// group and alignment
1587 llvm::IRBuilder<> &BuilderM,
1588 llvm::Value *cptr,
1589 llvm::Value *cache) {
1590 // Retrieve the actual result
1591 auto result = BuilderM.CreateLoad(T, cptr);
1592
1593 // Apply requisite invariant, alignment, etc
1594 if (ValueInvariantGroups.find(cache) == ValueInvariantGroups.end()) {
1595 MDNode *invgroup = MDNode::getDistinct(cache->getContext(), {});
1596 ValueInvariantGroups[cache] = invgroup;
1597 }
1598 CacheLookups.insert(result);
1599 result->setMetadata(LLVMContext::MD_invariant_group,
1600 ValueInvariantGroups[cache]);
1601 ConstantInt *byteSizeOfType = ConstantInt::get(
1602 Type::getInt64Ty(cache->getContext()),
1603 newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits(
1604 result->getType()) /
1605 8);
1606 unsigned align = getCacheAlignment((unsigned)byteSizeOfType->getZExtValue());
1607 result->setAlignment(Align(align));
1608
1609 return result;
1610}
1611
1612/// Given an allocation specified by the LimitContext ctx and cache, lookup the
1613/// underlying cached value.
1615 Type *T, bool inForwardPass, IRBuilder<> &BuilderM, LimitContext ctx,
1616 Value *cache, bool isi1, const ValueToValueMapTy &available,
1617 Value *extraSize, Value *extraOffset) {
1618 // Get the underlying cache pointer
1619 auto cptr =
1620 getCachePointer(T, inForwardPass, BuilderM, ctx, cache,
1621 /*storeInInstructionsMap*/ false, available, extraSize);
1622
1623 // Optionally apply the additional offset
1624 if (extraOffset) {
1625 cptr = BuilderM.CreateGEP(T, cptr, extraOffset);
1626 cast<GetElementPtrInst>(cptr)->setIsInBounds(true);
1627 }
1628
1629 Value *result = loadFromCachePointer(T, BuilderM, cptr, cache);
1630
1631 // If using the efficient bool cache, do the corresponding
1632 // mask and shift to retrieve the actual value
1633 if (EfficientBoolCache && isi1) {
1634 if (auto gep = dyn_cast<GetElementPtrInst>(cptr)) {
1635 auto bo = cast<BinaryOperator>(*gep->idx_begin());
1636 assert(bo->getOpcode() == BinaryOperator::LShr);
1637 Value *res = BuilderM.CreateLShr(
1638 result,
1639 BuilderM.CreateAnd(
1640 BuilderM.CreateTrunc(bo->getOperand(0),
1641 Type::getInt8Ty(cache->getContext())),
1642 ConstantInt::get(Type::getInt8Ty(cache->getContext()), 7)));
1643 return BuilderM.CreateTrunc(res, Type::getInt1Ty(result->getContext()));
1644 }
1645 }
1646 return result;
1647}
llvm::cl::opt< bool > EfficientMaxCache("enzyme-max-cache", cl::init(false), cl::Hidden, cl::desc("Avoid reallocs when possible by potentially overallocating cache"))
std::pair< PHINode *, Instruction * > InsertNewCanonicalIV(Loop *L, Type *Ty, const llvm::Twine &Name)
void CanonicalizeLatches(const Loop *L, BasicBlock *Header, BasicBlock *Preheader, PHINode *CanonicalIV, MustExitScalarEvolution &SE, CacheUtility &gutils, Instruction *Increment, ArrayRef< BasicBlock * > latches)
void RemoveRedundantIVs(BasicBlock *Header, PHINode *CanonicalIV, Instruction *Increment, MustExitScalarEvolution &SE, llvm::function_ref< void(Instruction *, Value *)> replacer, llvm::function_ref< void(Instruction *)> eraser)
std::pair< PHINode *, Instruction * > FindCanonicalIV(Loop *L, Type *Ty)
llvm::cl::opt< bool > EnzymeZeroCache
llvm::cl::opt< bool > EfficientBoolCache
Pack 8 bools together in a single byte.
@ AttemptFullUnwrapWithLookup
static void getExitBlocks(const llvm::Loop *L, llvm::SmallPtrSetImpl< llvm::BasicBlock * > &ExitBlocks)
static llvm::SmallVector< llvm::BasicBlock *, 3 > getLatches(const llvm::Loop *L, const llvm::SmallPtrSetImpl< llvm::BasicBlock * > &ExitBlocks)
static std::string str(AugmentedStruct c)
Definition EnzymeLogic.h:62
llvm::Value * CreateReAllocation(llvm::IRBuilder<> &B, llvm::Value *prev, llvm::Type *T, llvm::Value *OuterCount, llvm::Value *InnerCount, const llvm::Twine &Name, llvm::CallInst **caller, bool ZeroMem)
Definition Utils.cpp:590
llvm::SmallVector< llvm::Instruction *, 2 > PostCacheStore(llvm::StoreInst *SI, llvm::IRBuilder<> &B)
Definition Utils.cpp:423
Value * CreateAllocation(IRBuilder<> &Builder, llvm::Type *T, Value *Count, const Twine &Name, CallInst **caller, Instruction **ZeroMem, bool isDefault)
Definition Utils.cpp:619
LLVMValueRef(* CustomErrorHandler)(const char *, LLVMValueRef, ErrorType, const void *, LLVMValueRef, LLVMBuilderRef)
Definition Utils.cpp:62
llvm::FastMathFlags getFast()
Get LLVM fast math flags.
Definition Utils.cpp:3731
llvm::Constant * getUndefinedValueForType(llvm::Module &M, llvm::Type *T, bool forceZero)
Definition Utils.cpp:3623
static llvm::Instruction * getNextNonDebugInstructionOrNull(llvm::Instruction *Z)
Get the next non-debug instruction, if one exists.
Definition Utils.h:318
static V * findInMap(std::map< K, V > &map, K key)
Definition Utils.h:855
void EmitFailure(llvm::StringRef RemarkName, const llvm::DiagnosticLocation &Loc, const llvm::Instruction *CodeRegion, Args &...args)
Definition Utils.h:203
static std::map< K, V >::iterator insert_or_assign2(std::map< K, V > &map, K key, V val)
Insert into a map.
Definition Utils.h:846
void EmitWarning(llvm::StringRef RemarkName, const llvm::DiagnosticLocation &Loc, const llvm::BasicBlock *BB, const Args &...args)
Definition Utils.h:133
static llvm::Instruction * getNextNonDebugInstruction(llvm::Instruction *Z)
Get the next non-debug instruction, erring if none exists.
Definition Utils.h:327
llvm::cl::opt< bool > EnzymePrintPerf
Print additional debug info relevant to performance.
static llvm::Instruction * getFirstNonPHIOrDbg(llvm::BasicBlock *B)
Definition Utils.h:2272
llvm::Function *const newFunc
The function whose instructions we are caching.
MustExitScalarEvolution SE
virtual llvm::Value * unwrapM(llvm::Value *const val, llvm::IRBuilder<> &BuilderM, const llvm::ValueToValueMapTy &available, UnwrapMode mode, llvm::BasicBlock *scope=nullptr, bool permitCache=true)=0
High-level utility to "unwrap" an instruction at a new location specified by BuilderM.
virtual llvm::CallInst * freeCache(llvm::BasicBlock *forwardPreheader, const SubLimitType &antimap, int i, llvm::AllocaInst *alloc, llvm::Type *myType, llvm::ConstantInt *byteSizeOfType, llvm::Value *storeInto, llvm::MDNode *InvariantMD)
If an allocation is requested to be freed, this subclass will be called to chose how and where to fre...
std::map< llvm::AllocaInst *, llvm::SmallVector< llvm::AssertingVH< llvm::CallInst >, 4 > > scopeAllocs
A map of allocations to a set of instructions which allocate memory as part of the cache.
llvm::AllocaInst * createCacheForScope(LimitContext ctx, llvm::Type *T, llvm::StringRef name, bool shouldFree, bool allocateInternal=true, llvm::Value *extraSize=nullptr)
Create a cache of Type T at the given LimitContext.
virtual void erase(llvm::Instruction *I)
Erase this instruction both from LLVM modules and any local data-structures.
llvm::Value * getCachePointer(llvm::Type *T, bool inForwardPass, llvm::IRBuilder<> &BuilderM, LimitContext ctx, llvm::Value *cache, bool storeInInstructionsMap, const llvm::ValueToValueMapTy &available, llvm::Value *extraSize)
Given an allocation specified by the LimitContext ctx and cache, compute a pointer that can hold the ...
llvm::Value * lookupValueFromCache(llvm::Type *T, bool inForwardPass, llvm::IRBuilder<> &BuilderM, LimitContext ctx, llvm::Value *cache, bool isi1, const llvm::ValueToValueMapTy &available, llvm::Value *extraSize=nullptr, llvm::Value *extraOffset=nullptr)
Given an allocation specified by the LimitContext ctx and cache, lookup the underlying cached value.
bool getContext(llvm::BasicBlock *BB, LoopContext &loopContext, bool ReverseLimit)
Given a BasicBlock BB in newFunc, set loopContext to the relevant contained loop and return true.
llvm::Value * loadFromCachePointer(llvm::Type *T, llvm::IRBuilder<> &BuilderM, llvm::Value *cptr, llvm::Value *cache)
Perform the final load from the cache, applying requisite invariant group and alignment.
std::map< llvm::AllocaInst *, std::set< llvm::AssertingVH< llvm::CallInst > > > scopeFrees
A map of allocations to a set of instructions which free memory as part of the cache.
virtual ~CacheUtility()
llvm::LoopInfo LI
std::map< llvm::Loop *, LoopContext > loopContexts
Map of Loop to requisite loop information needed for AD (forward/reverse induction/etc)
std::map< llvm::AllocaInst *, llvm::SmallVector< llvm::AssertingVH< llvm::Instruction >, 4 > > scopeInstructions
A map of allocations to a vector of instruction used to create by the allocation Keeping track of the...
llvm::BasicBlock * inversionAllocs
std::map< llvm::Value *, std::pair< llvm::AssertingVH< llvm::AllocaInst >, LimitContext > > scopeMap
A map of values being cached to their underlying allocation/limit context.
llvm::SmallVector< std::pair< llvm::Value *, llvm::SmallVector< std::pair< LoopContext, llvm::Value * >, 4 > >, 0 > SubLimitType
Given a LimitContext ctx, representing a location inside a loop nest, break each of the loops up into...
virtual void replaceAWithB(llvm::Value *A, llvm::Value *B, bool storeInCache=false)
Replace this instruction both in LLVM modules and any local data-structures.
unsigned getCacheAlignment(unsigned bsize) const
llvm::AllocaInst * getDynamicLoopLimit(llvm::Loop *L, bool ReverseLimit=true)
llvm::SmallPtrSet< llvm::LoadInst *, 10 > CacheLookups
SubLimitType getSubLimits(bool inForwardPass, llvm::IRBuilder<> *RB, LimitContext ctx, llvm::Value *extraSize=nullptr)
Given a LimitContext ctx, representing a location inside a loop nest, break each of the loops up into...
void storeInstructionInCache(LimitContext ctx, llvm::IRBuilder<> &BuilderM, llvm::Value *val, llvm::AllocaInst *cache, llvm::MDNode *TBAA=nullptr)
Given an allocation defined at a particular ctx, store the value val in the cache at the location def...
virtual llvm::Value * lookupM(llvm::Value *val, llvm::IRBuilder<> &BuilderM, const llvm::ValueToValueMapTy &incoming_availalble=llvm::ValueToValueMapTy(), bool tryLegalityCheck=true, llvm::BasicBlock *scope=nullptr)=0
High-level utility to get the value an instruction at a new location specified by BuilderM.
virtual bool assumeDynamicLoopOfSizeOne(llvm::Loop *L) const =0
llvm::SmallPtrSet< llvm::BasicBlock *, 4 > GuaranteedUnreachable
ScalarEvolution::ExitLimit computeExitLimit(const llvm::Loop *L, llvm::BasicBlock *ExitingBlock, bool AllowPredicates)
llvm::BasicBlock * Block
Container for all loop information to synthesize gradients.
llvm::Loop * parent
Parent loop of this loop.
llvm::BasicBlock * header
Header of this loop.
llvm::AssertingVH< llvm::Instruction > incvar
Increment of the induction.
bool dynamic
Whether this loop has a statically analyzable number of iterations.
llvm::SmallPtrSet< llvm::BasicBlock *, 8 > exitBlocks
All blocks this loop exits too.
llvm::AssertingVH< llvm::AllocaInst > antivaralloc
Allocation of induction variable of reverse pass.
llvm::AssertingVH< llvm::PHINode > var
Canonical induction variable of the loop.
AssertingReplacingVH offset
An offset to add to the index when getting the cache pointer.
llvm::BasicBlock * preheader
Preheader of this loop.
AssertingReplacingVH maxLimit
limit is last value of a canonical induction variable iters is number of times loop is run (thus iter...
AssertingReplacingVH allocLimit
An overriding allocation limit size.
AssertingReplacingVH trueLimit