Enzyme main
Loading...
Searching...
No Matches
MustExitScalarEvolution.cpp
Go to the documentation of this file.
1
2//===- MustExitScalarEvolution.cpp - ScalarEvolution assuming loops
3// terminate-=//
4//
5// Enzyme Project
6//
7// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
8// Exceptions. See https://llvm.org/LICENSE.txt for license information.
9// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
10//
11// If using this code in an academic setting, please cite the following:
12// @incollection{enzymeNeurips,
13// title = {Instead of Rewriting Foreign Code for Machine Learning,
14// Automatically Synthesize Fast Gradients},
15// author = {Moses, William S. and Churavy, Valentin},
16// booktitle = {Advances in Neural Information Processing Systems 33},
17// year = {2020},
18// note = {To appear in},
19// }
20//
21//===----------------------------------------------------------------------===//
22//
23// This file defines MustExitScalarEvolution, a subclass of ScalarEvolution
24// that assumes that all loops terminate (and don't loop forever).
25//
26//===----------------------------------------------------------------------===//
27
29#include "FunctionUtils.h"
30#include "llvm/ADT/SmallVector.h"
31#include "llvm/Analysis/LoopInfo.h"
32#include "llvm/Analysis/ScalarEvolution.h"
33
34#ifdef __clang__
35#pragma clang diagnostic push
36#pragma clang diagnostic ignored "-Wunused-variable"
37#else
38#pragma GCC diagnostic push
39#pragma GCC diagnostic ignored "-Wunused-variable"
40#endif
41
42#if LLVM_VERSION_MAJOR <= 22
43#define SCEVUse const SCEV *
44#endif
45
46using namespace llvm;
47
49 return true;
50}
51
53 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsExit,
54 bool AllowPredicates) {
55 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
56 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
57 ControlsExit, AllowPredicates);
58}
59
61 llvm::TargetLibraryInfo &TLI,
62 llvm::AssumptionCache &AC,
63 llvm::DominatorTree &DT,
64 llvm::LoopInfo &LI)
65 : ScalarEvolution(F, TLI, AC, DT, LI),
66 GuaranteedUnreachable(getGuaranteedUnreachable(&F)) {}
67
69 const Loop *L, BasicBlock *ExitingBlock, bool AllowPredicates) {
70
71 SmallVector<BasicBlock *, 8> ExitingBlocks;
72 L->getExitingBlocks(ExitingBlocks);
73 for (auto &ExitingBlock : ExitingBlocks) {
74 BasicBlock *Exit = nullptr;
75 for (auto *SBB : successors(ExitingBlock)) {
76 if (!L->contains(SBB)) {
77 if (GuaranteedUnreachable.count(SBB))
78 continue;
79 Exit = SBB;
80 break;
81 }
82 }
83 if (!Exit)
84 ExitingBlock = nullptr;
85 }
86 ExitingBlocks.erase(
87 std::remove(ExitingBlocks.begin(), ExitingBlocks.end(), nullptr),
88 ExitingBlocks.end());
89
90 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
91 // If our exiting block does not dominate the latch, then its connection with
92 // loop's exit limit may be far from trivial.
93 const BasicBlock *Latch = L->getLoopLatch();
94 if (!Latch || !DT.dominates(ExitingBlock, Latch))
95 return getCouldNotCompute();
96
97 bool IsOnlyExit = ExitingBlocks.size() == 1;
98 auto *Term = ExitingBlock->getTerminator();
99 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
100 assert(BI->isConditional() && "If unconditional, it can't be in loop!");
101 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
102 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
103 "It should have one successor in loop and one exit block!");
104 // Proceed to the next level to examine the exit condition expression.
105 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
106 /*ControlsExit=*/IsOnlyExit,
107 AllowPredicates);
108 }
109
110 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
111 // For switch, make sure that there is a single exit from the loop.
112 BasicBlock *Exit = nullptr;
113 for (auto *SBB : successors(ExitingBlock))
114 if (!L->contains(SBB)) {
115 if (GuaranteedUnreachable.count(SBB))
116 continue;
117 if (Exit) // Multiple exit successors.
118 return getCouldNotCompute();
119 Exit = SBB;
120 }
121 assert(Exit && "Exiting block must have at least one exit");
122 return computeExitLimitFromSingleExitSwitch(L, SI, Exit,
123 /*ControlsExit=*/IsOnlyExit);
124 }
125
126 return getCouldNotCompute();
127}
128
129ScalarEvolution::ExitLimit
131 const Loop *L, SwitchInst *Switch, BasicBlock *ExitingBlock,
132 bool ControlsOnlyExit) {
133 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
134
135 // Give up if the exit is the default dest of a switch.
136 if (Switch->getDefaultDest() == ExitingBlock)
137 return getCouldNotCompute();
138
139 ///! If we're guaranteed unreachable, the default dest does not matter.
140 if (!GuaranteedUnreachable.count(Switch->getDefaultDest()))
141 assert(L->contains(Switch->getDefaultDest()) &&
142 "Default case must not exit the loop!");
143 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
144 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
145
146 // while (X != Y) --> while (X-Y != 0)
147 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
148 if (EL.hasAnyInfo())
149 return EL;
150
151 return getCouldNotCompute();
152}
153
154ScalarEvolution::ExitLimit
156 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
157 bool ControlsExit, bool AllowPredicates) {
158
159 if (auto MaybeEL =
160 Cache.find(L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates))
161 return *MaybeEL;
162
163 ExitLimit EL = computeExitLimitFromCondImpl(Cache, L, ExitCond, ExitIfTrue,
164 ControlsExit, AllowPredicates);
165 Cache.insert(L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates, EL);
166 return EL;
167}
168
169ScalarEvolution::ExitLimit
171 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
172 bool ControlsExit, bool AllowPredicates) {
173 // Check if the controlling expression for this loop is an And or Or.
174 if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) {
175 if (BO->getOpcode() == Instruction::And) {
176 // Recurse on the operands of the and.
177 bool EitherMayExit = !ExitIfTrue;
178 ExitLimit EL0 = computeExitLimitFromCondCached(
179 Cache, L, BO->getOperand(0), ExitIfTrue,
180 ControlsExit && !EitherMayExit, AllowPredicates);
181 ExitLimit EL1 = computeExitLimitFromCondCached(
182 Cache, L, BO->getOperand(1), ExitIfTrue,
183 ControlsExit && !EitherMayExit, AllowPredicates);
184 const SCEV *BECount = getCouldNotCompute();
185 const SCEV *MaxBECount = getCouldNotCompute();
186 if (EitherMayExit) {
187 // Both conditions must be true for the loop to continue executing.
188 // Choose the less conservative count.
189 if (EL0.ExactNotTaken == getCouldNotCompute() ||
190 EL1.ExactNotTaken == getCouldNotCompute())
191 BECount = getCouldNotCompute();
192 else
193 BECount =
194 getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken);
195#if LLVM_VERSION_MAJOR >= 16
196 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
197 MaxBECount = EL1.ConstantMaxNotTaken;
198 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
199 MaxBECount = EL0.ConstantMaxNotTaken;
200 else
201 MaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
202 EL1.ConstantMaxNotTaken);
203 } else {
204 // Both conditions must be true at the same time for the loop to exit.
205 // For now, be conservative.
206 if (EL0.ConstantMaxNotTaken == EL1.ConstantMaxNotTaken)
207 MaxBECount = EL0.ConstantMaxNotTaken;
208 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
209 BECount = EL0.ExactNotTaken;
210 }
211#else
212 if (EL0.MaxNotTaken == getCouldNotCompute())
213 MaxBECount = EL1.MaxNotTaken;
214 else if (EL1.MaxNotTaken == getCouldNotCompute())
215 MaxBECount = EL0.MaxNotTaken;
216 else
217 MaxBECount =
218 getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken);
219 } else {
220 // Both conditions must be true at the same time for the loop to exit.
221 // For now, be conservative.
222 if (EL0.MaxNotTaken == EL1.MaxNotTaken)
223 MaxBECount = EL0.MaxNotTaken;
224 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
225 BECount = EL0.ExactNotTaken;
226 }
227#endif
228 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
229 // to be more aggressive when computing BECount than when computing
230 // MaxBECount. In these cases it is possible for EL0.ExactNotTaken and
231 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
232 // EL1.ConstantMaxNotTaken to not.
233 if (isa<SCEVCouldNotCompute>(MaxBECount) &&
234 !isa<SCEVCouldNotCompute>(BECount))
235 MaxBECount = getConstant(getUnsignedRangeMax(BECount));
236
237#if LLVM_VERSION_MAJOR >= 20
238 return ExitLimit(BECount, MaxBECount, MaxBECount, false,
239 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
240#elif LLVM_VERSION_MAJOR >= 16
241 return ExitLimit(BECount, MaxBECount, MaxBECount, false,
242 {&EL0.Predicates, &EL1.Predicates});
243#else
244 return ExitLimit(BECount, MaxBECount, false,
245 {&EL0.Predicates, &EL1.Predicates});
246#endif
247 }
248 if (BO->getOpcode() == Instruction::Or) {
249 // Recurse on the operands of the or.
250 bool EitherMayExit = ExitIfTrue;
251 ExitLimit EL0 = computeExitLimitFromCondCached(
252 Cache, L, BO->getOperand(0), ExitIfTrue,
253 ControlsExit && !EitherMayExit, AllowPredicates);
254 ExitLimit EL1 = computeExitLimitFromCondCached(
255 Cache, L, BO->getOperand(1), ExitIfTrue,
256 ControlsExit && !EitherMayExit, AllowPredicates);
257 const SCEV *BECount = getCouldNotCompute();
258 const SCEV *MaxBECount = getCouldNotCompute();
259 if (EitherMayExit) {
260 // Both conditions must be false for the loop to continue executing.
261 // Choose the less conservative count.
262 if (EL0.ExactNotTaken == getCouldNotCompute() ||
263 EL1.ExactNotTaken == getCouldNotCompute())
264 BECount = getCouldNotCompute();
265 else
266 BECount =
267 getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken);
268#if LLVM_VERSION_MAJOR >= 16
269 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
270 MaxBECount = EL1.ConstantMaxNotTaken;
271 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
272 MaxBECount = EL0.ConstantMaxNotTaken;
273 else
274 MaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
275 EL1.ConstantMaxNotTaken);
276 } else {
277 // Both conditions must be false at the same time for the loop to exit.
278 // For now, be conservative.
279 if (EL0.ConstantMaxNotTaken == EL1.ConstantMaxNotTaken)
280 MaxBECount = EL0.ConstantMaxNotTaken;
281 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
282 BECount = EL0.ExactNotTaken;
283 }
284#if LLVM_VERSION_MAJOR >= 20
285 return ExitLimit(BECount, MaxBECount, MaxBECount, false,
286 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
287#else
288 return ExitLimit(BECount, MaxBECount, MaxBECount, false,
289 {&EL0.Predicates, &EL1.Predicates});
290#endif
291#else
292 if (EL0.MaxNotTaken == getCouldNotCompute())
293 MaxBECount = EL1.MaxNotTaken;
294 else if (EL1.MaxNotTaken == getCouldNotCompute())
295 MaxBECount = EL0.MaxNotTaken;
296 else
297 MaxBECount =
298 getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken);
299 } else {
300 // Both conditions must be false at the same time for the loop to exit.
301 // For now, be conservative.
302 if (EL0.MaxNotTaken == EL1.MaxNotTaken)
303 MaxBECount = EL0.MaxNotTaken;
304 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
305 BECount = EL0.ExactNotTaken;
306 }
307 return ExitLimit(BECount, MaxBECount, false,
308 {&EL0.Predicates, &EL1.Predicates});
309#endif
310 }
311 }
312
313 // With an icmp, it may be feasible to compute an exact backedge-taken count.
314 // Proceed to the next level to examine the icmp.
315 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
316 ExitLimit EL =
317 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsExit);
318 if (EL.hasFullInfo() || !AllowPredicates)
319 return EL;
320
321 // Try again, but use SCEV predicates this time.
322 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsExit,
323 /*AllowPredicates=*/true);
324 }
325
326 // Check for a constant condition. These are normally stripped out by
327 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
328 // preserve the CFG and is temporarily leaving constant conditions
329 // in place.
330 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
331 if (ExitIfTrue == !CI->getZExtValue())
332 // The backedge is always taken.
333 return getCouldNotCompute();
334 else
335 // The backedge is never taken.
336 return getZero(CI->getType());
337 }
338
339 // If it's not an integer or pointer comparison then compute it the hard way.
340 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
341}
342
344 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsExit,
345 bool AllowPredicates) {
346 // If the condition was exit on true, convert the condition to exit on false
347#if LLVM_VERSION_MAJOR >= 20
348 llvm::CmpPredicate Pred = ExitCond->getPredicate();
349#else
350 auto Pred = ExitCond->getPredicate();
351#endif
352 if (ExitIfTrue)
353 Pred = ExitCond->getInversePredicate();
354 const auto OriginalPred = Pred;
355
356#if LLVM_VERSION_MAJOR < 14
357 // Handle common loops like: for (X = "string"; *X; ++X)
358 if (LoadInst *LI = dyn_cast<LoadInst>(ExitCond->getOperand(0)))
359 if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) {
360 ExitLimit ItCnt = computeLoadConstantCompareExitLimit(LI, RHS, L, Pred);
361 if (ItCnt.hasAnyInfo())
362 return ItCnt;
363 }
364#endif
365
366 SCEVUse LHS = getSCEV(ExitCond->getOperand(0));
367 SCEVUse RHS = getSCEV(ExitCond->getOperand(1));
368
369#define PROP_PHI(LHS) \
370 if (auto un = dyn_cast<SCEVUnknown>(LHS)) { \
371 if (auto pn = dyn_cast_or_null<PHINode>(un->getValue())) { \
372 const SCEV *sc = nullptr; \
373 bool failed = false; \
374 for (auto &a : pn->incoming_values()) { \
375 auto subsc = getSCEV(a); \
376 if (sc == nullptr) { \
377 sc = subsc; \
378 continue; \
379 } \
380 if (subsc != sc) { \
381 failed = true; \
382 break; \
383 } \
384 } \
385 if (!failed) { \
386 LHS = sc; \
387 } \
388 } \
389 }
390 PROP_PHI(LHS)
391 PROP_PHI(RHS)
392
393 // Try to evaluate any dependencies out of the loop.
394 LHS = getSCEVAtScope(LHS, L);
395 RHS = getSCEVAtScope(RHS, L);
396
397 // At this point, we would like to compute how many iterations of the
398 // loop the predicate will return true for these inputs.
399 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
400 // If there is a loop-invariant, force it into the RHS.
401 std::swap(LHS, RHS);
402 Pred = ICmpInst::getSwappedPredicate(Pred);
403 }
404
405 // Simplify the operands before analyzing them.
406 (void)SimplifyICmpOperands(Pred, LHS, RHS);
407
408 // If we have a comparison of a chrec against a constant, try to use value
409 // ranges to answer this query.
410 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
411 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
412 if (AddRec->getLoop() == L) {
413 // Form the constant range.
414 ConstantRange CompRange =
415 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
416
417 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
418 if (!isa<SCEVCouldNotCompute>(Ret))
419 return Ret;
420 }
421
422 switch (Pred) {
423 case ICmpInst::ICMP_NE: { // while (X != Y)
424 // Convert to: while (X-Y != 0)
425 ExitLimit EL =
426 howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit, AllowPredicates);
427 if (EL.hasAnyInfo())
428 return EL;
429 break;
430 }
431 case ICmpInst::ICMP_EQ: { // while (X == Y)
432 // Convert to: while (X-Y == 0)
433 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
434 if (EL.hasAnyInfo())
435 return EL;
436 break;
437 }
438 case ICmpInst::ICMP_SLT:
439 case ICmpInst::ICMP_ULT:
440 case ICmpInst::ICMP_SLE:
441 case ICmpInst::ICMP_ULE: { // while (X < Y)
442 bool IsSigned = Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE;
443
444 if (Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_ULE) {
445 if (!isa<IntegerType>(RHS->getType()))
446 break;
447 SmallVector<SCEVUse, 2> sv = {
448 RHS,
449 getConstant(ConstantInt::get(cast<IntegerType>(RHS->getType()), 1))};
450 // Since this is not an infinite loop by induction, RHS cannot be
451 // int_max/uint_max Therefore adding 1 does not wrap.
452 if (IsSigned)
453 RHS = getAddExpr(sv, SCEV::FlagNSW);
454 else
455 RHS = getAddExpr(sv, SCEV::FlagNUW);
456 }
457 ExitLimit EL =
458 howManyLessThans(LHS, RHS, L, IsSigned, ControlsExit, AllowPredicates);
459 if (EL.hasAnyInfo())
460 return EL;
461 break;
462 }
463 case ICmpInst::ICMP_SGT:
464 case ICmpInst::ICMP_UGT:
465 case ICmpInst::ICMP_SGE:
466 case ICmpInst::ICMP_UGE: { // while (X > Y)
467 bool IsSigned = Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLE;
468 if (Pred == ICmpInst::ICMP_SGE || Pred == ICmpInst::ICMP_UGE) {
469 if (!isa<IntegerType>(RHS->getType()))
470 break;
471 SmallVector<SCEVUse, 2> sv = {
472 RHS,
473 getConstant(ConstantInt::get(cast<IntegerType>(RHS->getType()), -1))};
474 // Since this is not an infinite loop by induction, RHS cannot be
475 // int_min/uint_min Therefore subtracting 1 does not wrap.
476 if (IsSigned)
477 RHS = getAddExpr(sv, SCEV::FlagNSW);
478 else
479 RHS = getAddExpr(sv, SCEV::FlagNUW);
480 }
481 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit,
482 AllowPredicates);
483 if (EL.hasAnyInfo())
484 return EL;
485 break;
486 }
487 default:
488 break;
489 }
490
491 auto *ExhaustiveCount = computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
492
493 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
494 return ExhaustiveCount;
495
496 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
497 ExitCond->getOperand(1), L, OriginalPred);
498}
499
500#if LLVM_VERSION_MAJOR >= 13
501static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
502 ICmpInst::Predicate *Pred,
503 ScalarEvolution *SE) {
504 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
505 if (SE->isKnownPositive(Step)) {
506 *Pred = ICmpInst::ICMP_SLT;
507 return SE->getConstant(APInt::getSignedMinValue(BitWidth) -
508 SE->getSignedRangeMax(Step));
509 }
510 if (SE->isKnownNegative(Step)) {
511 *Pred = ICmpInst::ICMP_SGT;
512 return SE->getConstant(APInt::getSignedMaxValue(BitWidth) -
513 SE->getSignedRangeMin(Step));
514 }
515 return nullptr;
516}
517static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step,
518 ICmpInst::Predicate *Pred,
519 ScalarEvolution *SE) {
520 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
521 *Pred = ICmpInst::ICMP_ULT;
522
523 return SE->getConstant(APInt::getMinValue(BitWidth) -
524 SE->getUnsignedRangeMax(Step));
525}
526
527namespace {
528
529struct ExtendOpTraitsBase {
530 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
531 unsigned);
532};
533
534// Used to make code generic over signed and unsigned overflow.
535template <typename ExtendOp> struct ExtendOpTraits {
536 // Members present:
537 //
538 // static const SCEV::NoWrapFlags WrapType;
539 //
540 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
541 //
542 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
543 // ICmpInst::Predicate *Pred,
544 // ScalarEvolution *SE);
545};
546
547template <>
548struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
549 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
550
551 static const GetExtendExprTy GetExtendExpr;
552
553 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
554 ICmpInst::Predicate *Pred,
555 ScalarEvolution *SE) {
556 return getSignedOverflowLimitForStep(Step, Pred, SE);
557 }
558};
559
560const ExtendOpTraitsBase::GetExtendExprTy
561 ExtendOpTraits<SCEVSignExtendExpr>::GetExtendExpr =
562 &ScalarEvolution::getSignExtendExpr;
563
564template <>
565struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
566 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
567
568 static const GetExtendExprTy GetExtendExpr;
569
570 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
571 ICmpInst::Predicate *Pred,
572 ScalarEvolution *SE) {
573 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
574 }
575};
576
577const ExtendOpTraitsBase::GetExtendExprTy
578 ExtendOpTraits<SCEVZeroExtendExpr>::GetExtendExpr =
579 &ScalarEvolution::getZeroExtendExpr;
580
581} // end anonymous namespace
582
583static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags) {
584 return TestFlags == ScalarEvolution::maskFlags(Flags, TestFlags);
585};
586
587template <typename ExtendOpTy>
588static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
589 ScalarEvolution *SE, unsigned Depth) {
590 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
591 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
592
593 const Loop *L = AR->getLoop();
594 const SCEV *Start = AR->getStart();
595 const SCEV *Step = AR->getStepRecurrence(*SE);
596
597 // Check for a simple looking step prior to loop entry.
598 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
599 if (!SA)
600 return nullptr;
601
602 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
603 // subtraction is expensive. For this purpose, perform a quick and dirty
604 // difference, by checking for Step in the operand list.
605 SmallVector<SCEVUse, 4> DiffOps;
606 for (const SCEV *Op : SA->operands())
607 if (Op != Step)
608 DiffOps.push_back(Op);
609
610 if (DiffOps.size() == SA->getNumOperands())
611 return nullptr;
612
613 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
614 // `Step`:
615
616 // 1. NSW/NUW flags on the step increment.
617 auto PreStartFlags =
618 ScalarEvolution::maskFlags(SA->getNoWrapFlags(), SCEV::FlagNUW);
619 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
620 const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
621 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
622
623 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
624 // "S+X does not sign/unsign-overflow".
625 //
626
627 const SCEV *BECount = SE->getBackedgeTakenCount(L);
628#if LLVM_VERSION_MAJOR >= 23
629 if (PreAR && any(PreAR->getNoWrapFlags(WrapType)) &&
630 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
631 return PreStart;
632#else
633 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
634 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
635 return PreStart;
636#endif
637
638 // 2. Direct overflow check on the step operation's expression.
639 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
640 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
641 const SCEV *OperandExtendedStart =
642 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
643 (SE->*GetExtendExpr)(Step, WideTy, Depth));
644 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
645#if LLVM_VERSION_MAJOR >= 23
646 if (PreAR && any(AR->getNoWrapFlags(WrapType)))
647#else
648 if (PreAR && AR->getNoWrapFlags(WrapType))
649#endif
650 {
651 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
652 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
653 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
654 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
655 }
656 return PreStart;
657 }
658
659 // 3. Loop precondition.
660 ICmpInst::Predicate Pred;
661 const SCEV *OverflowLimit =
662 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
663
664 if (OverflowLimit &&
665 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
666 return PreStart;
667
668 return nullptr;
669}
670
671template <typename ExtendOpTy>
672static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
673 ScalarEvolution *SE, unsigned Depth) {
674 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
675
676 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
677 if (!PreStart)
678 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
679
680 return SE->getAddExpr(
681 (SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty, Depth),
682 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
683}
684
685static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE,
686 SCEVTypes Type,
687 const ArrayRef<const SCEV *> Ops,
688 SCEV::NoWrapFlags Flags) {
689 using namespace std::placeholders;
690
691 using OBO = OverflowingBinaryOperator;
692
693 bool CanAnalyze =
694 Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr;
695 (void)CanAnalyze;
696 assert(CanAnalyze && "don't call from other places!");
697
698 auto SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
699 SCEV::NoWrapFlags SignOrUnsignWrap =
700 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
701
702 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
703 auto IsKnownNonNegative = [&](const SCEV *S) {
704 return SE->isKnownNonNegative(S);
705 };
706
707 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
708 Flags =
709 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
710
711 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
712
713 if (SignOrUnsignWrap != SignOrUnsignMask &&
714 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
715 isa<SCEVConstant>(Ops[0])) {
716
717 auto Opcode = [&] {
718 switch (Type) {
719 case scAddExpr:
720 return Instruction::Add;
721 case scMulExpr:
722 return Instruction::Mul;
723 default:
724 llvm_unreachable("Unexpected SCEV op.");
725 }
726 }();
727
728 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
729
730 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
731 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
732 auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
733 Opcode, C, OBO::NoSignedWrap);
734 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
735 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
736 }
737
738 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
739 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
740 auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
741 Opcode, C, OBO::NoUnsignedWrap);
742 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
743 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
744 }
745 }
746
747 // <0,+,nonnegative><nw> is also nuw
748 // TODO: Add corresponding nsw case
749 if (Type == scAddRecExpr && hasFlags(Flags, SCEV::FlagNW) &&
750 !hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 && Ops[0]->isZero() &&
751 IsKnownNonNegative(Ops[1]))
752 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
753
754 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
755 if (Type == scMulExpr && !hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2) {
756 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
757 if (UDiv->getOperand(1) == Ops[1])
758 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
759 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
760 if (UDiv->getOperand(1) == Ops[0])
761 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
762 }
763
764 return Flags;
765}
766
767ScalarEvolution::ExitLimit MustExitScalarEvolution::howManyLessThans(
768 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
769 bool ControlsExit, bool AllowPredicates) {
770#if LLVM_VERSION_MAJOR >= 20
771 SmallVector<const SCEVPredicate *, 4> Predicates;
772#else
773 SmallPtrSet<const SCEVPredicate *, 4> Predicates;
774#endif
775
776 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
777 bool PredicatedIV = false;
778
779 auto canAssumeNoSelfWrap = [&](const SCEVAddRecExpr *AR) {
780 // Can we prove this loop *must* be UB if overflow of IV occurs?
781 // Reasoning goes as follows:
782 // * Suppose the IV did self wrap.
783 // * If Stride evenly divides the iteration space, then once wrap
784 // occurs, the loop must revisit the same values.
785 // * We know that RHS is invariant, and that none of those values
786 // caused this exit to be taken previously. Thus, this exit is
787 // dynamically dead.
788 // * If this is the sole exit, then a dead exit implies the loop
789 // must be infinite if there are no abnormal exits.
790 // * If the loop were infinite, then it must either not be mustprogress
791 // or have side effects. Otherwise, it must be UB.
792 // * It can't (by assumption), be UB so we have contradicted our
793 // premise and can conclude the IV did not in fact self-wrap.
794 if (!isLoopInvariant(RHS, L))
795 return false;
796
797 auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
798 if (!StrideC || !StrideC->getAPInt().isPowerOf2())
799 return false;
800
801 if (!ControlsExit || !loopHasNoAbnormalExits(L))
802 return false;
803
804 return loopIsFiniteByAssumption(L);
805 };
806
807 if (!IV) {
808 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
809 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
810 if (AR && AR->getLoop() == L && AR->isAffine()) {
811 auto Flags = AR->getNoWrapFlags();
812 if (!hasFlags(Flags, SCEV::FlagNW) && canAssumeNoSelfWrap(AR)) {
813 Flags = setFlags(Flags, SCEV::FlagNW);
814
815 SmallVector<const SCEV *, 4> Operands{AR->operands()};
816 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
817
818 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
819 }
820 if (AR->hasNoUnsignedWrap()) {
821 // Emulate what getZeroExtendExpr would have done during construction
822 // if we'd been able to infer the fact just above at that time.
823 const SCEV *Step = AR->getStepRecurrence(*this);
824 Type *Ty = ZExt->getType();
825 auto *S = getAddRecExpr(
826 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, 0),
827 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
828 IV = dyn_cast<SCEVAddRecExpr>(S);
829 }
830 }
831 }
832 }
833
834 if (!IV && AllowPredicates) {
835 // Try to make this an AddRec using runtime tests, in the first X
836 // iterations of this loop, where X is the SCEV expression found by the
837 // algorithm below.
838 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
839 PredicatedIV = true;
840 }
841
842 // Avoid weird loops
843 if (!IV || IV->getLoop() != L || !IV->isAffine())
844 return getCouldNotCompute();
845
846 // A precondition of this method is that the condition being analyzed
847 // reaches an exiting branch which dominates the latch. Given that, we can
848 // assume that an increment which violates the nowrap specification and
849 // produces poison must cause undefined behavior when the resulting poison
850 // value is branched upon and thus we can conclude that the backedge is
851 // taken no more often than would be required to produce that poison value.
852 // Note that a well defined loop can exit on the iteration which violates
853 // the nowrap specification if there is another exit (either explicit or
854 // implicit/exceptional) which causes the loop to execute before the
855 // exiting instruction we're analyzing would trigger UB.
856 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
857
858#if LLVM_VERSION_MAJOR >= 23
859 bool NoWrap = ControlsExit && any(IV->getNoWrapFlags(WrapType));
860#else
861 bool NoWrap = ControlsExit && IV->getNoWrapFlags(WrapType);
862#endif
863
864 ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
865
866 const SCEV *Stride = IV->getStepRecurrence(*this);
867
868 bool PositiveStride = isKnownPositive(Stride);
869
870 // Avoid negative or zero stride values.
871 if (!PositiveStride) {
872 // We can compute the correct backedge taken count for loops with unknown
873 // strides if we can prove that the loop is not an infinite loop with side
874 // effects. Here's the loop structure we are trying to handle -
875 //
876 // i = start
877 // do {
878 // A[i] = i;
879 // i += s;
880 // } while (i < end);
881 //
882 // The backedge taken count for such loops is evaluated as -
883 // (max(end, start + stride) - start - 1) /u stride
884 //
885 // The additional preconditions that we need to check to prove correctness
886 // of the above formula is as follows -
887 //
888 // a) IV is either nuw or nsw depending upon signedness (indicated by the
889 // NoWrap flag).
890 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
891 // no side effects within the loop)
892 // c) loop has a single static exit (with no abnormal exits)
893 //
894 // Precondition a) implies that if the stride is negative, this is a single
895 // trip loop. The backedge taken count formula reduces to zero in this case.
896 //
897 // Precondition b) and c) combine to imply that if rhs is invariant in L,
898 // then a zero stride means the backedge can't be taken without executing
899 // undefined behavior.
900 //
901 // The positive stride case is the same as isKnownPositive(Stride) returning
902 // true (original behavior of the function).
903 //
904 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
905 !loopHasNoAbnormalExits(L)) {
906 return getCouldNotCompute();
907 }
908
909 // This bailout is protecting the logic in computeMaxBECountForLT which
910 // has not yet been sufficiently auditted or tested with negative strides.
911 // We used to filter out all known-non-positive cases here, we're in the
912 // process of being less restrictive bit by bit.
913 if (IsSigned && isKnownNonPositive(Stride))
914 return getCouldNotCompute();
915
916 if (!isKnownNonZero(Stride)) {
917 // If we have a step of zero, and RHS isn't invariant in L, we don't know
918 // if it might eventually be greater than start and if so, on which
919 // iteration. We can't even produce a useful upper bound.
920 if (!isLoopInvariant(RHS, L))
921 return getCouldNotCompute();
922
923 // We allow a potentially zero stride, but we need to divide by stride
924 // below. Since the loop can't be infinite and this check must control
925 // the sole exit, we can infer the exit must be taken on the first
926 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
927 // we know the numerator in the divides below must be zero, so we can
928 // pick an arbitrary non-zero value for the denominator (e.g. stride)
929 // and produce the right result.
930 // FIXME: Handle the case where Stride is poison?
931 auto wouldZeroStrideBeUB = [&]() {
932 // Proof by contradiction. Suppose the stride were zero. If we can
933 // prove that the backedge *is* taken on the first iteration, then since
934 // we know this condition controls the sole exit, we must have an
935 // infinite loop. We can't have a (well defined) infinite loop per
936 // check just above.
937 // Note: The (Start - Stride) term is used to get the start' term from
938 // (start' + stride,+,stride). Remember that we only care about the
939 // result of this expression when stride == 0 at runtime.
940 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
941 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
942 };
943 if (!wouldZeroStrideBeUB()) {
944 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
945 }
946 }
947 } else if (!Stride->isOne() && !NoWrap) {
948 auto isUBOnWrap = [&]() {
949 // From no-self-wrap, we need to then prove no-(un)signed-wrap. This
950 // follows trivially from the fact that every (un)signed-wrapped, but
951 // not self-wrapped value must be LT than the last value before
952 // (un)signed wrap. Since we know that last value didn't exit, nor
953 // will any smaller one.
954 return canAssumeNoSelfWrap(IV);
955 };
956
957 // Avoid proven overflow cases: this will ensure that the backedge taken
958 // count will not generate any unsigned overflow. Relaxed no-overflow
959 // conditions exploit NoWrapFlags, allowing to optimize in presence of
960 // undefined behaviors like the case of C language.
961 if (canIVOverflowOnLT(RHS, Stride, IsSigned) && !isUBOnWrap())
962 return getCouldNotCompute();
963 }
964
965 // On all paths just preceeding, we established the following invariant:
966 // IV can be assumed not to overflow up to and including the exiting
967 // iteration. We proved this in one of two ways:
968 // 1) We can show overflow doesn't occur before the exiting iteration
969 // 1a) canIVOverflowOnLT, and b) step of one
970 // 2) We can show that if overflow occurs, the loop must execute UB
971 // before any possible exit.
972 // Note that we have not yet proved RHS invariant (in general).
973
974 const SCEV *Start = IV->getStart();
975
976 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
977 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
978 // Use integer-typed versions for actual computation; we can't subtract
979 // pointers in general.
980 const SCEV *OrigStart = Start;
981 const SCEV *OrigRHS = RHS;
982 if (Start->getType()->isPointerTy()) {
983 Start = getLosslessPtrToIntExpr(Start);
984 if (isa<SCEVCouldNotCompute>(Start))
985 return Start;
986 }
987 if (RHS->getType()->isPointerTy()) {
988 RHS = getLosslessPtrToIntExpr(RHS);
989 if (isa<SCEVCouldNotCompute>(RHS))
990 return RHS;
991 }
992
993 // When the RHS is not invariant, we do not know the end bound of the loop and
994 // cannot calculate the ExactBECount needed by ExitLimit. However, we can
995 // calculate the MaxBECount, given the start, stride and max value for the end
996 // bound of the loop (RHS), and the fact that IV does not overflow (which is
997 // checked above).
998 if (!isLoopInvariant(RHS, L)) {
999 const SCEV *MaxBECount = computeMaxBECountForLT(
1000 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
1001#if LLVM_VERSION_MAJOR >= 16
1002 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
1003 MaxBECount, false /*MaxOrZero*/, Predicates);
1004#else
1005 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
1006 false /*MaxOrZero*/, Predicates);
1007#endif
1008 }
1009
1010 // We use the expression (max(End,Start)-Start)/Stride to describe the
1011 // backedge count, as if the backedge is taken at least once max(End,Start)
1012 // is End and so the result is as above, and if not max(End,Start) is Start
1013 // so we get a backedge count of zero.
1014 const SCEV *BECount = nullptr;
1015 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
1016 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
1017 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
1018 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
1019 // Can we prove (max(RHS,Start) > Start - Stride?
1020 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
1021 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
1022 // In this case, we can use a refined formula for computing backedge taken
1023 // count. The general formula remains:
1024 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
1025 // We want to use the alternate formula:
1026 // "((End - 1) - (Start - Stride)) /u Stride"
1027 // Let's do a quick case analysis to show these are equivalent under
1028 // our precondition that max(RHS,Start) > Start - Stride.
1029 // * For RHS <= Start, the backedge-taken count must be zero.
1030 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
1031 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
1032 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
1033 // of Stride. For 0 stride, we've use umin(1,Stride) above, reducing
1034 // this to the stride of 1 case.
1035 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil Stride".
1036 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
1037 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
1038 // "((RHS - (Start - Stride) - 1) /u Stride".
1039 // Our preconditions trivially imply no overflow in that form.
1040 const SCEV *MinusOne = getMinusOne(Stride->getType());
1041 const SCEV *Numerator =
1042 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
1043 BECount = getUDivExpr(Numerator, Stride);
1044 }
1045
1046 const SCEV *BECountIfBackedgeTaken = nullptr;
1047 if (!BECount) {
1048 auto canProveRHSGreaterThanEqualStart = [&]() {
1049 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
1050 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart))
1051 return true;
1052
1053 // (RHS > Start - 1) implies RHS >= Start.
1054 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
1055 // "Start - 1" doesn't overflow.
1056 // * For signed comparison, if Start - 1 does overflow, it's equal
1057 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
1058 // * For unsigned comparison, if Start - 1 does overflow, it's equal
1059 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
1060 //
1061 // FIXME: Should isLoopEntryGuardedByCond do this for us?
1062 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
1063 auto *StartMinusOne =
1064 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
1065 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
1066 };
1067
1068 // If we know that RHS >= Start in the context of loop, then we know that
1069 // max(RHS, Start) = RHS at this point.
1070 const SCEV *End;
1071 if (canProveRHSGreaterThanEqualStart()) {
1072 End = RHS;
1073 } else {
1074 // If RHS < Start, the backedge will be taken zero times. So in
1075 // general, we can write the backedge-taken count as:
1076 //
1077 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
1078 //
1079 // We convert it to the following to make it more convenient for SCEV:
1080 //
1081 // ceil(max(RHS, Start) - Start) / Stride
1082 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
1083
1084 // See what would happen if we assume the backedge is taken. This is
1085 // used to compute MaxBECount.
1086 BECountIfBackedgeTaken =
1087 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
1088 }
1089
1090 // At this point, we know:
1091 //
1092 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
1093 // 2. The index variable doesn't overflow.
1094 //
1095 // Therefore, we know N exists such that
1096 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
1097 // doesn't overflow.
1098 //
1099 // Using this information, try to prove whether the addition in
1100 // "(Start - End) + (Stride - 1)" has unsigned overflow.
1101 const SCEV *One = getOne(Stride->getType());
1102 bool MayAddOverflow = [&] {
1103 if (auto *StrideC = dyn_cast<SCEVConstant>(Stride)) {
1104 if (StrideC->getAPInt().isPowerOf2()) {
1105 // Suppose Stride is a power of two, and Start/End are unsigned
1106 // integers. Let UMAX be the largest representable unsigned
1107 // integer.
1108 //
1109 // By the preconditions of this function, we know
1110 // "(Start + Stride * N) >= End", and this doesn't overflow.
1111 // As a formula:
1112 //
1113 // End <= (Start + Stride * N) <= UMAX
1114 //
1115 // Subtracting Start from all the terms:
1116 //
1117 // End - Start <= Stride * N <= UMAX - Start
1118 //
1119 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
1120 //
1121 // End - Start <= Stride * N <= UMAX
1122 //
1123 // Stride * N is a multiple of Stride. Therefore,
1124 //
1125 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
1126 //
1127 // Since Stride is a power of two, UMAX + 1 is divisible by Stride.
1128 // Therefore, UMAX mod Stride == Stride - 1. So we can write:
1129 //
1130 // End - Start <= Stride * N <= UMAX - Stride - 1
1131 //
1132 // Dropping the middle term:
1133 //
1134 // End - Start <= UMAX - Stride - 1
1135 //
1136 // Adding Stride - 1 to both sides:
1137 //
1138 // (End - Start) + (Stride - 1) <= UMAX
1139 //
1140 // In other words, the addition doesn't have unsigned overflow.
1141 //
1142 // A similar proof works if we treat Start/End as signed values.
1143 // Just rewrite steps before "End - Start <= Stride * N <= UMAX" to
1144 // use signed max instead of unsigned max. Note that we're trying
1145 // to prove a lack of unsigned overflow in either case.
1146 return false;
1147 }
1148 }
1149 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
1150 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End - 1.
1151 // If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1 <u End.
1152 // If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End - 1 <s End.
1153 //
1154 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 == End.
1155 return false;
1156 }
1157 return true;
1158 }();
1159
1160 const SCEV *Delta = getMinusSCEV(End, Start);
1161 if (!MayAddOverflow) {
1162 // floor((D + (S - 1)) / S)
1163 // We prefer this formulation if it's legal because it's fewer operations.
1164 BECount =
1165 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
1166 } else {
1167 BECount = getUDivCeilSCEV(Delta, Stride);
1168 }
1169 }
1170
1171 const SCEV *MaxBECount;
1172 bool MaxOrZero = false;
1173 if (isa<SCEVConstant>(BECount)) {
1174 MaxBECount = BECount;
1175 } else if (BECountIfBackedgeTaken &&
1176 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
1177 // If we know exactly how many times the backedge will be taken if it's
1178 // taken at least once, then the backedge count will either be that or
1179 // zero.
1180 MaxBECount = BECountIfBackedgeTaken;
1181 MaxOrZero = true;
1182 } else {
1183 MaxBECount = computeMaxBECountForLT(
1184 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
1185 }
1186
1187 if (isa<SCEVCouldNotCompute>(MaxBECount) &&
1188 !isa<SCEVCouldNotCompute>(BECount))
1189 MaxBECount = getConstant(getUnsignedRangeMax(BECount));
1190#if LLVM_VERSION_MAJOR >= 16
1191 return ExitLimit(BECount, MaxBECount, MaxBECount, MaxOrZero, Predicates);
1192#else
1193 return ExitLimit(BECount, MaxBECount, MaxOrZero, Predicates);
1194#endif
1195}
1196#else
1198 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
1199 bool ControlsExit, bool AllowPredicates) {
1200 SmallPtrSet<const SCEVPredicate *, 4> Predicates;
1201
1202 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
1203
1204 if (!IV && AllowPredicates) {
1205 // Try to make this an AddRec using runtime tests, in the first X
1206 // iterations of this loop, where X is the SCEV expression found by the
1207 // algorithm below.
1208 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
1209 }
1210
1211 // Avoid weird loops
1212 if (!IV || IV->getLoop() != L || !IV->isAffine())
1213 return getCouldNotCompute();
1214
1215 bool NoWrap = ControlsExit && true; // changed this to assume no wrap for inc
1216 // IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW);
1217
1218 const SCEV *Stride = IV->getStepRecurrence(*this);
1219
1220 bool PositiveStride = isKnownPositive(Stride);
1221
1222 // Avoid negative or zero stride values.
1223 if (!PositiveStride) {
1224 // We can compute the correct backedge taken count for loops with unknown
1225 // strides if we can prove that the loop is not an infinite loop with side
1226 // effects. Here's the loop structure we are trying to handle -
1227 //
1228 // i = start
1229 // do {
1230 // A[i] = i;
1231 // i += s;
1232 // } while (i < end);
1233 //
1234 // The backedge taken count for such loops is evaluated as -
1235 // (max(end, start + stride) - start - 1) /u stride
1236 //
1237 // The additional preconditions that we need to check to prove correctness
1238 // of the above formula is as follows -
1239 //
1240 // a) IV is either nuw or nsw depending upon signedness (indicated by the
1241 // NoWrap flag).
1242 // b) loop is single exit with no side effects. // dont need this
1243 //
1244 //
1245 // Precondition a) implies that if the stride is negative, this is a single
1246 // trip loop. The backedge taken count formula reduces to zero in this case.
1247 //
1248 // Precondition b) implies that the unknown stride cannot be zero otherwise
1249 // we have UB.
1250 //
1251 // The positive stride case is the same as isKnownPositive(Stride) returning
1252 // true (original behavior of the function).
1253 //
1254 // We want to make sure that the stride is truly unknown as there are edge
1255 // cases where ScalarEvolution propagates no wrap flags to the
1256 // post-increment/decrement IV even though the increment/decrement operation
1257 // itself is wrapping. The computed backedge taken count may be wrong in
1258 // such cases. This is prevented by checking that the stride is not known to
1259 // be either positive or non-positive. For example, no wrap flags are
1260 // propagated to the post-increment IV of this loop with a trip count of 2 -
1261 //
1262 // unsigned char i;
1263 // for(i=127; i<128; i+=129)
1264 // A[i] = i;
1265 //
1266 if (!NoWrap) // THIS LINE CHANGED
1267 return getCouldNotCompute();
1268 } else if (!Stride->isOne() &&
1269 doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap))
1270 // Avoid proven overflow cases: this will ensure that the backedge taken
1271 // count will not generate any unsigned overflow. Relaxed no-overflow
1272 // conditions exploit NoWrapFlags, allowing to optimize in presence of
1273 // undefined behaviors like the case of C language.
1274 return getCouldNotCompute();
1275
1276 ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
1277 const SCEV *Start = IV->getStart();
1278 const SCEV *End = RHS;
1279 // When the RHS is not invariant, we do not know the end bound of the loop and
1280 // cannot calculate the ExactBECount needed by ExitLimit. However, we can
1281 // calculate the MaxBECount, given the start, stride and max value for the end
1282 // bound of the loop (RHS), and the fact that IV does not overflow (which is
1283 // checked above).
1284 if (!isLoopInvariant(RHS, L)) {
1285 const SCEV *MaxBECount = computeMaxBECountForLT(
1286 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
1287 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
1288 false /*MaxOrZero*/, Predicates);
1289 }
1290 // If the backedge is taken at least once, then it will be taken
1291 // (End-Start)/Stride times (rounded up to a multiple of Stride), where Start
1292 // is the LHS value of the less-than comparison the first time it is evaluated
1293 // and End is the RHS.
1294 const SCEV *BECountIfBackedgeTaken =
1295 computeBECount(getMinusSCEV(End, Start), Stride, false);
1296 // If the loop entry is guarded by the result of the backedge test of the
1297 // first loop iteration, then we know the backedge will be taken at least
1298 // once and so the backedge taken count is as above. If not then we use the
1299 // expression (max(End,Start)-Start)/Stride to describe the backedge count,
1300 // as if the backedge is taken at least once max(End,Start) is End and so the
1301 // result is as above, and if not max(End,Start) is Start so we get a backedge
1302 // count of zero.
1303 const SCEV *BECount;
1304 if (isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS))
1305 BECount = BECountIfBackedgeTaken;
1306 else {
1307 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
1308 BECount = computeBECount(getMinusSCEV(End, Start), Stride, false);
1309 }
1310
1311 const SCEV *MaxBECount;
1312 bool MaxOrZero = false;
1313 if (isa<SCEVConstant>(BECount))
1314 MaxBECount = BECount;
1315 else if (isa<SCEVConstant>(BECountIfBackedgeTaken)) {
1316 // If we know exactly how many times the backedge will be taken if it's
1317 // taken at least once, then the backedge count will either be that or
1318 // zero.
1319 MaxBECount = BECountIfBackedgeTaken;
1320 MaxOrZero = true;
1321 } else {
1322 MaxBECount = computeMaxBECountForLT(
1323 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
1324 }
1325
1326 if (isa<SCEVCouldNotCompute>(MaxBECount) &&
1327 !isa<SCEVCouldNotCompute>(BECount))
1328 MaxBECount = getConstant(getUnsignedRangeMax(BECount));
1329
1330 return ExitLimit(BECount, MaxBECount, MaxOrZero, Predicates);
1331}
1332
1333#ifdef __clang__
1334#pragma clang diagnostic pop
1335#else
1336#pragma GCC diagnostic pop
1337#endif
1338
1339#endif
static llvm::SmallPtrSet< llvm::BasicBlock *, 4 > getGuaranteedUnreachable(llvm::Function *F)
#define PROP_PHI(LHS)
#define SCEVUse
ScalarEvolution::ExitLimit howManyLessThans(const llvm::SCEV *LHS, const llvm::SCEV *RHS, const llvm::Loop *L, bool IsSigned, bool ControlsExit, bool AllowPredicates)
llvm::SmallPtrSet< llvm::BasicBlock *, 4 > GuaranteedUnreachable
ScalarEvolution::ExitLimit computeExitLimitFromICmp(const llvm::Loop *L, llvm::ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsExit, bool AllowPredicates=false)
ScalarEvolution::ExitLimit computeExitLimitFromCond(const llvm::Loop *L, llvm::Value *ExitCond, bool ExitIfTrue, bool ControlsExit, bool AllowPredicates)
ScalarEvolution::ExitLimit computeExitLimitFromSingleExitSwitch(const llvm::Loop *L, llvm::SwitchInst *Switch, llvm::BasicBlock *ExitingBB, bool IsSubExpr)
ScalarEvolution::ExitLimit computeExitLimit(const llvm::Loop *L, llvm::BasicBlock *ExitingBlock, bool AllowPredicates)
ScalarEvolution::ExitLimit computeExitLimitFromCondCached(ExitLimitCacheTy &Cache, const llvm::Loop *L, llvm::Value *ExitCond, bool ExitIfTrue, bool ControlsExit, bool AllowPredicates)
ScalarEvolution::ExitLimit computeExitLimitFromCondImpl(ExitLimitCacheTy &Cache, const llvm::Loop *L, llvm::Value *ExitCond, bool ExitIfTrue, bool ControlsExit, bool AllowPredicates)
MustExitScalarEvolution(llvm::Function &F, llvm::TargetLibraryInfo &TLI, llvm::AssumptionCache &AC, llvm::DominatorTree &DT, llvm::LoopInfo &LI)
bool loopIsFiniteByAssumption(const llvm::Loop *L)