LLVM 18.0.0git
LoopPredication.cpp
Go to the documentation of this file.
1//===-- LoopPredication.cpp - Guard based loop predication pass -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The LoopPredication pass tries to convert loop variant range checks to loop
10// invariant by widening checks across loop iterations. For example, it will
11// convert
12//
13// for (i = 0; i < n; i++) {
14// guard(i < len);
15// ...
16// }
17//
18// to
19//
20// for (i = 0; i < n; i++) {
21// guard(n - 1 < len);
22// ...
23// }
24//
25// After this transformation the condition of the guard is loop invariant, so
26// loop-unswitch can later unswitch the loop by this condition which basically
27// predicates the loop by the widened condition:
28//
29// if (n - 1 < len)
30// for (i = 0; i < n; i++) {
31// ...
32// }
33// else
34// deoptimize
35//
36// It's tempting to rely on SCEV here, but it has proven to be problematic.
37// Generally the facts SCEV provides about the increment step of add
38// recurrences are true if the backedge of the loop is taken, which implicitly
39// assumes that the guard doesn't fail. Using these facts to optimize the
40// guard results in a circular logic where the guard is optimized under the
41// assumption that it never fails.
42//
43// For example, in the loop below the induction variable will be marked as nuw
44// basing on the guard. Basing on nuw the guard predicate will be considered
45// monotonic. Given a monotonic condition it's tempting to replace the induction
46// variable in the condition with its value on the last iteration. But this
47// transformation is not correct, e.g. e = 4, b = 5 breaks the loop.
48//
49// for (int i = b; i != e; i++)
50// guard(i u< len)
51//
52// One of the ways to reason about this problem is to use an inductive proof
53// approach. Given the loop:
54//
55// if (B(0)) {
56// do {
57// I = PHI(0, I.INC)
58// I.INC = I + Step
59// guard(G(I));
60// } while (B(I));
61// }
62//
63// where B(x) and G(x) are predicates that map integers to booleans, we want a
64// loop invariant expression M such the following program has the same semantics
65// as the above:
66//
67// if (B(0)) {
68// do {
69// I = PHI(0, I.INC)
70// I.INC = I + Step
71// guard(G(0) && M);
72// } while (B(I));
73// }
74//
75// One solution for M is M = forall X . (G(X) && B(X)) => G(X + Step)
76//
77// Informal proof that the transformation above is correct:
78//
79// By the definition of guards we can rewrite the guard condition to:
80// G(I) && G(0) && M
81//
82// Let's prove that for each iteration of the loop:
83// G(0) && M => G(I)
84// And the condition above can be simplified to G(Start) && M.
85//
86// Induction base.
87// G(0) && M => G(0)
88//
89// Induction step. Assuming G(0) && M => G(I) on the subsequent
90// iteration:
91//
92// B(I) is true because it's the backedge condition.
93// G(I) is true because the backedge is guarded by this condition.
94//
95// So M = forall X . (G(X) && B(X)) => G(X + Step) implies G(I + Step).
96//
97// Note that we can use anything stronger than M, i.e. any condition which
98// implies M.
99//
100// When S = 1 (i.e. forward iterating loop), the transformation is supported
101// when:
102// * The loop has a single latch with the condition of the form:
103// B(X) = latchStart + X <pred> latchLimit,
104// where <pred> is u<, u<=, s<, or s<=.
105// * The guard condition is of the form
106// G(X) = guardStart + X u< guardLimit
107//
108// For the ult latch comparison case M is:
109// forall X . guardStart + X u< guardLimit && latchStart + X <u latchLimit =>
110// guardStart + X + 1 u< guardLimit
111//
112// The only way the antecedent can be true and the consequent can be false is
113// if
114// X == guardLimit - 1 - guardStart
115// (and guardLimit is non-zero, but we won't use this latter fact).
116// If X == guardLimit - 1 - guardStart then the second half of the antecedent is
117// latchStart + guardLimit - 1 - guardStart u< latchLimit
118// and its negation is
119// latchStart + guardLimit - 1 - guardStart u>= latchLimit
120//
121// In other words, if
122// latchLimit u<= latchStart + guardLimit - 1 - guardStart
123// then:
124// (the ranges below are written in ConstantRange notation, where [A, B) is the
125// set for (I = A; I != B; I++ /*maywrap*/) yield(I);)
126//
127// forall X . guardStart + X u< guardLimit &&
128// latchStart + X u< latchLimit =>
129// guardStart + X + 1 u< guardLimit
130// == forall X . guardStart + X u< guardLimit &&
131// latchStart + X u< latchStart + guardLimit - 1 - guardStart =>
132// guardStart + X + 1 u< guardLimit
133// == forall X . (guardStart + X) in [0, guardLimit) &&
134// (latchStart + X) in [0, latchStart + guardLimit - 1 - guardStart) =>
135// (guardStart + X + 1) in [0, guardLimit)
136// == forall X . X in [-guardStart, guardLimit - guardStart) &&
137// X in [-latchStart, guardLimit - 1 - guardStart) =>
138// X in [-guardStart - 1, guardLimit - guardStart - 1)
139// == true
140//
141// So the widened condition is:
142// guardStart u< guardLimit &&
143// latchStart + guardLimit - 1 - guardStart u>= latchLimit
144// Similarly for ule condition the widened condition is:
145// guardStart u< guardLimit &&
146// latchStart + guardLimit - 1 - guardStart u> latchLimit
147// For slt condition the widened condition is:
148// guardStart u< guardLimit &&
149// latchStart + guardLimit - 1 - guardStart s>= latchLimit
150// For sle condition the widened condition is:
151// guardStart u< guardLimit &&
152// latchStart + guardLimit - 1 - guardStart s> latchLimit
153//
154// When S = -1 (i.e. reverse iterating loop), the transformation is supported
155// when:
156// * The loop has a single latch with the condition of the form:
157// B(X) = X <pred> latchLimit, where <pred> is u>, u>=, s>, or s>=.
158// * The guard condition is of the form
159// G(X) = X - 1 u< guardLimit
160//
161// For the ugt latch comparison case M is:
162// forall X. X-1 u< guardLimit and X u> latchLimit => X-2 u< guardLimit
163//
164// The only way the antecedent can be true and the consequent can be false is if
165// X == 1.
166// If X == 1 then the second half of the antecedent is
167// 1 u> latchLimit, and its negation is latchLimit u>= 1.
168//
169// So the widened condition is:
170// guardStart u< guardLimit && latchLimit u>= 1.
171// Similarly for sgt condition the widened condition is:
172// guardStart u< guardLimit && latchLimit s>= 1.
173// For uge condition the widened condition is:
174// guardStart u< guardLimit && latchLimit u> 1.
175// For sge condition the widened condition is:
176// guardStart u< guardLimit && latchLimit s> 1.
177//===----------------------------------------------------------------------===//
178
180#include "llvm/ADT/Statistic.h"
190#include "llvm/IR/Function.h"
192#include "llvm/IR/Module.h"
193#include "llvm/IR/PatternMatch.h"
196#include "llvm/Pass.h"
198#include "llvm/Support/Debug.h"
204#include <optional>
205
206#define DEBUG_TYPE "loop-predication"
207
208STATISTIC(TotalConsidered, "Number of guards considered");
209STATISTIC(TotalWidened, "Number of checks widened");
210
211using namespace llvm;
212
213static cl::opt<bool> EnableIVTruncation("loop-predication-enable-iv-truncation",
214 cl::Hidden, cl::init(true));
215
216static cl::opt<bool> EnableCountDownLoop("loop-predication-enable-count-down-loop",
217 cl::Hidden, cl::init(true));
218
219static cl::opt<bool>
220 SkipProfitabilityChecks("loop-predication-skip-profitability-checks",
221 cl::Hidden, cl::init(false));
222
223// This is the scale factor for the latch probability. We use this during
224// profitability analysis to find other exiting blocks that have a much higher
225// probability of exiting the loop instead of loop exiting via latch.
226// This value should be greater than 1 for a sane profitability check.
228 "loop-predication-latch-probability-scale", cl::Hidden, cl::init(2.0),
229 cl::desc("scale factor for the latch probability. Value should be greater "
230 "than 1. Lower values are ignored"));
231
233 "loop-predication-predicate-widenable-branches-to-deopt", cl::Hidden,
234 cl::desc("Whether or not we should predicate guards "
235 "expressed as widenable branches to deoptimize blocks"),
236 cl::init(true));
237
239 "loop-predication-insert-assumes-of-predicated-guards-conditions",
241 cl::desc("Whether or not we should insert assumes of conditions of "
242 "predicated guards"),
243 cl::init(true));
244
245namespace {
246/// Represents an induction variable check:
247/// icmp Pred, <induction variable>, <loop invariant limit>
248struct LoopICmp {
250 const SCEVAddRecExpr *IV;
251 const SCEV *Limit;
252 LoopICmp(ICmpInst::Predicate Pred, const SCEVAddRecExpr *IV,
253 const SCEV *Limit)
254 : Pred(Pred), IV(IV), Limit(Limit) {}
255 LoopICmp() = default;
256 void dump() {
257 dbgs() << "LoopICmp Pred = " << Pred << ", IV = " << *IV
258 << ", Limit = " << *Limit << "\n";
259 }
260};
261
262class LoopPredication {
263 AliasAnalysis *AA;
264 DominatorTree *DT;
265 ScalarEvolution *SE;
266 LoopInfo *LI;
267 MemorySSAUpdater *MSSAU;
268
269 Loop *L;
270 const DataLayout *DL;
271 BasicBlock *Preheader;
272 LoopICmp LatchCheck;
273
274 bool isSupportedStep(const SCEV* Step);
275 std::optional<LoopICmp> parseLoopICmp(ICmpInst *ICI);
276 std::optional<LoopICmp> parseLoopLatchICmp();
277
278 /// Return an insertion point suitable for inserting a safe to speculate
279 /// instruction whose only user will be 'User' which has operands 'Ops'. A
280 /// trivial result would be the at the User itself, but we try to return a
281 /// loop invariant location if possible.
282 Instruction *findInsertPt(Instruction *User, ArrayRef<Value*> Ops);
283 /// Same as above, *except* that this uses the SCEV definition of invariant
284 /// which is that an expression *can be made* invariant via SCEVExpander.
285 /// Thus, this version is only suitable for finding an insert point to be
286 /// passed to SCEVExpander!
287 Instruction *findInsertPt(const SCEVExpander &Expander, Instruction *User,
289
290 /// Return true if the value is known to produce a single fixed value across
291 /// all iterations on which it executes. Note that this does not imply
292 /// speculation safety. That must be established separately.
293 bool isLoopInvariantValue(const SCEV* S);
294
295 Value *expandCheck(SCEVExpander &Expander, Instruction *Guard,
296 ICmpInst::Predicate Pred, const SCEV *LHS,
297 const SCEV *RHS);
298
299 std::optional<Value *> widenICmpRangeCheck(ICmpInst *ICI,
300 SCEVExpander &Expander,
301 Instruction *Guard);
302 std::optional<Value *>
303 widenICmpRangeCheckIncrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck,
304 SCEVExpander &Expander,
305 Instruction *Guard);
306 std::optional<Value *>
307 widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck,
308 SCEVExpander &Expander,
309 Instruction *Guard);
310 void widenChecks(SmallVectorImpl<Value *> &Checks,
311 SmallVectorImpl<Value *> &WidenedChecks,
312 SCEVExpander &Expander, Instruction *Guard);
313 bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander);
314 bool widenWidenableBranchGuardConditions(BranchInst *Guard, SCEVExpander &Expander);
315 // If the loop always exits through another block in the loop, we should not
316 // predicate based on the latch check. For example, the latch check can be a
317 // very coarse grained check and there can be more fine grained exit checks
318 // within the loop.
319 bool isLoopProfitableToPredicate();
320
321 bool predicateLoopExits(Loop *L, SCEVExpander &Rewriter);
322
323public:
324 LoopPredication(AliasAnalysis *AA, DominatorTree *DT, ScalarEvolution *SE,
325 LoopInfo *LI, MemorySSAUpdater *MSSAU)
326 : AA(AA), DT(DT), SE(SE), LI(LI), MSSAU(MSSAU){};
327 bool runOnLoop(Loop *L);
328};
329
330class LoopPredicationLegacyPass : public LoopPass {
331public:
332 static char ID;
333 LoopPredicationLegacyPass() : LoopPass(ID) {
335 }
336
337 void getAnalysisUsage(AnalysisUsage &AU) const override {
341 }
342
343 bool runOnLoop(Loop *L, LPPassManager &LPM) override {
344 if (skipLoop(L))
345 return false;
346 auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
347 auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
348 auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
349 auto *MSSAWP = getAnalysisIfAvailable<MemorySSAWrapperPass>();
350 std::unique_ptr<MemorySSAUpdater> MSSAU;
351 if (MSSAWP)
352 MSSAU = std::make_unique<MemorySSAUpdater>(&MSSAWP->getMSSA());
353 auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
354 LoopPredication LP(AA, DT, SE, LI, MSSAU ? MSSAU.get() : nullptr);
355 return LP.runOnLoop(L);
356 }
357};
358
359char LoopPredicationLegacyPass::ID = 0;
360} // end namespace
361
362INITIALIZE_PASS_BEGIN(LoopPredicationLegacyPass, "loop-predication",
363 "Loop predication", false, false)
366INITIALIZE_PASS_END(LoopPredicationLegacyPass, "loop-predication",
368
370 return new LoopPredicationLegacyPass();
371}
372
375 LPMUpdater &U) {
376 std::unique_ptr<MemorySSAUpdater> MSSAU;
377 if (AR.MSSA)
378 MSSAU = std::make_unique<MemorySSAUpdater>(AR.MSSA);
379 LoopPredication LP(&AR.AA, &AR.DT, &AR.SE, &AR.LI,
380 MSSAU ? MSSAU.get() : nullptr);
381 if (!LP.runOnLoop(&L))
382 return PreservedAnalyses::all();
383
385 if (AR.MSSA)
386 PA.preserve<MemorySSAAnalysis>();
387 return PA;
388}
389
390std::optional<LoopICmp> LoopPredication::parseLoopICmp(ICmpInst *ICI) {
391 auto Pred = ICI->getPredicate();
392 auto *LHS = ICI->getOperand(0);
393 auto *RHS = ICI->getOperand(1);
394
395 const SCEV *LHSS = SE->getSCEV(LHS);
396 if (isa<SCEVCouldNotCompute>(LHSS))
397 return std::nullopt;
398 const SCEV *RHSS = SE->getSCEV(RHS);
399 if (isa<SCEVCouldNotCompute>(RHSS))
400 return std::nullopt;
401
402 // Canonicalize RHS to be loop invariant bound, LHS - a loop computable IV
403 if (SE->isLoopInvariant(LHSS, L)) {
404 std::swap(LHS, RHS);
405 std::swap(LHSS, RHSS);
407 }
408
409 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHSS);
410 if (!AR || AR->getLoop() != L)
411 return std::nullopt;
412
413 return LoopICmp(Pred, AR, RHSS);
414}
415
416Value *LoopPredication::expandCheck(SCEVExpander &Expander,
417 Instruction *Guard,
418 ICmpInst::Predicate Pred, const SCEV *LHS,
419 const SCEV *RHS) {
420 Type *Ty = LHS->getType();
421 assert(Ty == RHS->getType() && "expandCheck operands have different types?");
422
423 if (SE->isLoopInvariant(LHS, L) && SE->isLoopInvariant(RHS, L)) {
424 IRBuilder<> Builder(Guard);
425 if (SE->isLoopEntryGuardedByCond(L, Pred, LHS, RHS))
426 return Builder.getTrue();
427 if (SE->isLoopEntryGuardedByCond(L, ICmpInst::getInversePredicate(Pred),
428 LHS, RHS))
429 return Builder.getFalse();
430 }
431
432 Value *LHSV =
433 Expander.expandCodeFor(LHS, Ty, findInsertPt(Expander, Guard, {LHS}));
434 Value *RHSV =
435 Expander.expandCodeFor(RHS, Ty, findInsertPt(Expander, Guard, {RHS}));
436 IRBuilder<> Builder(findInsertPt(Guard, {LHSV, RHSV}));
437 return Builder.CreateICmp(Pred, LHSV, RHSV);
438}
439
440// Returns true if its safe to truncate the IV to RangeCheckType.
441// When the IV type is wider than the range operand type, we can still do loop
442// predication, by generating SCEVs for the range and latch that are of the
443// same type. We achieve this by generating a SCEV truncate expression for the
444// latch IV. This is done iff truncation of the IV is a safe operation,
445// without loss of information.
446// Another way to achieve this is by generating a wider type SCEV for the
447// range check operand, however, this needs a more involved check that
448// operands do not overflow. This can lead to loss of information when the
449// range operand is of the form: add i32 %offset, %iv. We need to prove that
450// sext(x + y) is same as sext(x) + sext(y).
451// This function returns true if we can safely represent the IV type in
452// the RangeCheckType without loss of information.
454 ScalarEvolution &SE,
455 const LoopICmp LatchCheck,
456 Type *RangeCheckType) {
458 return false;
459 assert(DL.getTypeSizeInBits(LatchCheck.IV->getType()).getFixedValue() >
460 DL.getTypeSizeInBits(RangeCheckType).getFixedValue() &&
461 "Expected latch check IV type to be larger than range check operand "
462 "type!");
463 // The start and end values of the IV should be known. This is to guarantee
464 // that truncating the wide type will not lose information.
465 auto *Limit = dyn_cast<SCEVConstant>(LatchCheck.Limit);
466 auto *Start = dyn_cast<SCEVConstant>(LatchCheck.IV->getStart());
467 if (!Limit || !Start)
468 return false;
469 // This check makes sure that the IV does not change sign during loop
470 // iterations. Consider latchType = i64, LatchStart = 5, Pred = ICMP_SGE,
471 // LatchEnd = 2, rangeCheckType = i32. If it's not a monotonic predicate, the
472 // IV wraps around, and the truncation of the IV would lose the range of
473 // iterations between 2^32 and 2^64.
474 if (!SE.getMonotonicPredicateType(LatchCheck.IV, LatchCheck.Pred))
475 return false;
476 // The active bits should be less than the bits in the RangeCheckType. This
477 // guarantees that truncating the latch check to RangeCheckType is a safe
478 // operation.
479 auto RangeCheckTypeBitSize =
480 DL.getTypeSizeInBits(RangeCheckType).getFixedValue();
481 return Start->getAPInt().getActiveBits() < RangeCheckTypeBitSize &&
482 Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize;
483}
484
485
486// Return an LoopICmp describing a latch check equivlent to LatchCheck but with
487// the requested type if safe to do so. May involve the use of a new IV.
488static std::optional<LoopICmp> generateLoopLatchCheck(const DataLayout &DL,
489 ScalarEvolution &SE,
490 const LoopICmp LatchCheck,
491 Type *RangeCheckType) {
492
493 auto *LatchType = LatchCheck.IV->getType();
494 if (RangeCheckType == LatchType)
495 return LatchCheck;
496 // For now, bail out if latch type is narrower than range type.
497 if (DL.getTypeSizeInBits(LatchType).getFixedValue() <
498 DL.getTypeSizeInBits(RangeCheckType).getFixedValue())
499 return std::nullopt;
500 if (!isSafeToTruncateWideIVType(DL, SE, LatchCheck, RangeCheckType))
501 return std::nullopt;
502 // We can now safely identify the truncated version of the IV and limit for
503 // RangeCheckType.
504 LoopICmp NewLatchCheck;
505 NewLatchCheck.Pred = LatchCheck.Pred;
506 NewLatchCheck.IV = dyn_cast<SCEVAddRecExpr>(
507 SE.getTruncateExpr(LatchCheck.IV, RangeCheckType));
508 if (!NewLatchCheck.IV)
509 return std::nullopt;
510 NewLatchCheck.Limit = SE.getTruncateExpr(LatchCheck.Limit, RangeCheckType);
511 LLVM_DEBUG(dbgs() << "IV of type: " << *LatchType
512 << "can be represented as range check type:"
513 << *RangeCheckType << "\n");
514 LLVM_DEBUG(dbgs() << "LatchCheck.IV: " << *NewLatchCheck.IV << "\n");
515 LLVM_DEBUG(dbgs() << "LatchCheck.Limit: " << *NewLatchCheck.Limit << "\n");
516 return NewLatchCheck;
517}
518
519bool LoopPredication::isSupportedStep(const SCEV* Step) {
520 return Step->isOne() || (Step->isAllOnesValue() && EnableCountDownLoop);
521}
522
523Instruction *LoopPredication::findInsertPt(Instruction *Use,
524 ArrayRef<Value*> Ops) {
525 for (Value *Op : Ops)
526 if (!L->isLoopInvariant(Op))
527 return Use;
528 return Preheader->getTerminator();
529}
530
531Instruction *LoopPredication::findInsertPt(const SCEVExpander &Expander,
534 // Subtlety: SCEV considers things to be invariant if the value produced is
535 // the same across iterations. This is not the same as being able to
536 // evaluate outside the loop, which is what we actually need here.
537 for (const SCEV *Op : Ops)
538 if (!SE->isLoopInvariant(Op, L) ||
539 !Expander.isSafeToExpandAt(Op, Preheader->getTerminator()))
540 return Use;
541 return Preheader->getTerminator();
542}
543
544bool LoopPredication::isLoopInvariantValue(const SCEV* S) {
545 // Handling expressions which produce invariant results, but *haven't* yet
546 // been removed from the loop serves two important purposes.
547 // 1) Most importantly, it resolves a pass ordering cycle which would
548 // otherwise need us to iteration licm, loop-predication, and either
549 // loop-unswitch or loop-peeling to make progress on examples with lots of
550 // predicable range checks in a row. (Since, in the general case, we can't
551 // hoist the length checks until the dominating checks have been discharged
552 // as we can't prove doing so is safe.)
553 // 2) As a nice side effect, this exposes the value of peeling or unswitching
554 // much more obviously in the IR. Otherwise, the cost modeling for other
555 // transforms would end up needing to duplicate all of this logic to model a
556 // check which becomes predictable based on a modeled peel or unswitch.
557 //
558 // The cost of doing so in the worst case is an extra fill from the stack in
559 // the loop to materialize the loop invariant test value instead of checking
560 // against the original IV which is presumable in a register inside the loop.
561 // Such cases are presumably rare, and hint at missing oppurtunities for
562 // other passes.
563
564 if (SE->isLoopInvariant(S, L))
565 // Note: This the SCEV variant, so the original Value* may be within the
566 // loop even though SCEV has proven it is loop invariant.
567 return true;
568
569 // Handle a particular important case which SCEV doesn't yet know about which
570 // shows up in range checks on arrays with immutable lengths.
571 // TODO: This should be sunk inside SCEV.
572 if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S))
573 if (const auto *LI = dyn_cast<LoadInst>(U->getValue()))
574 if (LI->isUnordered() && L->hasLoopInvariantOperands(LI))
575 if (!isModSet(AA->getModRefInfoMask(LI->getOperand(0))) ||
576 LI->hasMetadata(LLVMContext::MD_invariant_load))
577 return true;
578 return false;
579}
580
581std::optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop(
582 LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander,
583 Instruction *Guard) {
584 auto *Ty = RangeCheck.IV->getType();
585 // Generate the widened condition for the forward loop:
586 // guardStart u< guardLimit &&
587 // latchLimit <pred> guardLimit - 1 - guardStart + latchStart
588 // where <pred> depends on the latch condition predicate. See the file
589 // header comment for the reasoning.
590 // guardLimit - guardStart + latchStart - 1
591 const SCEV *GuardStart = RangeCheck.IV->getStart();
592 const SCEV *GuardLimit = RangeCheck.Limit;
593 const SCEV *LatchStart = LatchCheck.IV->getStart();
594 const SCEV *LatchLimit = LatchCheck.Limit;
595 // Subtlety: We need all the values to be *invariant* across all iterations,
596 // but we only need to check expansion safety for those which *aren't*
597 // already guaranteed to dominate the guard.
598 if (!isLoopInvariantValue(GuardStart) ||
599 !isLoopInvariantValue(GuardLimit) ||
600 !isLoopInvariantValue(LatchStart) ||
601 !isLoopInvariantValue(LatchLimit)) {
602 LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
603 return std::nullopt;
604 }
605 if (!Expander.isSafeToExpandAt(LatchStart, Guard) ||
606 !Expander.isSafeToExpandAt(LatchLimit, Guard)) {
607 LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
608 return std::nullopt;
609 }
610
611 // guardLimit - guardStart + latchStart - 1
612 const SCEV *RHS =
613 SE->getAddExpr(SE->getMinusSCEV(GuardLimit, GuardStart),
614 SE->getMinusSCEV(LatchStart, SE->getOne(Ty)));
615 auto LimitCheckPred =
617
618 LLVM_DEBUG(dbgs() << "LHS: " << *LatchLimit << "\n");
619 LLVM_DEBUG(dbgs() << "RHS: " << *RHS << "\n");
620 LLVM_DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n");
621
622 auto *LimitCheck =
623 expandCheck(Expander, Guard, LimitCheckPred, LatchLimit, RHS);
624 auto *FirstIterationCheck = expandCheck(Expander, Guard, RangeCheck.Pred,
625 GuardStart, GuardLimit);
626 IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck}));
627 return Builder.CreateFreeze(
628 Builder.CreateAnd(FirstIterationCheck, LimitCheck));
629}
630
631std::optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(
632 LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander,
633 Instruction *Guard) {
634 auto *Ty = RangeCheck.IV->getType();
635 const SCEV *GuardStart = RangeCheck.IV->getStart();
636 const SCEV *GuardLimit = RangeCheck.Limit;
637 const SCEV *LatchStart = LatchCheck.IV->getStart();
638 const SCEV *LatchLimit = LatchCheck.Limit;
639 // Subtlety: We need all the values to be *invariant* across all iterations,
640 // but we only need to check expansion safety for those which *aren't*
641 // already guaranteed to dominate the guard.
642 if (!isLoopInvariantValue(GuardStart) ||
643 !isLoopInvariantValue(GuardLimit) ||
644 !isLoopInvariantValue(LatchStart) ||
645 !isLoopInvariantValue(LatchLimit)) {
646 LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
647 return std::nullopt;
648 }
649 if (!Expander.isSafeToExpandAt(LatchStart, Guard) ||
650 !Expander.isSafeToExpandAt(LatchLimit, Guard)) {
651 LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
652 return std::nullopt;
653 }
654 // The decrement of the latch check IV should be the same as the
655 // rangeCheckIV.
656 auto *PostDecLatchCheckIV = LatchCheck.IV->getPostIncExpr(*SE);
657 if (RangeCheck.IV != PostDecLatchCheckIV) {
658 LLVM_DEBUG(dbgs() << "Not the same. PostDecLatchCheckIV: "
659 << *PostDecLatchCheckIV
660 << " and RangeCheckIV: " << *RangeCheck.IV << "\n");
661 return std::nullopt;
662 }
663
664 // Generate the widened condition for CountDownLoop:
665 // guardStart u< guardLimit &&
666 // latchLimit <pred> 1.
667 // See the header comment for reasoning of the checks.
668 auto LimitCheckPred =
670 auto *FirstIterationCheck = expandCheck(Expander, Guard,
672 GuardStart, GuardLimit);
673 auto *LimitCheck = expandCheck(Expander, Guard, LimitCheckPred, LatchLimit,
674 SE->getOne(Ty));
675 IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck}));
676 return Builder.CreateFreeze(
677 Builder.CreateAnd(FirstIterationCheck, LimitCheck));
678}
679
681 LoopICmp& RC) {
682 // LFTR canonicalizes checks to the ICMP_NE/EQ form; normalize back to the
683 // ULT/UGE form for ease of handling by our caller.
684 if (ICmpInst::isEquality(RC.Pred) &&
685 RC.IV->getStepRecurrence(*SE)->isOne() &&
686 SE->isKnownPredicate(ICmpInst::ICMP_ULE, RC.IV->getStart(), RC.Limit))
687 RC.Pred = RC.Pred == ICmpInst::ICMP_NE ?
689}
690
691/// If ICI can be widened to a loop invariant condition emits the loop
692/// invariant condition in the loop preheader and return it, otherwise
693/// returns std::nullopt.
694std::optional<Value *>
695LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander,
696 Instruction *Guard) {
697 LLVM_DEBUG(dbgs() << "Analyzing ICmpInst condition:\n");
698 LLVM_DEBUG(ICI->dump());
699
700 // parseLoopStructure guarantees that the latch condition is:
701 // ++i <pred> latchLimit, where <pred> is u<, u<=, s<, or s<=.
702 // We are looking for the range checks of the form:
703 // i u< guardLimit
704 auto RangeCheck = parseLoopICmp(ICI);
705 if (!RangeCheck) {
706 LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n");
707 return std::nullopt;
708 }
709 LLVM_DEBUG(dbgs() << "Guard check:\n");
710 LLVM_DEBUG(RangeCheck->dump());
711 if (RangeCheck->Pred != ICmpInst::ICMP_ULT) {
712 LLVM_DEBUG(dbgs() << "Unsupported range check predicate("
713 << RangeCheck->Pred << ")!\n");
714 return std::nullopt;
715 }
716 auto *RangeCheckIV = RangeCheck->IV;
717 if (!RangeCheckIV->isAffine()) {
718 LLVM_DEBUG(dbgs() << "Range check IV is not affine!\n");
719 return std::nullopt;
720 }
721 auto *Step = RangeCheckIV->getStepRecurrence(*SE);
722 // We cannot just compare with latch IV step because the latch and range IVs
723 // may have different types.
724 if (!isSupportedStep(Step)) {
725 LLVM_DEBUG(dbgs() << "Range check and latch have IVs different steps!\n");
726 return std::nullopt;
727 }
728 auto *Ty = RangeCheckIV->getType();
729 auto CurrLatchCheckOpt = generateLoopLatchCheck(*DL, *SE, LatchCheck, Ty);
730 if (!CurrLatchCheckOpt) {
731 LLVM_DEBUG(dbgs() << "Failed to generate a loop latch check "
732 "corresponding to range type: "
733 << *Ty << "\n");
734 return std::nullopt;
735 }
736
737 LoopICmp CurrLatchCheck = *CurrLatchCheckOpt;
738 // At this point, the range and latch step should have the same type, but need
739 // not have the same value (we support both 1 and -1 steps).
740 assert(Step->getType() ==
741 CurrLatchCheck.IV->getStepRecurrence(*SE)->getType() &&
742 "Range and latch steps should be of same type!");
743 if (Step != CurrLatchCheck.IV->getStepRecurrence(*SE)) {
744 LLVM_DEBUG(dbgs() << "Range and latch have different step values!\n");
745 return std::nullopt;
746 }
747
748 if (Step->isOne())
749 return widenICmpRangeCheckIncrementingLoop(CurrLatchCheck, *RangeCheck,
750 Expander, Guard);
751 else {
752 assert(Step->isAllOnesValue() && "Step should be -1!");
753 return widenICmpRangeCheckDecrementingLoop(CurrLatchCheck, *RangeCheck,
754 Expander, Guard);
755 }
756}
757
758void LoopPredication::widenChecks(SmallVectorImpl<Value *> &Checks,
759 SmallVectorImpl<Value *> &WidenedChecks,
760 SCEVExpander &Expander, Instruction *Guard) {
761 for (auto &Check : Checks)
762 if (ICmpInst *ICI = dyn_cast<ICmpInst>(Check))
763 if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander, Guard)) {
764 WidenedChecks.push_back(Check);
765 Check = *NewRangeCheck;
766 }
767}
768
769bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard,
770 SCEVExpander &Expander) {
771 LLVM_DEBUG(dbgs() << "Processing guard:\n");
772 LLVM_DEBUG(Guard->dump());
773
774 TotalConsidered++;
776 SmallVector<Value *> WidenedChecks;
777 parseWidenableGuard(Guard, Checks);
778 widenChecks(Checks, WidenedChecks, Expander, Guard);
779 if (WidenedChecks.empty())
780 return false;
781
782 TotalWidened += WidenedChecks.size();
783
784 // Emit the new guard condition
785 IRBuilder<> Builder(findInsertPt(Guard, Checks));
786 Value *AllChecks = Builder.CreateAnd(Checks);
787 auto *OldCond = Guard->getOperand(0);
788 Guard->setOperand(0, AllChecks);
790 Builder.SetInsertPoint(&*++BasicBlock::iterator(Guard));
791 Builder.CreateAssumption(OldCond);
792 }
793 RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU);
794
795 LLVM_DEBUG(dbgs() << "Widened checks = " << WidenedChecks.size() << "\n");
796 return true;
797}
798
799bool LoopPredication::widenWidenableBranchGuardConditions(
800 BranchInst *BI, SCEVExpander &Expander) {
801 assert(isGuardAsWidenableBranch(BI) && "Must be!");
802 LLVM_DEBUG(dbgs() << "Processing guard:\n");
803 LLVM_DEBUG(BI->dump());
804
805 TotalConsidered++;
807 SmallVector<Value *> WidenedChecks;
808 parseWidenableGuard(BI, Checks);
809 // At the moment, our matching logic for wideable conditions implicitly
810 // assumes we preserve the form: (br (and Cond, WC())). FIXME
811 auto WC = extractWidenableCondition(BI);
812 Checks.push_back(WC);
813 widenChecks(Checks, WidenedChecks, Expander, BI);
814 if (WidenedChecks.empty())
815 return false;
816
817 TotalWidened += WidenedChecks.size();
818
819 // Emit the new guard condition
820 IRBuilder<> Builder(findInsertPt(BI, Checks));
821 Value *AllChecks = Builder.CreateAnd(Checks);
822 auto *OldCond = BI->getCondition();
823 BI->setCondition(AllChecks);
825 BasicBlock *IfTrueBB = BI->getSuccessor(0);
826 Builder.SetInsertPoint(IfTrueBB, IfTrueBB->getFirstInsertionPt());
827 // If this block has other predecessors, we might not be able to use Cond.
828 // In this case, create a Phi where every other input is `true` and input
829 // from guard block is Cond.
830 Value *AssumeCond = Builder.CreateAnd(WidenedChecks);
831 if (!IfTrueBB->getUniquePredecessor()) {
832 auto *GuardBB = BI->getParent();
833 auto *PN = Builder.CreatePHI(AssumeCond->getType(), pred_size(IfTrueBB),
834 "assume.cond");
835 for (auto *Pred : predecessors(IfTrueBB))
836 PN->addIncoming(Pred == GuardBB ? AssumeCond : Builder.getTrue(), Pred);
837 AssumeCond = PN;
838 }
839 Builder.CreateAssumption(AssumeCond);
840 }
841 RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU);
843 "Stopped being a guard after transform?");
844
845 LLVM_DEBUG(dbgs() << "Widened checks = " << WidenedChecks.size() << "\n");
846 return true;
847}
848
849std::optional<LoopICmp> LoopPredication::parseLoopLatchICmp() {
850 using namespace PatternMatch;
851
852 BasicBlock *LoopLatch = L->getLoopLatch();
853 if (!LoopLatch) {
854 LLVM_DEBUG(dbgs() << "The loop doesn't have a single latch!\n");
855 return std::nullopt;
856 }
857
858 auto *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator());
859 if (!BI || !BI->isConditional()) {
860 LLVM_DEBUG(dbgs() << "Failed to match the latch terminator!\n");
861 return std::nullopt;
862 }
863 BasicBlock *TrueDest = BI->getSuccessor(0);
864 assert(
865 (TrueDest == L->getHeader() || BI->getSuccessor(1) == L->getHeader()) &&
866 "One of the latch's destinations must be the header");
867
868 auto *ICI = dyn_cast<ICmpInst>(BI->getCondition());
869 if (!ICI) {
870 LLVM_DEBUG(dbgs() << "Failed to match the latch condition!\n");
871 return std::nullopt;
872 }
873 auto Result = parseLoopICmp(ICI);
874 if (!Result) {
875 LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n");
876 return std::nullopt;
877 }
878
879 if (TrueDest != L->getHeader())
881
882 // Check affine first, so if it's not we don't try to compute the step
883 // recurrence.
884 if (!Result->IV->isAffine()) {
885 LLVM_DEBUG(dbgs() << "The induction variable is not affine!\n");
886 return std::nullopt;
887 }
888
889 auto *Step = Result->IV->getStepRecurrence(*SE);
890 if (!isSupportedStep(Step)) {
891 LLVM_DEBUG(dbgs() << "Unsupported loop stride(" << *Step << ")!\n");
892 return std::nullopt;
893 }
894
895 auto IsUnsupportedPredicate = [](const SCEV *Step, ICmpInst::Predicate Pred) {
896 if (Step->isOne()) {
897 return Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_SLT &&
898 Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_SLE;
899 } else {
900 assert(Step->isAllOnesValue() && "Step should be -1!");
901 return Pred != ICmpInst::ICMP_UGT && Pred != ICmpInst::ICMP_SGT &&
902 Pred != ICmpInst::ICMP_UGE && Pred != ICmpInst::ICMP_SGE;
903 }
904 };
905
906 normalizePredicate(SE, L, *Result);
907 if (IsUnsupportedPredicate(Step, Result->Pred)) {
908 LLVM_DEBUG(dbgs() << "Unsupported loop latch predicate(" << Result->Pred
909 << ")!\n");
910 return std::nullopt;
911 }
912
913 return Result;
914}
915
916bool LoopPredication::isLoopProfitableToPredicate() {
918 return true;
919
921 L->getExitEdges(ExitEdges);
922 // If there is only one exiting edge in the loop, it is always profitable to
923 // predicate the loop.
924 if (ExitEdges.size() == 1)
925 return true;
926
927 // Calculate the exiting probabilities of all exiting edges from the loop,
928 // starting with the LatchExitProbability.
929 // Heuristic for profitability: If any of the exiting blocks' probability of
930 // exiting the loop is larger than exiting through the latch block, it's not
931 // profitable to predicate the loop.
932 auto *LatchBlock = L->getLoopLatch();
933 assert(LatchBlock && "Should have a single latch at this point!");
934 auto *LatchTerm = LatchBlock->getTerminator();
935 assert(LatchTerm->getNumSuccessors() == 2 &&
936 "expected to be an exiting block with 2 succs!");
937 unsigned LatchBrExitIdx =
938 LatchTerm->getSuccessor(0) == L->getHeader() ? 1 : 0;
939 // We compute branch probabilities without BPI. We do not rely on BPI since
940 // Loop predication is usually run in an LPM and BPI is only preserved
941 // lossily within loop pass managers, while BPI has an inherent notion of
942 // being complete for an entire function.
943
944 // If the latch exits into a deoptimize or an unreachable block, do not
945 // predicate on that latch check.
946 auto *LatchExitBlock = LatchTerm->getSuccessor(LatchBrExitIdx);
947 if (isa<UnreachableInst>(LatchTerm) ||
948 LatchExitBlock->getTerminatingDeoptimizeCall())
949 return false;
950
951 // Latch terminator has no valid profile data, so nothing to check
952 // profitability on.
953 if (!hasValidBranchWeightMD(*LatchTerm))
954 return true;
955
956 auto ComputeBranchProbability =
957 [&](const BasicBlock *ExitingBlock,
958 const BasicBlock *ExitBlock) -> BranchProbability {
959 auto *Term = ExitingBlock->getTerminator();
960 unsigned NumSucc = Term->getNumSuccessors();
961 if (MDNode *ProfileData = getValidBranchWeightMDNode(*Term)) {
962 SmallVector<uint32_t> Weights;
963 extractBranchWeights(ProfileData, Weights);
964 uint64_t Numerator = 0, Denominator = 0;
965 for (auto [i, Weight] : llvm::enumerate(Weights)) {
966 if (Term->getSuccessor(i) == ExitBlock)
967 Numerator += Weight;
968 Denominator += Weight;
969 }
970 // If all weights are zero act as if there was no profile data
971 if (Denominator == 0)
973 return BranchProbability::getBranchProbability(Numerator, Denominator);
974 } else {
975 assert(LatchBlock != ExitingBlock &&
976 "Latch term should always have profile data!");
977 // No profile data, so we choose the weight as 1/num_of_succ(Src)
979 }
980 };
981
982 BranchProbability LatchExitProbability =
983 ComputeBranchProbability(LatchBlock, LatchExitBlock);
984
985 // Protect against degenerate inputs provided by the user. Providing a value
986 // less than one, can invert the definition of profitable loop predication.
987 float ScaleFactor = LatchExitProbabilityScale;
988 if (ScaleFactor < 1) {
990 dbgs()
991 << "Ignored user setting for loop-predication-latch-probability-scale: "
992 << LatchExitProbabilityScale << "\n");
993 LLVM_DEBUG(dbgs() << "The value is set to 1.0\n");
994 ScaleFactor = 1.0;
995 }
996 const auto LatchProbabilityThreshold = LatchExitProbability * ScaleFactor;
997
998 for (const auto &ExitEdge : ExitEdges) {
999 BranchProbability ExitingBlockProbability =
1000 ComputeBranchProbability(ExitEdge.first, ExitEdge.second);
1001 // Some exiting edge has higher probability than the latch exiting edge.
1002 // No longer profitable to predicate.
1003 if (ExitingBlockProbability > LatchProbabilityThreshold)
1004 return false;
1005 }
1006
1007 // We have concluded that the most probable way to exit from the
1008 // loop is through the latch (or there's no profile information and all
1009 // exits are equally likely).
1010 return true;
1011}
1012
1013/// If we can (cheaply) find a widenable branch which controls entry into the
1014/// loop, return it.
1016 // Walk back through any unconditional executed blocks and see if we can find
1017 // a widenable condition which seems to control execution of this loop. Note
1018 // that we predict that maythrow calls are likely untaken and thus that it's
1019 // profitable to widen a branch before a maythrow call with a condition
1020 // afterwards even though that may cause the slow path to run in a case where
1021 // it wouldn't have otherwise.
1022 BasicBlock *BB = L->getLoopPreheader();
1023 if (!BB)
1024 return nullptr;
1025 do {
1026 if (BasicBlock *Pred = BB->getSinglePredecessor())
1027 if (BB == Pred->getSingleSuccessor()) {
1028 BB = Pred;
1029 continue;
1030 }
1031 break;
1032 } while (true);
1033
1034 if (BasicBlock *Pred = BB->getSinglePredecessor()) {
1035 if (auto *BI = dyn_cast<BranchInst>(Pred->getTerminator()))
1036 if (BI->getSuccessor(0) == BB && isWidenableBranch(BI))
1037 return BI;
1038 }
1039 return nullptr;
1040}
1041
1042/// Return the minimum of all analyzeable exit counts. This is an upper bound
1043/// on the actual exit count. If there are not at least two analyzeable exits,
1044/// returns SCEVCouldNotCompute.
1046 DominatorTree &DT,
1047 Loop *L) {
1048 SmallVector<BasicBlock *, 16> ExitingBlocks;
1049 L->getExitingBlocks(ExitingBlocks);
1050
1052 for (BasicBlock *ExitingBB : ExitingBlocks) {
1053 const SCEV *ExitCount = SE.getExitCount(L, ExitingBB);
1054 if (isa<SCEVCouldNotCompute>(ExitCount))
1055 continue;
1056 assert(DT.dominates(ExitingBB, L->getLoopLatch()) &&
1057 "We should only have known counts for exiting blocks that "
1058 "dominate latch!");
1059 ExitCounts.push_back(ExitCount);
1060 }
1061 if (ExitCounts.size() < 2)
1062 return SE.getCouldNotCompute();
1063 return SE.getUMinFromMismatchedTypes(ExitCounts);
1064}
1065
1066/// This implements an analogous, but entirely distinct transform from the main
1067/// loop predication transform. This one is phrased in terms of using a
1068/// widenable branch *outside* the loop to allow us to simplify loop exits in a
1069/// following loop. This is close in spirit to the IndVarSimplify transform
1070/// of the same name, but is materially different widening loosens legality
1071/// sharply.
1072bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
1073 // The transformation performed here aims to widen a widenable condition
1074 // above the loop such that all analyzeable exit leading to deopt are dead.
1075 // It assumes that the latch is the dominant exit for profitability and that
1076 // exits branching to deoptimizing blocks are rarely taken. It relies on the
1077 // semantics of widenable expressions for legality. (i.e. being able to fall
1078 // down the widenable path spuriously allows us to ignore exit order,
1079 // unanalyzeable exits, side effects, exceptional exits, and other challenges
1080 // which restrict the applicability of the non-WC based version of this
1081 // transform in IndVarSimplify.)
1082 //
1083 // NOTE ON POISON/UNDEF - We're hoisting an expression above guards which may
1084 // imply flags on the expression being hoisted and inserting new uses (flags
1085 // are only correct for current uses). The result is that we may be
1086 // inserting a branch on the value which can be either poison or undef. In
1087 // this case, the branch can legally go either way; we just need to avoid
1088 // introducing UB. This is achieved through the use of the freeze
1089 // instruction.
1090
1091 SmallVector<BasicBlock *, 16> ExitingBlocks;
1092 L->getExitingBlocks(ExitingBlocks);
1093
1094 if (ExitingBlocks.empty())
1095 return false; // Nothing to do.
1096
1097 auto *Latch = L->getLoopLatch();
1098 if (!Latch)
1099 return false;
1100
1101 auto *WidenableBR = FindWidenableTerminatorAboveLoop(L, *LI);
1102 if (!WidenableBR)
1103 return false;
1104
1105 const SCEV *LatchEC = SE->getExitCount(L, Latch);
1106 if (isa<SCEVCouldNotCompute>(LatchEC))
1107 return false; // profitability - want hot exit in analyzeable set
1108
1109 // At this point, we have found an analyzeable latch, and a widenable
1110 // condition above the loop. If we have a widenable exit within the loop
1111 // (for which we can't compute exit counts), drop the ability to further
1112 // widen so that we gain ability to analyze it's exit count and perform this
1113 // transform. TODO: It'd be nice to know for sure the exit became
1114 // analyzeable after dropping widenability.
1115 bool ChangedLoop = false;
1116
1117 for (auto *ExitingBB : ExitingBlocks) {
1118 if (LI->getLoopFor(ExitingBB) != L)
1119 continue;
1120
1121 auto *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
1122 if (!BI)
1123 continue;
1124
1125 if (auto WC = extractWidenableCondition(BI))
1126 if (L->contains(BI->getSuccessor(0))) {
1127 assert(WC->hasOneUse() && "Not appropriate widenable branch!");
1128 WC->user_back()->replaceUsesOfWith(
1129 WC, ConstantInt::getTrue(BI->getContext()));
1130 ChangedLoop = true;
1131 }
1132 }
1133 if (ChangedLoop)
1134 SE->forgetLoop(L);
1135
1136 // The insertion point for the widening should be at the widenably call, not
1137 // at the WidenableBR. If we do this at the widenableBR, we can incorrectly
1138 // change a loop-invariant condition to a loop-varying one.
1139 auto *IP = cast<Instruction>(WidenableBR->getCondition());
1140
1141 // The use of umin(all analyzeable exits) instead of latch is subtle, but
1142 // important for profitability. We may have a loop which hasn't been fully
1143 // canonicalized just yet. If the exit we chose to widen is provably never
1144 // taken, we want the widened form to *also* be provably never taken. We
1145 // can't guarantee this as a current unanalyzeable exit may later become
1146 // analyzeable, but we can at least avoid the obvious cases.
1147 const SCEV *MinEC = getMinAnalyzeableBackedgeTakenCount(*SE, *DT, L);
1148 if (isa<SCEVCouldNotCompute>(MinEC) || MinEC->getType()->isPointerTy() ||
1149 !SE->isLoopInvariant(MinEC, L) ||
1150 !Rewriter.isSafeToExpandAt(MinEC, IP))
1151 return ChangedLoop;
1152
1153 Rewriter.setInsertPoint(IP);
1154 IRBuilder<> B(IP);
1155
1156 bool InvalidateLoop = false;
1157 Value *MinECV = nullptr; // lazily generated if needed
1158 for (BasicBlock *ExitingBB : ExitingBlocks) {
1159 // If our exiting block exits multiple loops, we can only rewrite the
1160 // innermost one. Otherwise, we're changing how many times the innermost
1161 // loop runs before it exits.
1162 if (LI->getLoopFor(ExitingBB) != L)
1163 continue;
1164
1165 // Can't rewrite non-branch yet.
1166 auto *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
1167 if (!BI)
1168 continue;
1169
1170 // If already constant, nothing to do.
1171 if (isa<Constant>(BI->getCondition()))
1172 continue;
1173
1174 const SCEV *ExitCount = SE->getExitCount(L, ExitingBB);
1175 if (isa<SCEVCouldNotCompute>(ExitCount) ||
1176 ExitCount->getType()->isPointerTy() ||
1177 !Rewriter.isSafeToExpandAt(ExitCount, WidenableBR))
1178 continue;
1179
1180 const bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB));
1181 BasicBlock *ExitBB = BI->getSuccessor(ExitIfTrue ? 0 : 1);
1182 if (!ExitBB->getPostdominatingDeoptimizeCall())
1183 continue;
1184
1185 /// Here we can be fairly sure that executing this exit will most likely
1186 /// lead to executing llvm.experimental.deoptimize.
1187 /// This is a profitability heuristic, not a legality constraint.
1188
1189 // If we found a widenable exit condition, do two things:
1190 // 1) fold the widened exit test into the widenable condition
1191 // 2) fold the branch to untaken - avoids infinite looping
1192
1193 Value *ECV = Rewriter.expandCodeFor(ExitCount);
1194 if (!MinECV)
1195 MinECV = Rewriter.expandCodeFor(MinEC);
1196 Value *RHS = MinECV;
1197 if (ECV->getType() != RHS->getType()) {
1198 Type *WiderTy = SE->getWiderType(ECV->getType(), RHS->getType());
1199 ECV = B.CreateZExt(ECV, WiderTy);
1200 RHS = B.CreateZExt(RHS, WiderTy);
1201 }
1202 assert(!Latch || DT->dominates(ExitingBB, Latch));
1203 Value *NewCond = B.CreateICmp(ICmpInst::ICMP_UGT, ECV, RHS);
1204 // Freeze poison or undef to an arbitrary bit pattern to ensure we can
1205 // branch without introducing UB. See NOTE ON POISON/UNDEF above for
1206 // context.
1207 NewCond = B.CreateFreeze(NewCond);
1208
1209 widenWidenableBranch(WidenableBR, NewCond);
1210
1211 Value *OldCond = BI->getCondition();
1212 BI->setCondition(ConstantInt::get(OldCond->getType(), !ExitIfTrue));
1213 InvalidateLoop = true;
1214 }
1215
1216 if (InvalidateLoop)
1217 // We just mutated a bunch of loop exits changing there exit counts
1218 // widely. We need to force recomputation of the exit counts given these
1219 // changes. Note that all of the inserted exits are never taken, and
1220 // should be removed next time the CFG is modified.
1221 SE->forgetLoop(L);
1222
1223 // Always return `true` since we have moved the WidenableBR's condition.
1224 return true;
1225}
1226
1227bool LoopPredication::runOnLoop(Loop *Loop) {
1228 L = Loop;
1229
1230 LLVM_DEBUG(dbgs() << "Analyzing ");
1231 LLVM_DEBUG(L->dump());
1232
1233 Module *M = L->getHeader()->getModule();
1234
1235 // There is nothing to do if the module doesn't use guards
1236 auto *GuardDecl =
1237 M->getFunction(Intrinsic::getName(Intrinsic::experimental_guard));
1238 bool HasIntrinsicGuards = GuardDecl && !GuardDecl->use_empty();
1239 auto *WCDecl = M->getFunction(
1240 Intrinsic::getName(Intrinsic::experimental_widenable_condition));
1241 bool HasWidenableConditions =
1242 PredicateWidenableBranchGuards && WCDecl && !WCDecl->use_empty();
1243 if (!HasIntrinsicGuards && !HasWidenableConditions)
1244 return false;
1245
1246 DL = &M->getDataLayout();
1247
1248 Preheader = L->getLoopPreheader();
1249 if (!Preheader)
1250 return false;
1251
1252 auto LatchCheckOpt = parseLoopLatchICmp();
1253 if (!LatchCheckOpt)
1254 return false;
1255 LatchCheck = *LatchCheckOpt;
1256
1257 LLVM_DEBUG(dbgs() << "Latch check:\n");
1258 LLVM_DEBUG(LatchCheck.dump());
1259
1260 if (!isLoopProfitableToPredicate()) {
1261 LLVM_DEBUG(dbgs() << "Loop not profitable to predicate!\n");
1262 return false;
1263 }
1264 // Collect all the guards into a vector and process later, so as not
1265 // to invalidate the instruction iterator.
1267 SmallVector<BranchInst *, 4> GuardsAsWidenableBranches;
1268 for (const auto BB : L->blocks()) {
1269 for (auto &I : *BB)
1270 if (isGuard(&I))
1271 Guards.push_back(cast<IntrinsicInst>(&I));
1273 isGuardAsWidenableBranch(BB->getTerminator()))
1274 GuardsAsWidenableBranches.push_back(
1275 cast<BranchInst>(BB->getTerminator()));
1276 }
1277
1278 SCEVExpander Expander(*SE, *DL, "loop-predication");
1279 bool Changed = false;
1280 for (auto *Guard : Guards)
1281 Changed |= widenGuardConditions(Guard, Expander);
1282 for (auto *Guard : GuardsAsWidenableBranches)
1283 Changed |= widenWidenableBranchGuardConditions(Guard, Expander);
1284 Changed |= predicateLoopExits(L, Expander);
1285
1286 if (MSSAU && VerifyMemorySSA)
1287 MSSAU->getMemorySSA()->verifyMemorySSA();
1288 return Changed;
1289}
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
assume Assume Builder
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DEBUG(X)
Definition: Debug.h:101
#define Check(C,...)
static cl::opt< bool > SkipProfitabilityChecks("irce-skip-profitability-checks", cl::Hidden, cl::init(false))
static cl::opt< float > LatchExitProbabilityScale("loop-predication-latch-probability-scale", cl::Hidden, cl::init(2.0), cl::desc("scale factor for the latch probability. Value should be greater " "than 1. Lower values are ignored"))
static void normalizePredicate(ScalarEvolution *SE, Loop *L, LoopICmp &RC)
static cl::opt< bool > SkipProfitabilityChecks("loop-predication-skip-profitability-checks", cl::Hidden, cl::init(false))
static const SCEV * getMinAnalyzeableBackedgeTakenCount(ScalarEvolution &SE, DominatorTree &DT, Loop *L)
Return the minimum of all analyzeable exit counts.
static cl::opt< bool > EnableCountDownLoop("loop-predication-enable-count-down-loop", cl::Hidden, cl::init(true))
static cl::opt< bool > EnableIVTruncation("loop-predication-enable-iv-truncation", cl::Hidden, cl::init(true))
static std::optional< LoopICmp > generateLoopLatchCheck(const DataLayout &DL, ScalarEvolution &SE, const LoopICmp LatchCheck, Type *RangeCheckType)
loop predication
static cl::opt< bool > PredicateWidenableBranchGuards("loop-predication-predicate-widenable-branches-to-deopt", cl::Hidden, cl::desc("Whether or not we should predicate guards " "expressed as widenable branches to deoptimize blocks"), cl::init(true))
static bool isSafeToTruncateWideIVType(const DataLayout &DL, ScalarEvolution &SE, const LoopICmp LatchCheck, Type *RangeCheckType)
static cl::opt< bool > InsertAssumesOfPredicatedGuardsConditions("loop-predication-insert-assumes-of-predicated-guards-conditions", cl::Hidden, cl::desc("Whether or not we should insert assumes of conditions of " "predicated guards"), cl::init(true))
static BranchInst * FindWidenableTerminatorAboveLoop(Loop *L, LoopInfo &LI)
If we can (cheaply) find a widenable branch which controls entry into the loop, return it.
#define I(x, y, z)
Definition: MD5.cpp:58
This file exposes an interface to building/using memory SSA to walk memory instructions using a use/d...
Module.h This file contains the declarations for the Module class.
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:55
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:59
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
This file contains the declarations for profiling metadata utility functions.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:167
Virtual Register Rewriter
Definition: VirtRegMap.cpp:237
Value * RHS
Value * LHS
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:620
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
LLVM Basic Block Representation.
Definition: BasicBlock.h:56
const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
Definition: BasicBlock.cpp:257
const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
Definition: BasicBlock.cpp:296
const BasicBlock * getUniquePredecessor() const
Return the predecessor of this block if it has a unique predecessor block.
Definition: BasicBlock.cpp:304
const BasicBlock * getSingleSuccessor() const
Return the successor of this block if it has a single successor.
Definition: BasicBlock.cpp:326
InstListType::iterator iterator
Instruction iterators...
Definition: BasicBlock.h:87
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:127
const CallInst * getPostdominatingDeoptimizeCall() const
Returns the call instruction calling @llvm.experimental.deoptimize that is present either in current ...
Definition: BasicBlock.cpp:195
Conditional or Unconditional Branch instruction.
void setCondition(Value *V)
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
Value * getCondition() const
Legacy analysis pass which computes BranchProbabilityInfo.
static BranchProbability getBranchProbability(uint64_t Numerator, uint64_t Denominator)
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:711
@ ICMP_SLT
signed less than
Definition: InstrTypes.h:740
@ ICMP_SLE
signed less or equal
Definition: InstrTypes.h:741
@ ICMP_UGE
unsigned greater or equal
Definition: InstrTypes.h:735
@ ICMP_UGT
unsigned greater than
Definition: InstrTypes.h:734
@ ICMP_SGT
signed greater than
Definition: InstrTypes.h:738
@ ICMP_ULT
unsigned less than
Definition: InstrTypes.h:736
@ ICMP_NE
not equal
Definition: InstrTypes.h:733
@ ICMP_SGE
signed greater or equal
Definition: InstrTypes.h:739
@ ICMP_ULE
unsigned less or equal
Definition: InstrTypes.h:737
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition: InstrTypes.h:863
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition: InstrTypes.h:825
Predicate getPredicate() const
Return the predicate for this instruction.
Definition: InstrTypes.h:801
Predicate getFlippedStrictnessPredicate() const
For predicate of kind "is X or equal to 0" returns the predicate "is X".
Definition: InstrTypes.h:929
static ConstantInt * getTrue(LLVMContext &Context)
Definition: Constants.cpp:833
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
Definition: Constants.cpp:888
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:110
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:166
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:122
This instruction compares its operands according to the predicate given to the constructor.
bool isEquality() const
Return true if this predicate is either EQ or NE.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2625
const BasicBlock * getParent() const
Definition: Instruction.h:90
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:47
This class provides an interface for updating the loop pass manager based on mutations to the loop ne...
virtual bool runOnLoop(Loop *L, LPPassManager &LPM)=0
bool skipLoop(const Loop *L) const
Optional passes call this function to check whether the pass should be skipped.
Definition: LoopPass.cpp:370
PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &U)
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:47
Metadata node.
Definition: Metadata.h:950
An analysis that produces MemorySSA for a function.
Definition: MemorySSA.h:923
Legacy analysis pass which computes MemorySSA.
Definition: MemorySSA.h:975
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Pass interface - Implemented by all 'passes'.
Definition: Pass.h:94
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:98
A set of analyses that are preserved following a run of a transformation pass.
Definition: PassManager.h:152
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: PassManager.h:158
This node represents a polynomial recurrence on the trip count of the specified loop.
This class uses information about analyze scalars to rewrite expressions in canonical form.
bool isSafeToExpandAt(const SCEV *S, const Instruction *InsertionPoint) const
Return true if the given expression is safe to expand in the sense that all materialized values are d...
Value * expandCodeFor(const SCEV *SH, Type *Ty, BasicBlock::iterator I)
Insert code to directly compute the specified SCEV expression into the program.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents an analyzed expression in the program.
bool isOne() const
Return true if the expression is a constant one.
bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
Type * getType() const
Return the LLVM type of this SCEV expression.
The main scalar evolution driver.
bool isKnownPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
const SCEV * getCouldNotCompute()
const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
bool empty() const
Definition: SmallVector.h:94
size_t size() const
Definition: SmallVector.h:91
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:577
void push_back(const T &Elt)
Definition: SmallVector.h:416
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1200
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
bool isPointerTy() const
True if this is an instance of PointerType.
Definition: Type.h:255
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
void setOperand(unsigned i, Value *Val)
Definition: User.h:174
Value * getOperand(unsigned i) const
Definition: User.h:169
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:1069
void dump() const
Support for debugging, callable in GDB: V->dump()
Definition: AsmWriter.cpp:4937
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
StringRef getName(ID id)
Return the LLVM name for an intrinsic, such as "llvm.ppc.altivec.lvx".
Definition: Function.cpp:988
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:445
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
Definition: Local.cpp:529
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are are tuples (A,...
Definition: STLExtras.h:2338
void widenWidenableBranch(BranchInst *WidenableBR, Value *NewCond)
Given a branch we know is widenable (defined per Analysis/GuardUtils.h), widen it such that condition...
Definition: GuardUtils.cpp:82
Interval::succ_iterator succ_begin(Interval *I)
succ_begin/succ_end - define methods so that Intervals may be used just like BasicBlocks can with the...
Definition: Interval.h:99
Value * extractWidenableCondition(const User *U)
Definition: GuardUtils.cpp:151
void parseWidenableGuard(const User *U, llvm::SmallVectorImpl< Value * > &Checks)
Definition: GuardUtils.cpp:138
bool isGuard(const User *U)
Returns true iff U has semantics of a guard expressed in a form of call of llvm.experimental....
Definition: GuardUtils.cpp:18
Pass * createLoopPredicationPass()
MDNode * getValidBranchWeightMDNode(const Instruction &I)
Get the valid branch weights metadata node.
bool isModSet(const ModRefInfo MRI)
Definition: ModRef.h:48
bool hasValidBranchWeightMD(const Instruction &I)
Checks if an instructions has valid Branch Weight Metadata.
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
void getLoopAnalysisUsage(AnalysisUsage &AU)
Helper to consistently add the set of standard passes to a loop pass's AnalysisUsage.
Definition: LoopUtils.cpp:141
bool isWidenableBranch(const User *U)
Returns true iff U is a widenable branch (that is, extractWidenableCondition returns widenable condit...
Definition: GuardUtils.cpp:26
bool VerifyMemorySSA
Enables verification of MemorySSA.
Definition: MemorySSA.cpp:83
bool isGuardAsWidenableBranch(const User *U)
Returns true iff U has semantics of a guard expressed in a form of a widenable conditional branch to ...
Definition: GuardUtils.cpp:33
bool extractBranchWeights(const MDNode *ProfileData, SmallVectorImpl< uint32_t > &Weights)
Extract branch weights from MD_prof metadata.
PreservedAnalyses getLoopPassPreservedAnalyses()
Returns the minimum set of Analyses that all loop passes must preserve.
auto predecessors(const MachineBasicBlock *BB)
void initializeLoopPredicationLegacyPassPass(PassRegistry &)
unsigned pred_size(const MachineBasicBlock *BB)
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:860
The adaptor from a function pass to a loop pass computes these analyses and makes them available to t...