LLVM 17.0.0git
LoopPredication.cpp
Go to the documentation of this file.
1//===-- LoopPredication.cpp - Guard based loop predication pass -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6//
7//===----------------------------------------------------------------------===//
8//
9// The LoopPredication pass tries to convert loop variant range checks to loop
10// invariant by widening checks across loop iterations. For example, it will
11// convert
12//
13// for (i = 0; i < n; i++) {
14// guard(i < len);
15// ...
16// }
17//
18// to
19//
20// for (i = 0; i < n; i++) {
21// guard(n - 1 < len);
22// ...
23// }
24//
25// After this transformation the condition of the guard is loop invariant, so
26// loop-unswitch can later unswitch the loop by this condition which basically
27// predicates the loop by the widened condition:
28//
29// if (n - 1 < len)
30// for (i = 0; i < n; i++) {
31// ...
32// }
33// else
34// deoptimize
35//
36// It's tempting to rely on SCEV here, but it has proven to be problematic.
37// Generally the facts SCEV provides about the increment step of add
38// recurrences are true if the backedge of the loop is taken, which implicitly
39// assumes that the guard doesn't fail. Using these facts to optimize the
40// guard results in a circular logic where the guard is optimized under the
41// assumption that it never fails.
42//
43// For example, in the loop below the induction variable will be marked as nuw
44// basing on the guard. Basing on nuw the guard predicate will be considered
45// monotonic. Given a monotonic condition it's tempting to replace the induction
46// variable in the condition with its value on the last iteration. But this
47// transformation is not correct, e.g. e = 4, b = 5 breaks the loop.
48//
49// for (int i = b; i != e; i++)
50// guard(i u< len)
51//
53// approach. Given the loop:
54//
55// if (B(0)) {
56// do {
57// I = PHI(0, I.INC)
58// I.INC = I + Step
59// guard(G(I));
60// } while (B(I));
61// }
62//
63// where B(x) and G(x) are predicates that map integers to booleans, we want a
64// loop invariant expression M such the following program has the same semantics
65// as the above:
66//
67// if (B(0)) {
68// do {
69// I = PHI(0, I.INC)
70// I.INC = I + Step
71// guard(G(0) && M);
72// } while (B(I));
73// }
74//
75// One solution for M is M = forall X . (G(X) && B(X)) => G(X + Step)
76//
77// Informal proof that the transformation above is correct:
78//
79// By the definition of guards we can rewrite the guard condition to:
80// G(I) && G(0) && M
81//
82// Let's prove that for each iteration of the loop:
83// G(0) && M => G(I)
84// And the condition above can be simplified to G(Start) && M.
85//
86// Induction base.
87// G(0) && M => G(0)
88//
89// Induction step. Assuming G(0) && M => G(I) on the subsequent
90// iteration:
91//
92// B(I) is true because it's the backedge condition.
93// G(I) is true because the backedge is guarded by this condition.
94//
95// So M = forall X . (G(X) && B(X)) => G(X + Step) implies G(I + Step).
96//
97// Note that we can use anything stronger than M, i.e. any condition which
98// implies M.
99//
100// When S = 1 (i.e. forward iterating loop), the transformation is supported
101// when:
102// * The loop has a single latch with the condition of the form:
103// B(X) = latchStart + X <pred> latchLimit,
104// where <pred> is u<, u<=, s<, or s<=.
105// * The guard condition is of the form
106// G(X) = guardStart + X u< guardLimit
107//
108// For the ult latch comparison case M is:
109// forall X . guardStart + X u< guardLimit && latchStart + X <u latchLimit =>
110// guardStart + X + 1 u< guardLimit
111//
112// The only way the antecedent can be true and the consequent can be false is
113// if
114// X == guardLimit - 1 - guardStart
115// (and guardLimit is non-zero, but we won't use this latter fact).
116// If X == guardLimit - 1 - guardStart then the second half of the antecedent is
117// latchStart + guardLimit - 1 - guardStart u< latchLimit
118// and its negation is
119// latchStart + guardLimit - 1 - guardStart u>= latchLimit
120//
121// In other words, if
122// latchLimit u<= latchStart + guardLimit - 1 - guardStart
123// then:
124// (the ranges below are written in ConstantRange notation, where [A, B) is the
125// set for (I = A; I != B; I++ /*maywrap*/) yield(I);)
126//
127// forall X . guardStart + X u< guardLimit &&
128// latchStart + X u< latchLimit =>
129// guardStart + X + 1 u< guardLimit
130// == forall X . guardStart + X u< guardLimit &&
131// latchStart + X u< latchStart + guardLimit - 1 - guardStart =>
132// guardStart + X + 1 u< guardLimit
133// == forall X . (guardStart + X) in [0, guardLimit) &&
134// (latchStart + X) in [0, latchStart + guardLimit - 1 - guardStart) =>
135// (guardStart + X + 1) in [0, guardLimit)
136// == forall X . X in [-guardStart, guardLimit - guardStart) &&
137// X in [-latchStart, guardLimit - 1 - guardStart) =>
138// X in [-guardStart - 1, guardLimit - guardStart - 1)
139// == true
140//
141// So the widened condition is:
142// guardStart u< guardLimit &&
143// latchStart + guardLimit - 1 - guardStart u>= latchLimit
144// Similarly for ule condition the widened condition is:
145// guardStart u< guardLimit &&
146// latchStart + guardLimit - 1 - guardStart u> latchLimit
147// For slt condition the widened condition is:
148// guardStart u< guardLimit &&
149// latchStart + guardLimit - 1 - guardStart s>= latchLimit
150// For sle condition the widened condition is:
151// guardStart u< guardLimit &&
152// latchStart + guardLimit - 1 - guardStart s> latchLimit
153//
154// When S = -1 (i.e. reverse iterating loop), the transformation is supported
155// when:
156// * The loop has a single latch with the condition of the form:
157// B(X) = X <pred> latchLimit, where <pred> is u>, u>=, s>, or s>=.
158// * The guard condition is of the form
159// G(X) = X - 1 u< guardLimit
160//
161// For the ugt latch comparison case M is:
162// forall X. X-1 u< guardLimit and X u> latchLimit => X-2 u< guardLimit
163//
164// The only way the antecedent can be true and the consequent can be false is if
165// X == 1.
166// If X == 1 then the second half of the antecedent is
167// 1 u> latchLimit, and its negation is latchLimit u>= 1.
168//
169// So the widened condition is:
170// guardStart u< guardLimit && latchLimit u>= 1.
171// Similarly for sgt condition the widened condition is:
172// guardStart u< guardLimit && latchLimit s>= 1.
173// For uge condition the widened condition is:
174// guardStart u< guardLimit && latchLimit u> 1.
175// For sge condition the widened condition is:
176// guardStart u< guardLimit && latchLimit s> 1.
177//===----------------------------------------------------------------------===//
178
190#include "llvm/IR/Function.h"
192#include "llvm/IR/Module.h"
193#include "llvm/IR/PatternMatch.h"
196#include "llvm/Pass.h"
198#include "llvm/Support/Debug.h"
203
204#include <optional>
205
206#define DEBUG_TYPE "loop-predication"
207
208STATISTIC(TotalConsidered, "Number of guards considered");
209STATISTIC(TotalWidened, "Number of checks widened");
210
211using namespace llvm;
212
213static cl::opt<bool> EnableIVTruncation("loop-predication-enable-iv-truncation",
214 cl::Hidden, cl::init(true));
215
216static cl::opt<bool> EnableCountDownLoop("loop-predication-enable-count-down-loop",
217 cl::Hidden, cl::init(true));
218
219static cl::opt<bool>
220 SkipProfitabilityChecks("loop-predication-skip-profitability-checks",
221 cl::Hidden, cl::init(false));
222
223// This is the scale factor for the latch probability. We use this during
224// profitability analysis to find other exiting blocks that have a much higher
225// probability of exiting the loop instead of loop exiting via latch.
226// This value should be greater than 1 for a sane profitability check.
228 "loop-predication-latch-probability-scale", cl::Hidden, cl::init(2.0),
229 cl::desc("scale factor for the latch probability. Value should be greater "
230 "than 1. Lower values are ignored"));
231
233 "loop-predication-predicate-widenable-branches-to-deopt", cl::Hidden,
234 cl::desc("Whether or not we should predicate guards "
235 "expressed as widenable branches to deoptimize blocks"),
236 cl::init(true));
237
239 "loop-predication-insert-assumes-of-predicated-guards-conditions",
241 cl::desc("Whether or not we should insert assumes of conditions of "
242 "predicated guards"),
243 cl::init(true));
244
245namespace {
246/// Represents an induction variable check:
247/// icmp Pred, <induction variable>, <loop invariant limit>
248struct LoopICmp {
251 const SCEV *Limit;
252 LoopICmp(ICmpInst::Predicate Pred, const SCEVAddRecExpr *IV,
253 const SCEV *Limit)
254 : Pred(Pred), IV(IV), Limit(Limit) {}
255 LoopICmp() = default;
256 void dump() {
257 dbgs() << "LoopICmp Pred = " << Pred << ", IV = " << *IV
258 << ", Limit = " << *Limit << "\n";
259 }
260};
261
262class LoopPredication {
263 AliasAnalysis *AA;
264 DominatorTree *DT;
265 ScalarEvolution *SE;
266 LoopInfo *LI;
267 MemorySSAUpdater *MSSAU;
268
269 Loop *L;
270 const DataLayout *DL;
272 LoopICmp LatchCheck;
273
274 bool isSupportedStep(const SCEV* Step);
275 std::optional<LoopICmp> parseLoopICmp(ICmpInst *ICI);
276 std::optional<LoopICmp> parseLoopLatchICmp();
277
278 /// Return an insertion point suitable for inserting a safe to speculate
279 /// instruction whose only user will be 'User' which has operands 'Ops'. A
280 /// trivial result would be the at the User itself, but we try to return a
281 /// loop invariant location if possible.
282 Instruction *findInsertPt(Instruction *User, ArrayRef<Value*> Ops);
283 /// Same as above, *except* that this uses the SCEV definition of invariant
284 /// which is that an expression *can be made* invariant via SCEVExpander.
285 /// Thus, this version is only suitable for finding an insert point to be be
286 /// passed to SCEVExpander!
287 Instruction *findInsertPt(const SCEVExpander &Expander, Instruction *User,
289
290 /// Return true if the value is known to produce a single fixed value across
291 /// all iterations on which it executes. Note that this does not imply
292 /// speculation safety. That must be established separately.
293 bool isLoopInvariantValue(const SCEV* S);
294
295 Value *expandCheck(SCEVExpander &Expander, Instruction *Guard,
296 ICmpInst::Predicate Pred, const SCEV *LHS,
297 const SCEV *RHS);
298
299 std::optional<Value *> widenICmpRangeCheck(ICmpInst *ICI,
300 SCEVExpander &Expander,
301 Instruction *Guard);
302 std::optional<Value *>
303 widenICmpRangeCheckIncrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck,
304 SCEVExpander &Expander,
305 Instruction *Guard);
306 std::optional<Value *>
307 widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck,
308 SCEVExpander &Expander,
309 Instruction *Guard);
310 unsigned collectChecks(SmallVectorImpl<Value *> &Checks, Value *Condition,
311 SCEVExpander &Expander, Instruction *Guard);
312 bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander);
313 bool widenWidenableBranchGuardConditions(BranchInst *Guard, SCEVExpander &Expander);
314 // If the loop always exits through another block in the loop, we should not
315 // predicate based on the latch check. For example, the latch check can be a
316 // very coarse grained check and there can be more fine grained exit checks
317 // within the loop.
318 bool isLoopProfitableToPredicate();
319
320 bool predicateLoopExits(Loop *L, SCEVExpander &Rewriter);
321
322public:
323 LoopPredication(AliasAnalysis *AA, DominatorTree *DT, ScalarEvolution *SE,
324 LoopInfo *LI, MemorySSAUpdater *MSSAU)
325 : AA(AA), DT(DT), SE(SE), LI(LI), MSSAU(MSSAU){};
326 bool runOnLoop(Loop *L);
327};
328
329class LoopPredicationLegacyPass : public LoopPass {
330public:
331 static char ID;
332 LoopPredicationLegacyPass() : LoopPass(ID) {
334 }
335
336 void getAnalysisUsage(AnalysisUsage &AU) const override {
340 }
341
342 bool runOnLoop(Loop *L, LPPassManager &LPM) override {
343 if (skipLoop(L))
344 return false;
345 auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
346 auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
347 auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
348 auto *MSSAWP = getAnalysisIfAvailable<MemorySSAWrapperPass>();
349 std::unique_ptr<MemorySSAUpdater> MSSAU;
350 if (MSSAWP)
351 MSSAU = std::make_unique<MemorySSAUpdater>(&MSSAWP->getMSSA());
352 auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
353 LoopPredication LP(AA, DT, SE, LI, MSSAU ? MSSAU.get() : nullptr);
354 return LP.runOnLoop(L);
355 }
356};
357
358char LoopPredicationLegacyPass::ID = 0;
359} // end namespace
360
361INITIALIZE_PASS_BEGIN(LoopPredicationLegacyPass, "loop-predication",
362 "Loop predication", false, false)
365INITIALIZE_PASS_END(LoopPredicationLegacyPass, "loop-predication",
367
369 return new LoopPredicationLegacyPass();
370}
371
374 LPMUpdater &U) {
375 std::unique_ptr<MemorySSAUpdater> MSSAU;
376 if (AR.MSSA)
377 MSSAU = std::make_unique<MemorySSAUpdater>(AR.MSSA);
378 LoopPredication LP(&AR.AA, &AR.DT, &AR.SE, &AR.LI,
379 MSSAU ? MSSAU.get() : nullptr);
380 if (!LP.runOnLoop(&L))
381 return PreservedAnalyses::all();
382
384 if (AR.MSSA)
385 PA.preserve<MemorySSAAnalysis>();
386 return PA;
387}
388
389std::optional<LoopICmp> LoopPredication::parseLoopICmp(ICmpInst *ICI) {
390 auto Pred = ICI->getPredicate();
391 auto *LHS = ICI->getOperand(0);
392 auto *RHS = ICI->getOperand(1);
393
394 const SCEV *LHSS = SE->getSCEV(LHS);
395 if (isa<SCEVCouldNotCompute>(LHSS))
396 return std::nullopt;
397 const SCEV *RHSS = SE->getSCEV(RHS);
398 if (isa<SCEVCouldNotCompute>(RHSS))
399 return std::nullopt;
400
401 // Canonicalize RHS to be loop invariant bound, LHS - a loop computable IV
402 if (SE->isLoopInvariant(LHSS, L)) {
403 std::swap(LHS, RHS);
404 std::swap(LHSS, RHSS);
406 }
407
409 if (!AR || AR->getLoop() != L)
410 return std::nullopt;
411
412 return LoopICmp(Pred, AR, RHSS);
413}
414
415Value *LoopPredication::expandCheck(SCEVExpander &Expander,
416 Instruction *Guard,
417 ICmpInst::Predicate Pred, const SCEV *LHS,
418 const SCEV *RHS) {
419 Type *Ty = LHS->getType();
420 assert(Ty == RHS->getType() && "expandCheck operands have different types?");
421
422 if (SE->isLoopInvariant(LHS, L) && SE->isLoopInvariant(RHS, L)) {
423 IRBuilder<> Builder(Guard);
424 if (SE->isLoopEntryGuardedByCond(L, Pred, LHS, RHS))
425 return Builder.getTrue();
426 if (SE->isLoopEntryGuardedByCond(L, ICmpInst::getInversePredicate(Pred),
427 LHS, RHS))
428 return Builder.getFalse();
429 }
430
431 Value *LHSV =
432 Expander.expandCodeFor(LHS, Ty, findInsertPt(Expander, Guard, {LHS}));
433 Value *RHSV =
434 Expander.expandCodeFor(RHS, Ty, findInsertPt(Expander, Guard, {RHS}));
435 IRBuilder<> Builder(findInsertPt(Guard, {LHSV, RHSV}));
436 return Builder.CreateICmp(Pred, LHSV, RHSV);
437}
438
439// Returns true if its safe to truncate the IV to RangeCheckType.
440// When the IV type is wider than the range operand type, we can still do loop
441// predication, by generating SCEVs for the range and latch that are of the
442// same type. We achieve this by generating a SCEV truncate expression for the
443// latch IV. This is done iff truncation of the IV is a safe operation,
444// without loss of information.
445// Another way to achieve this is by generating a wider type SCEV for the
446// range check operand, however, this needs a more involved check that
447// operands do not overflow. This can lead to loss of information when the
448// range operand is of the form: add i32 %offset, %iv. We need to prove that
449// sext(x + y) is same as sext(x) + sext(y).
450// This function returns true if we can safely represent the IV type in
451// the RangeCheckType without loss of information.
453 ScalarEvolution &SE,
454 const LoopICmp LatchCheck,
455 Type *RangeCheckType) {
457 return false;
458 assert(DL.getTypeSizeInBits(LatchCheck.IV->getType()).getFixedValue() >
459 DL.getTypeSizeInBits(RangeCheckType).getFixedValue() &&
460 "Expected latch check IV type to be larger than range check operand "
461 "type!");
462 // The start and end values of the IV should be known. This is to guarantee
463 // that truncating the wide type will not lose information.
464 auto *Limit = dyn_cast<SCEVConstant>(LatchCheck.Limit);
465 auto *Start = dyn_cast<SCEVConstant>(LatchCheck.IV->getStart());
466 if (!Limit || !Start)
467 return false;
468 // This check makes sure that the IV does not change sign during loop
469 // iterations. Consider latchType = i64, LatchStart = 5, Pred = ICMP_SGE,
470 // LatchEnd = 2, rangeCheckType = i32. If it's not a monotonic predicate, the
471 // IV wraps around, and the truncation of the IV would lose the range of
472 // iterations between 2^32 and 2^64.
473 if (!SE.getMonotonicPredicateType(LatchCheck.IV, LatchCheck.Pred))
474 return false;
475 // The active bits should be less than the bits in the RangeCheckType. This
476 // guarantees that truncating the latch check to RangeCheckType is a safe
477 // operation.
478 auto RangeCheckTypeBitSize =
479 DL.getTypeSizeInBits(RangeCheckType).getFixedValue();
480 return Start->getAPInt().getActiveBits() < RangeCheckTypeBitSize &&
481 Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize;
482}
483
484
485// Return an LoopICmp describing a latch check equivlent to LatchCheck but with
486// the requested type if safe to do so. May involve the use of a new IV.
487static std::optional<LoopICmp> generateLoopLatchCheck(const DataLayout &DL,
488 ScalarEvolution &SE,
489 const LoopICmp LatchCheck,
490 Type *RangeCheckType) {
491
492 auto *LatchType = LatchCheck.IV->getType();
493 if (RangeCheckType == LatchType)
494 return LatchCheck;
495 // For now, bail out if latch type is narrower than range type.
496 if (DL.getTypeSizeInBits(LatchType).getFixedValue() <
497 DL.getTypeSizeInBits(RangeCheckType).getFixedValue())
498 return std::nullopt;
499 if (!isSafeToTruncateWideIVType(DL, SE, LatchCheck, RangeCheckType))
500 return std::nullopt;
501 // We can now safely identify the truncated version of the IV and limit for
502 // RangeCheckType.
503 LoopICmp NewLatchCheck;
504 NewLatchCheck.Pred = LatchCheck.Pred;
506 SE.getTruncateExpr(LatchCheck.IV, RangeCheckType));
507 if (!NewLatchCheck.IV)
508 return std::nullopt;
509 NewLatchCheck.Limit = SE.getTruncateExpr(LatchCheck.Limit, RangeCheckType);
510 LLVM_DEBUG(dbgs() << "IV of type: " << *LatchType
511 << "can be represented as range check type:"
512 << *RangeCheckType << "\n");
513 LLVM_DEBUG(dbgs() << "LatchCheck.IV: " << *NewLatchCheck.IV << "\n");
514 LLVM_DEBUG(dbgs() << "LatchCheck.Limit: " << *NewLatchCheck.Limit << "\n");
515 return NewLatchCheck;
516}
517
518bool LoopPredication::isSupportedStep(const SCEV* Step) {
519 return Step->isOne() || (Step->isAllOnesValue() && EnableCountDownLoop);
520}
521
522Instruction *LoopPredication::findInsertPt(Instruction *Use,
523 ArrayRef<Value*> Ops) {
524 for (Value *Op : Ops)
525 if (!L->isLoopInvariant(Op))
526 return Use;
528}
529
530Instruction *LoopPredication::findInsertPt(const SCEVExpander &Expander,
533 // Subtlety: SCEV considers things to be invariant if the value produced is
534 // the same across iterations. This is not the same as being able to
535 // evaluate outside the loop, which is what we actually need here.
536 for (const SCEV *Op : Ops)
537 if (!SE->isLoopInvariant(Op, L) ||
539 return Use;
541}
542
543bool LoopPredication::isLoopInvariantValue(const SCEV* S) {
544 // Handling expressions which produce invariant results, but *haven't* yet
545 // been removed from the loop serves two important purposes.
546 // 1) Most importantly, it resolves a pass ordering cycle which would
547 // otherwise need us to iteration licm, loop-predication, and either
548 // loop-unswitch or loop-peeling to make progress on examples with lots of
549 // predicable range checks in a row. (Since, in the general case, we can't
550 // hoist the length checks until the dominating checks have been discharged
551 // as we can't prove doing so is safe.)
552 // 2) As a nice side effect, this exposes the value of peeling or unswitching
553 // much more obviously in the IR. Otherwise, the cost modeling for other
554 // transforms would end up needing to duplicate all of this logic to model a
555 // check which becomes predictable based on a modeled peel or unswitch.
556 //
557 // The cost of doing so in the worst case is an extra fill from the stack in
558 // the loop to materialize the loop invariant test value instead of checking
559 // against the original IV which is presumable in a register inside the loop.
560 // Such cases are presumably rare, and hint at missing oppurtunities for
561 // other passes.
562
563 if (SE->isLoopInvariant(S, L))
564 // Note: This the SCEV variant, so the original Value* may be within the
565 // loop even though SCEV has proven it is loop invariant.
566 return true;
567
568 // Handle a particular important case which SCEV doesn't yet know about which
569 // shows up in range checks on arrays with immutable lengths.
570 // TODO: This should be sunk inside SCEV.
571 if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S))
572 if (const auto *LI = dyn_cast<LoadInst>(U->getValue()))
573 if (LI->isUnordered() && L->hasLoopInvariantOperands(LI))
576 return true;
577 return false;
578}
579
580std::optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop(
581 LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander,
582 Instruction *Guard) {
583 auto *Ty = RangeCheck.IV->getType();
584 // Generate the widened condition for the forward loop:
585 // guardStart u< guardLimit &&
586 // latchLimit <pred> guardLimit - 1 - guardStart + latchStart
587 // where <pred> depends on the latch condition predicate. See the file
588 // header comment for the reasoning.
589 // guardLimit - guardStart + latchStart - 1
590 const SCEV *GuardStart = RangeCheck.IV->getStart();
591 const SCEV *GuardLimit = RangeCheck.Limit;
592 const SCEV *LatchStart = LatchCheck.IV->getStart();
593 const SCEV *LatchLimit = LatchCheck.Limit;
594 // Subtlety: We need all the values to be *invariant* across all iterations,
595 // but we only need to check expansion safety for those which *aren't*
596 // already guaranteed to dominate the guard.
597 if (!isLoopInvariantValue(GuardStart) ||
598 !isLoopInvariantValue(GuardLimit) ||
599 !isLoopInvariantValue(LatchStart) ||
600 !isLoopInvariantValue(LatchLimit)) {
601 LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
602 return std::nullopt;
603 }
604 if (!Expander.isSafeToExpandAt(LatchStart, Guard) ||
605 !Expander.isSafeToExpandAt(LatchLimit, Guard)) {
606 LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
607 return std::nullopt;
608 }
609
610 // guardLimit - guardStart + latchStart - 1
611 const SCEV *RHS =
613 SE->getMinusSCEV(LatchStart, SE->getOne(Ty)));
614 auto LimitCheckPred =
616
617 LLVM_DEBUG(dbgs() << "LHS: " << *LatchLimit << "\n");
618 LLVM_DEBUG(dbgs() << "RHS: " << *RHS << "\n");
619 LLVM_DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n");
620
621 auto *LimitCheck =
622 expandCheck(Expander, Guard, LimitCheckPred, LatchLimit, RHS);
623 auto *FirstIterationCheck = expandCheck(Expander, Guard, RangeCheck.Pred,
624 GuardStart, GuardLimit);
625 IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck}));
626 return Builder.CreateFreeze(
627 Builder.CreateAnd(FirstIterationCheck, LimitCheck));
628}
629
630std::optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(
631 LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander,
632 Instruction *Guard) {
633 auto *Ty = RangeCheck.IV->getType();
634 const SCEV *GuardStart = RangeCheck.IV->getStart();
635 const SCEV *GuardLimit = RangeCheck.Limit;
636 const SCEV *LatchStart = LatchCheck.IV->getStart();
637 const SCEV *LatchLimit = LatchCheck.Limit;
638 // Subtlety: We need all the values to be *invariant* across all iterations,
639 // but we only need to check expansion safety for those which *aren't*
640 // already guaranteed to dominate the guard.
641 if (!isLoopInvariantValue(GuardStart) ||
642 !isLoopInvariantValue(GuardLimit) ||
643 !isLoopInvariantValue(LatchStart) ||
644 !isLoopInvariantValue(LatchLimit)) {
645 LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
646 return std::nullopt;
647 }
648 if (!Expander.isSafeToExpandAt(LatchStart, Guard) ||
649 !Expander.isSafeToExpandAt(LatchLimit, Guard)) {
650 LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
651 return std::nullopt;
652 }
653 // The decrement of the latch check IV should be the same as the
654 // rangeCheckIV.
655 auto *PostDecLatchCheckIV = LatchCheck.IV->getPostIncExpr(*SE);
656 if (RangeCheck.IV != PostDecLatchCheckIV) {
657 LLVM_DEBUG(dbgs() << "Not the same. PostDecLatchCheckIV: "
658 << *PostDecLatchCheckIV
659 << " and RangeCheckIV: " << *RangeCheck.IV << "\n");
660 return std::nullopt;
661 }
662
663 // Generate the widened condition for CountDownLoop:
664 // guardStart u< guardLimit &&
665 // latchLimit <pred> 1.
666 // See the header comment for reasoning of the checks.
667 auto LimitCheckPred =
669 auto *FirstIterationCheck = expandCheck(Expander, Guard,
671 GuardStart, GuardLimit);
672 auto *LimitCheck = expandCheck(Expander, Guard, LimitCheckPred, LatchLimit,
673 SE->getOne(Ty));
674 IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck}));
675 return Builder.CreateFreeze(
676 Builder.CreateAnd(FirstIterationCheck, LimitCheck));
677}
678
680 LoopICmp& RC) {
681 // LFTR canonicalizes checks to the ICMP_NE/EQ form; normalize back to the
682 // ULT/UGE form for ease of handling by our caller.
683 if (ICmpInst::isEquality(RC.Pred) &&
684 RC.IV->getStepRecurrence(*SE)->isOne() &&
685 SE->isKnownPredicate(ICmpInst::ICMP_ULE, RC.IV->getStart(), RC.Limit))
686 RC.Pred = RC.Pred == ICmpInst::ICMP_NE ?
688}
689
690/// If ICI can be widened to a loop invariant condition emits the loop
691/// invariant condition in the loop preheader and return it, otherwise
692/// returns std::nullopt.
693std::optional<Value *>
694LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander,
695 Instruction *Guard) {
696 LLVM_DEBUG(dbgs() << "Analyzing ICmpInst condition:\n");
697 LLVM_DEBUG(ICI->dump());
698
699 // parseLoopStructure guarantees that the latch condition is:
700 // ++i <pred> latchLimit, where <pred> is u<, u<=, s<, or s<=.
701 // We are looking for the range checks of the form:
702 // i u< guardLimit
703 auto RangeCheck = parseLoopICmp(ICI);
704 if (!RangeCheck) {
705 LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n");
706 return std::nullopt;
707 }
708 LLVM_DEBUG(dbgs() << "Guard check:\n");
709 LLVM_DEBUG(RangeCheck->dump());
710 if (RangeCheck->Pred != ICmpInst::ICMP_ULT) {
711 LLVM_DEBUG(dbgs() << "Unsupported range check predicate("
712 << RangeCheck->Pred << ")!\n");
713 return std::nullopt;
714 }
715 auto *RangeCheckIV = RangeCheck->IV;
716 if (!RangeCheckIV->isAffine()) {
717 LLVM_DEBUG(dbgs() << "Range check IV is not affine!\n");
718 return std::nullopt;
719 }
720 auto *Step = RangeCheckIV->getStepRecurrence(*SE);
721 // We cannot just compare with latch IV step because the latch and range IVs
722 // may have different types.
723 if (!isSupportedStep(Step)) {
724 LLVM_DEBUG(dbgs() << "Range check and latch have IVs different steps!\n");
725 return std::nullopt;
726 }
727 auto *Ty = RangeCheckIV->getType();
728 auto CurrLatchCheckOpt = generateLoopLatchCheck(*DL, *SE, LatchCheck, Ty);
729 if (!CurrLatchCheckOpt) {
730 LLVM_DEBUG(dbgs() << "Failed to generate a loop latch check "
731 "corresponding to range type: "
732 << *Ty << "\n");
733 return std::nullopt;
734 }
735
736 LoopICmp CurrLatchCheck = *CurrLatchCheckOpt;
737 // At this point, the range and latch step should have the same type, but need
738 // not have the same value (we support both 1 and -1 steps).
739 assert(Step->getType() ==
740 CurrLatchCheck.IV->getStepRecurrence(*SE)->getType() &&
741 "Range and latch steps should be of same type!");
742 if (Step != CurrLatchCheck.IV->getStepRecurrence(*SE)) {
743 LLVM_DEBUG(dbgs() << "Range and latch have different step values!\n");
744 return std::nullopt;
745 }
746
747 if (Step->isOne())
748 return widenICmpRangeCheckIncrementingLoop(CurrLatchCheck, *RangeCheck,
749 Expander, Guard);
750 else {
751 assert(Step->isAllOnesValue() && "Step should be -1!");
752 return widenICmpRangeCheckDecrementingLoop(CurrLatchCheck, *RangeCheck,
753 Expander, Guard);
754 }
755}
756
757unsigned LoopPredication::collectChecks(SmallVectorImpl<Value *> &Checks,
758 Value *Condition,
759 SCEVExpander &Expander,
760 Instruction *Guard) {
761 unsigned NumWidened = 0;
762 // The guard condition is expected to be in form of:
763 // cond1 && cond2 && cond3 ...
764 // Iterate over subconditions looking for icmp conditions which can be
765 // widened across loop iterations. Widening these conditions remember the
766 // resulting list of subconditions in Checks vector.
767 SmallVector<Value *, 4> Worklist(1, Condition);
769 Visited.insert(Condition);
770 Value *WideableCond = nullptr;
771 do {
772 Value *Condition = Worklist.pop_back_val();
773 Value *LHS, *RHS;
774 using namespace llvm::PatternMatch;
775 if (match(Condition, m_And(m_Value(LHS), m_Value(RHS)))) {
776 if (Visited.insert(LHS).second)
777 Worklist.push_back(LHS);
778 if (Visited.insert(RHS).second)
779 Worklist.push_back(RHS);
780 continue;
781 }
782
783 if (match(Condition,
784 m_Intrinsic<Intrinsic::experimental_widenable_condition>())) {
785 // Pick any, we don't care which
786 WideableCond = Condition;
787 continue;
788 }
789
790 if (ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) {
791 if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander,
792 Guard)) {
793 Checks.push_back(*NewRangeCheck);
794 NumWidened++;
795 continue;
796 }
797 }
798
799 // Save the condition as is if we can't widen it
800 Checks.push_back(Condition);
801 } while (!Worklist.empty());
802 // At the moment, our matching logic for wideable conditions implicitly
803 // assumes we preserve the form: (br (and Cond, WC())). FIXME
804 // Note that if there were multiple calls to wideable condition in the
805 // traversal, we only need to keep one, and which one is arbitrary.
806 if (WideableCond)
807 Checks.push_back(WideableCond);
808 return NumWidened;
809}
810
811bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard,
812 SCEVExpander &Expander) {
813 LLVM_DEBUG(dbgs() << "Processing guard:\n");
814 LLVM_DEBUG(Guard->dump());
815
816 TotalConsidered++;
818 unsigned NumWidened = collectChecks(Checks, Guard->getOperand(0), Expander,
819 Guard);
820 if (NumWidened == 0)
821 return false;
822
823 TotalWidened += NumWidened;
824
825 // Emit the new guard condition
826 IRBuilder<> Builder(findInsertPt(Guard, Checks));
827 Value *AllChecks = Builder.CreateAnd(Checks);
828 auto *OldCond = Guard->getOperand(0);
829 Guard->setOperand(0, AllChecks);
831 Builder.SetInsertPoint(&*++BasicBlock::iterator(Guard));
832 Builder.CreateAssumption(OldCond);
833 }
834 RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU);
835
836 LLVM_DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n");
837 return true;
838}
839
840bool LoopPredication::widenWidenableBranchGuardConditions(
841 BranchInst *BI, SCEVExpander &Expander) {
842 assert(isGuardAsWidenableBranch(BI) && "Must be!");
843 LLVM_DEBUG(dbgs() << "Processing guard:\n");
844 LLVM_DEBUG(BI->dump());
845
846 Value *Cond, *WC;
847 BasicBlock *IfTrueBB, *IfFalseBB;
848 bool Parsed = parseWidenableBranch(BI, Cond, WC, IfTrueBB, IfFalseBB);
849 assert(Parsed && "Must be able to parse widenable branch");
850 (void)Parsed;
851
852 TotalConsidered++;
854 unsigned NumWidened = collectChecks(Checks, BI->getCondition(),
855 Expander, BI);
856 if (NumWidened == 0)
857 return false;
858
859 TotalWidened += NumWidened;
860
861 // Emit the new guard condition
862 IRBuilder<> Builder(findInsertPt(BI, Checks));
863 Value *AllChecks = Builder.CreateAnd(Checks);
864 auto *OldCond = BI->getCondition();
865 BI->setCondition(AllChecks);
867 Builder.SetInsertPoint(IfTrueBB, IfTrueBB->getFirstInsertionPt());
868 // If this block has other predecessors, we might not be able to use Cond.
869 // In this case, create a Phi where every other input is true and input
870 // from guard block is Cond.
871 Value *AssumeCond = Cond;
872 if (!IfTrueBB->getUniquePredecessor()) {
873 auto *GuardBB = BI->getParent();
874 auto *PN = Builder.CreatePHI(Cond->getType(), pred_size(IfTrueBB),
875 "assume.cond");
876 for (auto *Pred : predecessors(IfTrueBB))
877 PN->addIncoming(Pred == GuardBB ? Cond : Builder.getTrue(), Pred);
878 AssumeCond = PN;
879 }
880 Builder.CreateAssumption(AssumeCond);
881 }
882 RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU);
884 "Stopped being a guard after transform?");
885
886 LLVM_DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n");
887 return true;
888}
889
890std::optional<LoopICmp> LoopPredication::parseLoopLatchICmp() {
891 using namespace PatternMatch;
892
893 BasicBlock *LoopLatch = L->getLoopLatch();
894 if (!LoopLatch) {
895 LLVM_DEBUG(dbgs() << "The loop doesn't have a single latch!\n");
896 return std::nullopt;
897 }
898
899 auto *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator());
900 if (!BI || !BI->isConditional()) {
901 LLVM_DEBUG(dbgs() << "Failed to match the latch terminator!\n");
902 return std::nullopt;
903 }
904 BasicBlock *TrueDest = BI->getSuccessor(0);
905 assert(
907 "One of the latch's destinations must be the header");
908
909 auto *ICI = dyn_cast<ICmpInst>(BI->getCondition());
910 if (!ICI) {
911 LLVM_DEBUG(dbgs() << "Failed to match the latch condition!\n");
912 return std::nullopt;
913 }
914 auto Result = parseLoopICmp(ICI);
915 if (!Result) {
916 LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n");
917 return std::nullopt;
918 }
919
922
923 // Check affine first, so if it's not we don't try to compute the step
924 // recurrence.
925 if (!Result->IV->isAffine()) {
926 LLVM_DEBUG(dbgs() << "The induction variable is not affine!\n");
927 return std::nullopt;
928 }
929
930 auto *Step = Result->IV->getStepRecurrence(*SE);
931 if (!isSupportedStep(Step)) {
932 LLVM_DEBUG(dbgs() << "Unsupported loop stride(" << *Step << ")!\n");
933 return std::nullopt;
934 }
935
936 auto IsUnsupportedPredicate = [](const SCEV *Step, ICmpInst::Predicate Pred) {
937 if (Step->isOne()) {
938 return Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_SLT &&
939 Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_SLE;
940 } else {
941 assert(Step->isAllOnesValue() && "Step should be -1!");
942 return Pred != ICmpInst::ICMP_UGT && Pred != ICmpInst::ICMP_SGT &&
943 Pred != ICmpInst::ICMP_UGE && Pred != ICmpInst::ICMP_SGE;
944 }
945 };
946
947 normalizePredicate(SE, L, *Result);
948 if (IsUnsupportedPredicate(Step, Result->Pred)) {
949 LLVM_DEBUG(dbgs() << "Unsupported loop latch predicate(" << Result->Pred
950 << ")!\n");
951 return std::nullopt;
952 }
953
954 return Result;
955}
956
957bool LoopPredication::isLoopProfitableToPredicate() {
959 return true;
960
962 L->getExitEdges(ExitEdges);
963 // If there is only one exiting edge in the loop, it is always profitable to
964 // predicate the loop.
965 if (ExitEdges.size() == 1)
966 return true;
967
968 // Calculate the exiting probabilities of all exiting edges from the loop,
969 // starting with the LatchExitProbability.
970 // Heuristic for profitability: If any of the exiting blocks' probability of
971 // exiting the loop is larger than exiting through the latch block, it's not
972 // profitable to predicate the loop.
973 auto *LatchBlock = L->getLoopLatch();
974 assert(LatchBlock && "Should have a single latch at this point!");
975 auto *LatchTerm = LatchBlock->getTerminator();
976 assert(LatchTerm->getNumSuccessors() == 2 &&
977 "expected to be an exiting block with 2 succs!");
978 unsigned LatchBrExitIdx =
979 LatchTerm->getSuccessor(0) == L->getHeader() ? 1 : 0;
980 // We compute branch probabilities without BPI. We do not rely on BPI since
981 // Loop predication is usually run in an LPM and BPI is only preserved
982 // lossily within loop pass managers, while BPI has an inherent notion of
983 // being complete for an entire function.
984
985 // If the latch exits into a deoptimize or an unreachable block, do not
986 // predicate on that latch check.
987 auto *LatchExitBlock = LatchTerm->getSuccessor(LatchBrExitIdx);
988 if (isa<UnreachableInst>(LatchTerm) ||
989 LatchExitBlock->getTerminatingDeoptimizeCall())
990 return false;
991
992 // Latch terminator has no valid profile data, so nothing to check
993 // profitability on.
994 if (!hasValidBranchWeightMD(*LatchTerm))
995 return true;
996
997 auto ComputeBranchProbability =
998 [&](const BasicBlock *ExitingBlock,
999 const BasicBlock *ExitBlock) -> BranchProbability {
1000 auto *Term = ExitingBlock->getTerminator();
1001 unsigned NumSucc = Term->getNumSuccessors();
1002 if (MDNode *ProfileData = getValidBranchWeightMDNode(*Term)) {
1003 SmallVector<uint32_t> Weights;
1004 extractBranchWeights(ProfileData, Weights);
1005 uint64_t Numerator = 0, Denominator = 0;
1006 for (auto [i, Weight] : llvm::enumerate(Weights)) {
1007 if (Term->getSuccessor(i) == ExitBlock)
1008 Numerator += Weight;
1009 Denominator += Weight;
1010 }
1011 return BranchProbability::getBranchProbability(Numerator, Denominator);
1012 } else {
1013 assert(LatchBlock != ExitingBlock &&
1014 "Latch term should always have profile data!");
1015 // No profile data, so we choose the weight as 1/num_of_succ(Src)
1016 return BranchProbability::getBranchProbability(1, NumSucc);
1017 }
1018 };
1019
1020 BranchProbability LatchExitProbability =
1021 ComputeBranchProbability(LatchBlock, LatchExitBlock);
1022
1023 // Protect against degenerate inputs provided by the user. Providing a value
1024 // less than one, can invert the definition of profitable loop predication.
1025 float ScaleFactor = LatchExitProbabilityScale;
1026 if (ScaleFactor < 1) {
1027 LLVM_DEBUG(
1028 dbgs()
1029 << "Ignored user setting for loop-predication-latch-probability-scale: "
1030 << LatchExitProbabilityScale << "\n");
1031 LLVM_DEBUG(dbgs() << "The value is set to 1.0\n");
1032 ScaleFactor = 1.0;
1033 }
1034 const auto LatchProbabilityThreshold = LatchExitProbability * ScaleFactor;
1035
1036 for (const auto &ExitEdge : ExitEdges) {
1037 BranchProbability ExitingBlockProbability =
1038 ComputeBranchProbability(ExitEdge.first, ExitEdge.second);
1039 // Some exiting edge has higher probability than the latch exiting edge.
1040 // No longer profitable to predicate.
1041 if (ExitingBlockProbability > LatchProbabilityThreshold)
1042 return false;
1043 }
1044
1045 // We have concluded that the most probable way to exit from the
1046 // loop is through the latch (or there's no profile information and all
1047 // exits are equally likely).
1048 return true;
1049}
1050
1051/// If we can (cheaply) find a widenable branch which controls entry into the
1052/// loop, return it.
1054 // Walk back through any unconditional executed blocks and see if we can find
1055 // a widenable condition which seems to control execution of this loop. Note
1056 // that we predict that maythrow calls are likely untaken and thus that it's
1057 // profitable to widen a branch before a maythrow call with a condition
1058 // afterwards even though that may cause the slow path to run in a case where
1059 // it wouldn't have otherwise.
1061 if (!BB)
1062 return nullptr;
1063 do {
1064 if (BasicBlock *Pred = BB->getSinglePredecessor())
1065 if (BB == Pred->getSingleSuccessor()) {
1066 BB = Pred;
1067 continue;
1068 }
1069 break;
1070 } while (true);
1071
1072 if (BasicBlock *Pred = BB->getSinglePredecessor()) {
1073 auto *Term = Pred->getTerminator();
1074
1075 Value *Cond, *WC;
1076 BasicBlock *IfTrueBB, *IfFalseBB;
1077 if (parseWidenableBranch(Term, Cond, WC, IfTrueBB, IfFalseBB) &&
1078 IfTrueBB == BB)
1079 return cast<BranchInst>(Term);
1080 }
1081 return nullptr;
1082}
1083
1084/// Return the minimum of all analyzeable exit counts. This is an upper bound
1085/// on the actual exit count. If there are not at least two analyzeable exits,
1086/// returns SCEVCouldNotCompute.
1088 DominatorTree &DT,
1089 Loop *L) {
1090 SmallVector<BasicBlock *, 16> ExitingBlocks;
1091 L->getExitingBlocks(ExitingBlocks);
1092
1094 for (BasicBlock *ExitingBB : ExitingBlocks) {
1095 const SCEV *ExitCount = SE.getExitCount(L, ExitingBB);
1096 if (isa<SCEVCouldNotCompute>(ExitCount))
1097 continue;
1098 assert(DT.dominates(ExitingBB, L->getLoopLatch()) &&
1099 "We should only have known counts for exiting blocks that "
1100 "dominate latch!");
1101 ExitCounts.push_back(ExitCount);
1102 }
1103 if (ExitCounts.size() < 2)
1104 return SE.getCouldNotCompute();
1105 return SE.getUMinFromMismatchedTypes(ExitCounts);
1106}
1107
1108/// This implements an analogous, but entirely distinct transform from the main
1109/// loop predication transform. This one is phrased in terms of using a
1110/// widenable branch *outside* the loop to allow us to simplify loop exits in a
1111/// following loop. This is close in spirit to the IndVarSimplify transform
1112/// of the same name, but is materially different widening loosens legality
1113/// sharply.
1114bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
1115 // The transformation performed here aims to widen a widenable condition
1116 // above the loop such that all analyzeable exit leading to deopt are dead.
1117 // It assumes that the latch is the dominant exit for profitability and that
1118 // exits branching to deoptimizing blocks are rarely taken. It relies on the
1119 // semantics of widenable expressions for legality. (i.e. being able to fall
1120 // down the widenable path spuriously allows us to ignore exit order,
1121 // unanalyzeable exits, side effects, exceptional exits, and other challenges
1122 // which restrict the applicability of the non-WC based version of this
1123 // transform in IndVarSimplify.)
1124 //
1125 // NOTE ON POISON/UNDEF - We're hoisting an expression above guards which may
1126 // imply flags on the expression being hoisted and inserting new uses (flags
1127 // are only correct for current uses). The result is that we may be
1128 // inserting a branch on the value which can be either poison or undef. In
1129 // this case, the branch can legally go either way; we just need to avoid
1130 // introducing UB. This is achieved through the use of the freeze
1131 // instruction.
1132
1133 SmallVector<BasicBlock *, 16> ExitingBlocks;
1134 L->getExitingBlocks(ExitingBlocks);
1135
1136 if (ExitingBlocks.empty())
1137 return false; // Nothing to do.
1138
1139 auto *Latch = L->getLoopLatch();
1140 if (!Latch)
1141 return false;
1142
1143 auto *WidenableBR = FindWidenableTerminatorAboveLoop(L, *LI);
1144 if (!WidenableBR)
1145 return false;
1146
1147 const SCEV *LatchEC = SE->getExitCount(L, Latch);
1148 if (isa<SCEVCouldNotCompute>(LatchEC))
1149 return false; // profitability - want hot exit in analyzeable set
1150
1151 // At this point, we have found an analyzeable latch, and a widenable
1152 // condition above the loop. If we have a widenable exit within the loop
1153 // (for which we can't compute exit counts), drop the ability to further
1154 // widen so that we gain ability to analyze it's exit count and perform this
1155 // transform. TODO: It'd be nice to know for sure the exit became
1156 // analyzeable after dropping widenability.
1157 bool ChangedLoop = false;
1158
1159 for (auto *ExitingBB : ExitingBlocks) {
1160 if (LI->getLoopFor(ExitingBB) != L)
1161 continue;
1162
1163 auto *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
1164 if (!BI)
1165 continue;
1166
1167 Use *Cond, *WC;
1168 BasicBlock *IfTrueBB, *IfFalseBB;
1169 if (parseWidenableBranch(BI, Cond, WC, IfTrueBB, IfFalseBB) &&
1170 L->contains(IfTrueBB)) {
1171 WC->set(ConstantInt::getTrue(IfTrueBB->getContext()));
1172 ChangedLoop = true;
1173 }
1174 }
1175 if (ChangedLoop)
1176 SE->forgetLoop(L);
1177
1178 // The use of umin(all analyzeable exits) instead of latch is subtle, but
1179 // important for profitability. We may have a loop which hasn't been fully
1180 // canonicalized just yet. If the exit we chose to widen is provably never
1181 // taken, we want the widened form to *also* be provably never taken. We
1182 // can't guarantee this as a current unanalyzeable exit may later become
1183 // analyzeable, but we can at least avoid the obvious cases.
1184 const SCEV *MinEC = getMinAnalyzeableBackedgeTakenCount(*SE, *DT, L);
1185 if (isa<SCEVCouldNotCompute>(MinEC) || MinEC->getType()->isPointerTy() ||
1186 !SE->isLoopInvariant(MinEC, L) ||
1187 !Rewriter.isSafeToExpandAt(MinEC, WidenableBR))
1188 return ChangedLoop;
1189
1190 // Subtlety: We need to avoid inserting additional uses of the WC. We know
1191 // that it can only have one transitive use at the moment, and thus moving
1192 // that use to just before the branch and inserting code before it and then
1193 // modifying the operand is legal.
1194 auto *IP = cast<Instruction>(WidenableBR->getCondition());
1195 // Here we unconditionally modify the IR, so after this point we should return
1196 // only true!
1197 IP->moveBefore(WidenableBR);
1198 if (MSSAU)
1199 if (auto *MUD = MSSAU->getMemorySSA()->getMemoryAccess(IP))
1200 MSSAU->moveToPlace(MUD, WidenableBR->getParent(),
1202 Rewriter.setInsertPoint(IP);
1203 IRBuilder<> B(IP);
1204
1205 bool InvalidateLoop = false;
1206 Value *MinECV = nullptr; // lazily generated if needed
1207 for (BasicBlock *ExitingBB : ExitingBlocks) {
1208 // If our exiting block exits multiple loops, we can only rewrite the
1209 // innermost one. Otherwise, we're changing how many times the innermost
1210 // loop runs before it exits.
1211 if (LI->getLoopFor(ExitingBB) != L)
1212 continue;
1213
1214 // Can't rewrite non-branch yet.
1215 auto *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
1216 if (!BI)
1217 continue;
1218
1219 // If already constant, nothing to do.
1220 if (isa<Constant>(BI->getCondition()))
1221 continue;
1222
1223 const SCEV *ExitCount = SE->getExitCount(L, ExitingBB);
1224 if (isa<SCEVCouldNotCompute>(ExitCount) ||
1225 ExitCount->getType()->isPointerTy() ||
1226 !Rewriter.isSafeToExpandAt(ExitCount, WidenableBR))
1227 continue;
1228
1229 const bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB));
1230 BasicBlock *ExitBB = BI->getSuccessor(ExitIfTrue ? 0 : 1);
1231 if (!ExitBB->getPostdominatingDeoptimizeCall())
1232 continue;
1233
1234 /// Here we can be fairly sure that executing this exit will most likely
1235 /// lead to executing llvm.experimental.deoptimize.
1236 /// This is a profitability heuristic, not a legality constraint.
1237
1238 // If we found a widenable exit condition, do two things:
1239 // 1) fold the widened exit test into the widenable condition
1240 // 2) fold the branch to untaken - avoids infinite looping
1241
1242 Value *ECV = Rewriter.expandCodeFor(ExitCount);
1243 if (!MinECV)
1244 MinECV = Rewriter.expandCodeFor(MinEC);
1245 Value *RHS = MinECV;
1246 if (ECV->getType() != RHS->getType()) {
1247 Type *WiderTy = SE->getWiderType(ECV->getType(), RHS->getType());
1248 ECV = B.CreateZExt(ECV, WiderTy);
1249 RHS = B.CreateZExt(RHS, WiderTy);
1250 }
1251 assert(!Latch || DT->dominates(ExitingBB, Latch));
1252 Value *NewCond = B.CreateICmp(ICmpInst::ICMP_UGT, ECV, RHS);
1253 // Freeze poison or undef to an arbitrary bit pattern to ensure we can
1254 // branch without introducing UB. See NOTE ON POISON/UNDEF above for
1255 // context.
1256 NewCond = B.CreateFreeze(NewCond);
1257
1258 widenWidenableBranch(WidenableBR, NewCond);
1259
1260 Value *OldCond = BI->getCondition();
1261 BI->setCondition(ConstantInt::get(OldCond->getType(), !ExitIfTrue));
1262 InvalidateLoop = true;
1263 }
1264
1265 if (InvalidateLoop)
1266 // We just mutated a bunch of loop exits changing there exit counts
1267 // widely. We need to force recomputation of the exit counts given these
1268 // changes. Note that all of the inserted exits are never taken, and
1269 // should be removed next time the CFG is modified.
1270 SE->forgetLoop(L);
1271
1272 // Always return true since we have moved the WidenableBR's condition.
1273 return true;
1274}
1275
1276bool LoopPredication::runOnLoop(Loop *Loop) {
1277 L = Loop;
1278
1279 LLVM_DEBUG(dbgs() << "Analyzing ");
1280 LLVM_DEBUG(L->dump());
1281
1283
1284 // There is nothing to do if the module doesn't use guards
1285 auto *GuardDecl =
1286 M->getFunction(Intrinsic::getName(Intrinsic::experimental_guard));
1287 bool HasIntrinsicGuards = GuardDecl && !GuardDecl->use_empty();
1288 auto *WCDecl = M->getFunction(
1289 Intrinsic::getName(Intrinsic::experimental_widenable_condition));
1290 bool HasWidenableConditions =
1291 PredicateWidenableBranchGuards && WCDecl && !WCDecl->use_empty();
1292 if (!HasIntrinsicGuards && !HasWidenableConditions)
1293 return false;
1294
1295 DL = &M->getDataLayout();
1296
1299 return false;
1300
1301 auto LatchCheckOpt = parseLoopLatchICmp();
1302 if (!LatchCheckOpt)
1303 return false;
1304 LatchCheck = *LatchCheckOpt;
1305
1306 LLVM_DEBUG(dbgs() << "Latch check:\n");
1307 LLVM_DEBUG(LatchCheck.dump());
1308
1309 if (!isLoopProfitableToPredicate()) {
1310 LLVM_DEBUG(dbgs() << "Loop not profitable to predicate!\n");
1311 return false;
1312 }
1313 // Collect all the guards into a vector and process later, so as not
1314 // to invalidate the instruction iterator.
1316 SmallVector<BranchInst *, 4> GuardsAsWidenableBranches;
1317 for (const auto BB : L->blocks()) {
1318 for (auto &I : *BB)
1319 if (isGuard(&I))
1320 Guards.push_back(cast<IntrinsicInst>(&I));
1322 isGuardAsWidenableBranch(BB->getTerminator()))
1323 GuardsAsWidenableBranches.push_back(
1324 cast<BranchInst>(BB->getTerminator()));
1325 }
1326
1327 SCEVExpander Expander(*SE, *DL, "loop-predication");
1328 bool Changed = false;
1329 for (auto *Guard : Guards)
1330 Changed |= widenGuardConditions(Guard, Expander);
1331 for (auto *Guard : GuardsAsWidenableBranches)
1332 Changed |= widenWidenableBranchGuardConditions(Guard, Expander);
1333 Changed |= predicateLoopExits(L, Expander);
1334
1335 if (MSSAU && VerifyMemorySSA)
1336 MSSAU->getMemorySSA()->verifyMemorySSA();
1337 return Changed;
1338}
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
assume Assume Builder
SmallVector< MachineOperand, 4 > Cond
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DEBUG(X)
Definition: Debug.h:101
static cl::opt< bool > SkipProfitabilityChecks("irce-skip-profitability-checks", cl::Hidden, cl::init(false))
static cl::opt< float > LatchExitProbabilityScale("loop-predication-latch-probability-scale", cl::Hidden, cl::init(2.0), cl::desc("scale factor for the latch probability. Value should be greater " "than 1. Lower values are ignored"))
static void normalizePredicate(ScalarEvolution *SE, Loop *L, LoopICmp &RC)
static cl::opt< bool > SkipProfitabilityChecks("loop-predication-skip-profitability-checks", cl::Hidden, cl::init(false))
static const SCEV * getMinAnalyzeableBackedgeTakenCount(ScalarEvolution &SE, DominatorTree &DT, Loop *L)
Return the minimum of all analyzeable exit counts.
static cl::opt< bool > EnableCountDownLoop("loop-predication-enable-count-down-loop", cl::Hidden, cl::init(true))
static cl::opt< bool > EnableIVTruncation("loop-predication-enable-iv-truncation", cl::Hidden, cl::init(true))
static std::optional< LoopICmp > generateLoopLatchCheck(const DataLayout &DL, ScalarEvolution &SE, const LoopICmp LatchCheck, Type *RangeCheckType)
loop predication
static cl::opt< bool > PredicateWidenableBranchGuards("loop-predication-predicate-widenable-branches-to-deopt", cl::Hidden, cl::desc("Whether or not we should predicate guards " "expressed as widenable branches to deoptimize blocks"), cl::init(true))
static bool isSafeToTruncateWideIVType(const DataLayout &DL, ScalarEvolution &SE, const LoopICmp LatchCheck, Type *RangeCheckType)
static cl::opt< bool > InsertAssumesOfPredicatedGuardsConditions("loop-predication-insert-assumes-of-predicated-guards-conditions", cl::Hidden, cl::desc("Whether or not we should insert assumes of conditions of " "predicated guards"), cl::init(true))
static BranchInst * FindWidenableTerminatorAboveLoop(Loop *L, LoopInfo &LI)
If we can (cheaply) find a widenable branch which controls entry into the loop, return it.
#define I(x, y, z)
Definition: MD5.cpp:58
This file exposes an interface to building/using memory SSA to walk memory instructions using a use/d...
Module.h This file contains the declarations for the Module class.
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:55
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:59
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
This file contains the declarations for profiling metadata utility functions.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:167
Virtual Register Rewriter
Definition: VirtRegMap.cpp:237
Value * RHS
Value * LHS
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:620
Represent the analysis usage information of a pass.
Add the specified Pass class to the set of analyses preserved by this pass.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
LLVM Basic Block Representation.
Definition: BasicBlock.h:56
const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
Definition: BasicBlock.cpp:254
const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
Definition: BasicBlock.cpp:293
const BasicBlock * getUniquePredecessor() const
Return the predecessor of this block if it has a unique predecessor block.
Definition: BasicBlock.cpp:301
const BasicBlock * getSingleSuccessor() const
Return the successor of this block if it has a single successor.
Definition: BasicBlock.cpp:323
InstListType::iterator iterator
Instruction iterators...
Definition: BasicBlock.h:87
LLVMContext & getContext() const
Get the context in which this basic block lives.
Definition: BasicBlock.cpp:35
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:127
const CallInst * getPostdominatingDeoptimizeCall() const
Returns the call instruction calling @llvm.experimental.deoptimize that is present either in current ...
Definition: BasicBlock.cpp:196
Conditional or Unconditional Branch instruction.
void setCondition(Value *V)
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
Value * getCondition() const
Legacy analysis pass which computes BranchProbabilityInfo.
static BranchProbability getBranchProbability(uint64_t Numerator, uint64_t Denominator)
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:718
@ ICMP_SLT
signed less than
Definition: InstrTypes.h:747
@ ICMP_SLE
signed less or equal
Definition: InstrTypes.h:748
@ ICMP_UGE
unsigned greater or equal
Definition: InstrTypes.h:742
@ ICMP_UGT
unsigned greater than
Definition: InstrTypes.h:741
@ ICMP_SGT
signed greater than
Definition: InstrTypes.h:745
@ ICMP_ULT
unsigned less than
Definition: InstrTypes.h:743
@ ICMP_NE
not equal
Definition: InstrTypes.h:740
@ ICMP_SGE
signed greater or equal
Definition: InstrTypes.h:746
@ ICMP_ULE
unsigned less or equal
Definition: InstrTypes.h:744
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition: InstrTypes.h:859
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition: InstrTypes.h:832
Predicate getPredicate() const
Return the predicate for this instruction.
Definition: InstrTypes.h:808
Predicate getFlippedStrictnessPredicate() const
For predicate of kind "is X or equal to 0" returns the predicate "is X".
Definition: InstrTypes.h:925
static ConstantInt * getTrue(LLVMContext &Context)
Definition: Constants.cpp:833
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
Definition: Constants.cpp:888
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:110
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:166
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:122
This instruction compares its operands according to the predicate given to the constructor.
bool isEquality() const
Return true if this predicate is either EQ or NE.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2564
const BasicBlock * getParent() const
Definition: Instruction.h:90
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:47
This class provides an interface for updating the loop pass manager based on mutations to the loop ne...
virtual bool runOnLoop(Loop *L, LPPassManager &LPM)=0
bool skipLoop(const Loop *L) const
Optional passes call this function to check whether the pass should be skipped.
Definition: LoopPass.cpp:370
PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &U)
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:547
An analysis that produces MemorySSA for a function.
Definition: MemorySSA.h:936
Legacy analysis pass which computes MemorySSA.
Definition: MemorySSA.h:986
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Pass interface - Implemented by all 'passes'.
Definition: Pass.h:91
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:98
A set of analyses that are preserved following a run of a transformation pass.
Definition: PassManager.h:152
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: PassManager.h:158
This node represents a polynomial recurrence on the trip count of the specified loop.
This class uses information about analyze scalars to rewrite expressions in canonical form.
bool isSafeToExpandAt(const SCEV *S, const Instruction *InsertionPoint) const
Return true if the given expression is safe to expand in the sense that all materialized values are d...
Value * expandCodeFor(const SCEV *SH, Type *Ty, Instruction *I)
Insert code to directly compute the specified SCEV expression into the program.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents an analyzed expression in the program.
bool isOne() const
Return true if the expression is a constant one.
bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
Type * getType() const
Return the LLVM type of this SCEV expression.
The main scalar evolution driver.
bool isKnownPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS Pred X" is monotonically increasing or decreasing,...
const SCEV * getCouldNotCompute()
const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:365
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:450
bool empty() const
Definition: SmallVector.h:94
size_t size() const
Definition: SmallVector.h:91
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:577
void push_back(const T &Elt)
Definition: SmallVector.h:416
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1200
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
bool isPointerTy() const
True if this is an instance of PointerType.
Definition: Type.h:258
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
void setOperand(unsigned i, Value *Val)
Definition: User.h:174
Value * getOperand(unsigned i) const
Definition: User.h:169
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
void dump() const
Support for debugging, callable in GDB: V->dump()
Definition: AsmWriter.cpp:4941
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
StringRef getName(ID id)
Return the LLVM name for an intrinsic, such as "llvm.ppc.altivec.lvx".
Definition: Function.cpp:979
BinaryOp_match< LHS, RHS, Instruction::And > m_And(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:76
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:445
This is an optimization pass for GlobalISel generic memory operations.
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
Definition: Local.cpp:537
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are are tuples (A,...
Definition: STLExtras.h:2430
void widenWidenableBranch(BranchInst *WidenableBR, Value *NewCond)
Given a branch we know is widenable (defined per Analysis/GuardUtils.h), widen it such that condition...
Definition: GuardUtils.cpp:82
Interval::succ_iterator succ_begin(Interval *I)
succ_begin/succ_end - define methods so that Intervals may be used just like BasicBlocks can with the...
Definition: Interval.h:99
bool isGuard(const User *U)
Returns true iff U has semantics of a guard expressed in a form of call of llvm.experimental....
Definition: GuardUtils.cpp:18
Pass * createLoopPredicationPass()
MDNode * getValidBranchWeightMDNode(const Instruction &I)
Get the valid branch weights metadata node.
bool isModSet(const ModRefInfo MRI)
Definition: ModRef.h:48
bool parseWidenableBranch(const User *U, Value *&Condition, Value *&WidenableCondition, BasicBlock *&IfTrueBB, BasicBlock *&IfFalseBB)
If U is widenable branch looking like: cond = ... wc = call i1 @llvm.experimental....
Definition: GuardUtils.cpp:44
bool hasValidBranchWeightMD(const Instruction &I)
Checks if an instructions has valid Branch Weight Metadata.
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
void getLoopAnalysisUsage(AnalysisUsage &AU)
Helper to consistently add the set of standard passes to a loop pass's AnalysisUsage.
Definition: LoopUtils.cpp:141
bool VerifyMemorySSA
Enables verification of MemorySSA.
Definition: MemorySSA.cpp:89
bool isGuardAsWidenableBranch(const User *U)
Returns true iff U has semantics of a guard expressed in a form of a widenable conditional branch to ...
Definition: GuardUtils.cpp:29
bool extractBranchWeights(const MDNode *ProfileData, SmallVectorImpl< uint32_t > &Weights)
Extract branch weights from MD_prof metadata.
PreservedAnalyses getLoopPassPreservedAnalyses()
Returns the minimum set of Analyses that all loop passes must preserve.
auto predecessors(const MachineBasicBlock *BB)
void initializeLoopPredicationLegacyPassPass(PassRegistry &)
unsigned pred_size(const MachineBasicBlock *BB)
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:860
The adaptor from a function pass to a loop pass computes these analyses and makes them available to t...