LLVM 20.0.0git
LoopPredication.cpp
Go to the documentation of this file.
1//===-- LoopPredication.cpp - Guard based loop predication pass -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The LoopPredication pass tries to convert loop variant range checks to loop
10// invariant by widening checks across loop iterations. For example, it will
11// convert
12//
13// for (i = 0; i < n; i++) {
14// guard(i < len);
15// ...
16// }
17//
18// to
19//
20// for (i = 0; i < n; i++) {
21// guard(n - 1 < len);
22// ...
23// }
24//
25// After this transformation the condition of the guard is loop invariant, so
26// loop-unswitch can later unswitch the loop by this condition which basically
27// predicates the loop by the widened condition:
28//
29// if (n - 1 < len)
30// for (i = 0; i < n; i++) {
31// ...
32// }
33// else
34// deoptimize
35//
36// It's tempting to rely on SCEV here, but it has proven to be problematic.
37// Generally the facts SCEV provides about the increment step of add
38// recurrences are true if the backedge of the loop is taken, which implicitly
39// assumes that the guard doesn't fail. Using these facts to optimize the
40// guard results in a circular logic where the guard is optimized under the
41// assumption that it never fails.
42//
43// For example, in the loop below the induction variable will be marked as nuw
44// basing on the guard. Basing on nuw the guard predicate will be considered
45// monotonic. Given a monotonic condition it's tempting to replace the induction
46// variable in the condition with its value on the last iteration. But this
47// transformation is not correct, e.g. e = 4, b = 5 breaks the loop.
48//
49// for (int i = b; i != e; i++)
50// guard(i u< len)
51//
52// One of the ways to reason about this problem is to use an inductive proof
53// approach. Given the loop:
54//
55// if (B(0)) {
56// do {
57// I = PHI(0, I.INC)
58// I.INC = I + Step
59// guard(G(I));
60// } while (B(I));
61// }
62//
63// where B(x) and G(x) are predicates that map integers to booleans, we want a
64// loop invariant expression M such the following program has the same semantics
65// as the above:
66//
67// if (B(0)) {
68// do {
69// I = PHI(0, I.INC)
70// I.INC = I + Step
71// guard(G(0) && M);
72// } while (B(I));
73// }
74//
75// One solution for M is M = forall X . (G(X) && B(X)) => G(X + Step)
76//
77// Informal proof that the transformation above is correct:
78//
79// By the definition of guards we can rewrite the guard condition to:
80// G(I) && G(0) && M
81//
82// Let's prove that for each iteration of the loop:
83// G(0) && M => G(I)
84// And the condition above can be simplified to G(Start) && M.
85//
86// Induction base.
87// G(0) && M => G(0)
88//
89// Induction step. Assuming G(0) && M => G(I) on the subsequent
90// iteration:
91//
92// B(I) is true because it's the backedge condition.
93// G(I) is true because the backedge is guarded by this condition.
94//
95// So M = forall X . (G(X) && B(X)) => G(X + Step) implies G(I + Step).
96//
97// Note that we can use anything stronger than M, i.e. any condition which
98// implies M.
99//
100// When S = 1 (i.e. forward iterating loop), the transformation is supported
101// when:
102// * The loop has a single latch with the condition of the form:
103// B(X) = latchStart + X <pred> latchLimit,
104// where <pred> is u<, u<=, s<, or s<=.
105// * The guard condition is of the form
106// G(X) = guardStart + X u< guardLimit
107//
108// For the ult latch comparison case M is:
109// forall X . guardStart + X u< guardLimit && latchStart + X <u latchLimit =>
110// guardStart + X + 1 u< guardLimit
111//
112// The only way the antecedent can be true and the consequent can be false is
113// if
114// X == guardLimit - 1 - guardStart
115// (and guardLimit is non-zero, but we won't use this latter fact).
116// If X == guardLimit - 1 - guardStart then the second half of the antecedent is
117// latchStart + guardLimit - 1 - guardStart u< latchLimit
118// and its negation is
119// latchStart + guardLimit - 1 - guardStart u>= latchLimit
120//
121// In other words, if
122// latchLimit u<= latchStart + guardLimit - 1 - guardStart
123// then:
124// (the ranges below are written in ConstantRange notation, where [A, B) is the
125// set for (I = A; I != B; I++ /*maywrap*/) yield(I);)
126//
127// forall X . guardStart + X u< guardLimit &&
128// latchStart + X u< latchLimit =>
129// guardStart + X + 1 u< guardLimit
130// == forall X . guardStart + X u< guardLimit &&
131// latchStart + X u< latchStart + guardLimit - 1 - guardStart =>
132// guardStart + X + 1 u< guardLimit
133// == forall X . (guardStart + X) in [0, guardLimit) &&
134// (latchStart + X) in [0, latchStart + guardLimit - 1 - guardStart) =>
135// (guardStart + X + 1) in [0, guardLimit)
136// == forall X . X in [-guardStart, guardLimit - guardStart) &&
137// X in [-latchStart, guardLimit - 1 - guardStart) =>
138// X in [-guardStart - 1, guardLimit - guardStart - 1)
139// == true
140//
141// So the widened condition is:
142// guardStart u< guardLimit &&
143// latchStart + guardLimit - 1 - guardStart u>= latchLimit
144// Similarly for ule condition the widened condition is:
145// guardStart u< guardLimit &&
146// latchStart + guardLimit - 1 - guardStart u> latchLimit
147// For slt condition the widened condition is:
148// guardStart u< guardLimit &&
149// latchStart + guardLimit - 1 - guardStart s>= latchLimit
150// For sle condition the widened condition is:
151// guardStart u< guardLimit &&
152// latchStart + guardLimit - 1 - guardStart s> latchLimit
153//
154// When S = -1 (i.e. reverse iterating loop), the transformation is supported
155// when:
156// * The loop has a single latch with the condition of the form:
157// B(X) = X <pred> latchLimit, where <pred> is u>, u>=, s>, or s>=.
158// * The guard condition is of the form
159// G(X) = X - 1 u< guardLimit
160//
161// For the ugt latch comparison case M is:
162// forall X. X-1 u< guardLimit and X u> latchLimit => X-2 u< guardLimit
163//
164// The only way the antecedent can be true and the consequent can be false is if
165// X == 1.
166// If X == 1 then the second half of the antecedent is
167// 1 u> latchLimit, and its negation is latchLimit u>= 1.
168//
169// So the widened condition is:
170// guardStart u< guardLimit && latchLimit u>= 1.
171// Similarly for sgt condition the widened condition is:
172// guardStart u< guardLimit && latchLimit s>= 1.
173// For uge condition the widened condition is:
174// guardStart u< guardLimit && latchLimit u> 1.
175// For sge condition the widened condition is:
176// guardStart u< guardLimit && latchLimit s> 1.
177//===----------------------------------------------------------------------===//
178
180#include "llvm/ADT/Statistic.h"
190#include "llvm/IR/Function.h"
192#include "llvm/IR/Module.h"
193#include "llvm/IR/PatternMatch.h"
195#include "llvm/Pass.h"
197#include "llvm/Support/Debug.h"
203#include <optional>
204
205#define DEBUG_TYPE "loop-predication"
206
207STATISTIC(TotalConsidered, "Number of guards considered");
208STATISTIC(TotalWidened, "Number of checks widened");
209
210using namespace llvm;
211
212static cl::opt<bool> EnableIVTruncation("loop-predication-enable-iv-truncation",
213 cl::Hidden, cl::init(true));
214
215static cl::opt<bool> EnableCountDownLoop("loop-predication-enable-count-down-loop",
216 cl::Hidden, cl::init(true));
217
218static cl::opt<bool>
219 SkipProfitabilityChecks("loop-predication-skip-profitability-checks",
220 cl::Hidden, cl::init(false));
221
222// This is the scale factor for the latch probability. We use this during
223// profitability analysis to find other exiting blocks that have a much higher
224// probability of exiting the loop instead of loop exiting via latch.
225// This value should be greater than 1 for a sane profitability check.
227 "loop-predication-latch-probability-scale", cl::Hidden, cl::init(2.0),
228 cl::desc("scale factor for the latch probability. Value should be greater "
229 "than 1. Lower values are ignored"));
230
232 "loop-predication-predicate-widenable-branches-to-deopt", cl::Hidden,
233 cl::desc("Whether or not we should predicate guards "
234 "expressed as widenable branches to deoptimize blocks"),
235 cl::init(true));
236
238 "loop-predication-insert-assumes-of-predicated-guards-conditions",
240 cl::desc("Whether or not we should insert assumes of conditions of "
241 "predicated guards"),
242 cl::init(true));
243
244namespace {
245/// Represents an induction variable check:
246/// icmp Pred, <induction variable>, <loop invariant limit>
247struct LoopICmp {
249 const SCEVAddRecExpr *IV;
250 const SCEV *Limit;
251 LoopICmp(ICmpInst::Predicate Pred, const SCEVAddRecExpr *IV,
252 const SCEV *Limit)
253 : Pred(Pred), IV(IV), Limit(Limit) {}
254 LoopICmp() = default;
255 void dump() {
256 dbgs() << "LoopICmp Pred = " << Pred << ", IV = " << *IV
257 << ", Limit = " << *Limit << "\n";
258 }
259};
260
261class LoopPredication {
262 AliasAnalysis *AA;
263 DominatorTree *DT;
264 ScalarEvolution *SE;
265 LoopInfo *LI;
266 MemorySSAUpdater *MSSAU;
267
268 Loop *L;
269 const DataLayout *DL;
270 BasicBlock *Preheader;
271 LoopICmp LatchCheck;
272
273 bool isSupportedStep(const SCEV* Step);
274 std::optional<LoopICmp> parseLoopICmp(ICmpInst *ICI);
275 std::optional<LoopICmp> parseLoopLatchICmp();
276
277 /// Return an insertion point suitable for inserting a safe to speculate
278 /// instruction whose only user will be 'User' which has operands 'Ops'. A
279 /// trivial result would be the at the User itself, but we try to return a
280 /// loop invariant location if possible.
281 Instruction *findInsertPt(Instruction *User, ArrayRef<Value*> Ops);
282 /// Same as above, *except* that this uses the SCEV definition of invariant
283 /// which is that an expression *can be made* invariant via SCEVExpander.
284 /// Thus, this version is only suitable for finding an insert point to be
285 /// passed to SCEVExpander!
286 Instruction *findInsertPt(const SCEVExpander &Expander, Instruction *User,
288
289 /// Return true if the value is known to produce a single fixed value across
290 /// all iterations on which it executes. Note that this does not imply
291 /// speculation safety. That must be established separately.
292 bool isLoopInvariantValue(const SCEV* S);
293
294 Value *expandCheck(SCEVExpander &Expander, Instruction *Guard,
295 ICmpInst::Predicate Pred, const SCEV *LHS,
296 const SCEV *RHS);
297
298 std::optional<Value *> widenICmpRangeCheck(ICmpInst *ICI,
299 SCEVExpander &Expander,
300 Instruction *Guard);
301 std::optional<Value *>
302 widenICmpRangeCheckIncrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck,
303 SCEVExpander &Expander,
304 Instruction *Guard);
305 std::optional<Value *>
306 widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck,
307 SCEVExpander &Expander,
308 Instruction *Guard);
309 void widenChecks(SmallVectorImpl<Value *> &Checks,
310 SmallVectorImpl<Value *> &WidenedChecks,
311 SCEVExpander &Expander, Instruction *Guard);
312 bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander);
313 bool widenWidenableBranchGuardConditions(BranchInst *Guard, SCEVExpander &Expander);
314 // If the loop always exits through another block in the loop, we should not
315 // predicate based on the latch check. For example, the latch check can be a
316 // very coarse grained check and there can be more fine grained exit checks
317 // within the loop.
318 bool isLoopProfitableToPredicate();
319
320 bool predicateLoopExits(Loop *L, SCEVExpander &Rewriter);
321
322public:
323 LoopPredication(AliasAnalysis *AA, DominatorTree *DT, ScalarEvolution *SE,
324 LoopInfo *LI, MemorySSAUpdater *MSSAU)
325 : AA(AA), DT(DT), SE(SE), LI(LI), MSSAU(MSSAU){};
326 bool runOnLoop(Loop *L);
327};
328
329} // end namespace
330
333 LPMUpdater &U) {
334 std::unique_ptr<MemorySSAUpdater> MSSAU;
335 if (AR.MSSA)
336 MSSAU = std::make_unique<MemorySSAUpdater>(AR.MSSA);
337 LoopPredication LP(&AR.AA, &AR.DT, &AR.SE, &AR.LI,
338 MSSAU ? MSSAU.get() : nullptr);
339 if (!LP.runOnLoop(&L))
340 return PreservedAnalyses::all();
341
343 if (AR.MSSA)
344 PA.preserve<MemorySSAAnalysis>();
345 return PA;
346}
347
348std::optional<LoopICmp> LoopPredication::parseLoopICmp(ICmpInst *ICI) {
349 auto Pred = ICI->getPredicate();
350 auto *LHS = ICI->getOperand(0);
351 auto *RHS = ICI->getOperand(1);
352
353 const SCEV *LHSS = SE->getSCEV(LHS);
354 if (isa<SCEVCouldNotCompute>(LHSS))
355 return std::nullopt;
356 const SCEV *RHSS = SE->getSCEV(RHS);
357 if (isa<SCEVCouldNotCompute>(RHSS))
358 return std::nullopt;
359
360 // Canonicalize RHS to be loop invariant bound, LHS - a loop computable IV
361 if (SE->isLoopInvariant(LHSS, L)) {
362 std::swap(LHS, RHS);
363 std::swap(LHSS, RHSS);
365 }
366
367 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHSS);
368 if (!AR || AR->getLoop() != L)
369 return std::nullopt;
370
371 return LoopICmp(Pred, AR, RHSS);
372}
373
374Value *LoopPredication::expandCheck(SCEVExpander &Expander,
375 Instruction *Guard,
376 ICmpInst::Predicate Pred, const SCEV *LHS,
377 const SCEV *RHS) {
378 Type *Ty = LHS->getType();
379 assert(Ty == RHS->getType() && "expandCheck operands have different types?");
380
381 if (SE->isLoopInvariant(LHS, L) && SE->isLoopInvariant(RHS, L)) {
382 IRBuilder<> Builder(Guard);
383 if (SE->isLoopEntryGuardedByCond(L, Pred, LHS, RHS))
384 return Builder.getTrue();
385 if (SE->isLoopEntryGuardedByCond(L, ICmpInst::getInversePredicate(Pred),
386 LHS, RHS))
387 return Builder.getFalse();
388 }
389
390 Value *LHSV =
391 Expander.expandCodeFor(LHS, Ty, findInsertPt(Expander, Guard, {LHS}));
392 Value *RHSV =
393 Expander.expandCodeFor(RHS, Ty, findInsertPt(Expander, Guard, {RHS}));
394 IRBuilder<> Builder(findInsertPt(Guard, {LHSV, RHSV}));
395 return Builder.CreateICmp(Pred, LHSV, RHSV);
396}
397
398// Returns true if its safe to truncate the IV to RangeCheckType.
399// When the IV type is wider than the range operand type, we can still do loop
400// predication, by generating SCEVs for the range and latch that are of the
401// same type. We achieve this by generating a SCEV truncate expression for the
402// latch IV. This is done iff truncation of the IV is a safe operation,
403// without loss of information.
404// Another way to achieve this is by generating a wider type SCEV for the
405// range check operand, however, this needs a more involved check that
406// operands do not overflow. This can lead to loss of information when the
407// range operand is of the form: add i32 %offset, %iv. We need to prove that
408// sext(x + y) is same as sext(x) + sext(y).
409// This function returns true if we can safely represent the IV type in
410// the RangeCheckType without loss of information.
412 ScalarEvolution &SE,
413 const LoopICmp LatchCheck,
414 Type *RangeCheckType) {
416 return false;
417 assert(DL.getTypeSizeInBits(LatchCheck.IV->getType()).getFixedValue() >
418 DL.getTypeSizeInBits(RangeCheckType).getFixedValue() &&
419 "Expected latch check IV type to be larger than range check operand "
420 "type!");
421 // The start and end values of the IV should be known. This is to guarantee
422 // that truncating the wide type will not lose information.
423 auto *Limit = dyn_cast<SCEVConstant>(LatchCheck.Limit);
424 auto *Start = dyn_cast<SCEVConstant>(LatchCheck.IV->getStart());
425 if (!Limit || !Start)
426 return false;
427 // This check makes sure that the IV does not change sign during loop
428 // iterations. Consider latchType = i64, LatchStart = 5, Pred = ICMP_SGE,
429 // LatchEnd = 2, rangeCheckType = i32. If it's not a monotonic predicate, the
430 // IV wraps around, and the truncation of the IV would lose the range of
431 // iterations between 2^32 and 2^64.
432 if (!SE.getMonotonicPredicateType(LatchCheck.IV, LatchCheck.Pred))
433 return false;
434 // The active bits should be less than the bits in the RangeCheckType. This
435 // guarantees that truncating the latch check to RangeCheckType is a safe
436 // operation.
437 auto RangeCheckTypeBitSize =
438 DL.getTypeSizeInBits(RangeCheckType).getFixedValue();
439 return Start->getAPInt().getActiveBits() < RangeCheckTypeBitSize &&
440 Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize;
441}
442
443
444// Return an LoopICmp describing a latch check equivlent to LatchCheck but with
445// the requested type if safe to do so. May involve the use of a new IV.
446static std::optional<LoopICmp> generateLoopLatchCheck(const DataLayout &DL,
447 ScalarEvolution &SE,
448 const LoopICmp LatchCheck,
449 Type *RangeCheckType) {
450
451 auto *LatchType = LatchCheck.IV->getType();
452 if (RangeCheckType == LatchType)
453 return LatchCheck;
454 // For now, bail out if latch type is narrower than range type.
455 if (DL.getTypeSizeInBits(LatchType).getFixedValue() <
456 DL.getTypeSizeInBits(RangeCheckType).getFixedValue())
457 return std::nullopt;
458 if (!isSafeToTruncateWideIVType(DL, SE, LatchCheck, RangeCheckType))
459 return std::nullopt;
460 // We can now safely identify the truncated version of the IV and limit for
461 // RangeCheckType.
462 LoopICmp NewLatchCheck;
463 NewLatchCheck.Pred = LatchCheck.Pred;
464 NewLatchCheck.IV = dyn_cast<SCEVAddRecExpr>(
465 SE.getTruncateExpr(LatchCheck.IV, RangeCheckType));
466 if (!NewLatchCheck.IV)
467 return std::nullopt;
468 NewLatchCheck.Limit = SE.getTruncateExpr(LatchCheck.Limit, RangeCheckType);
469 LLVM_DEBUG(dbgs() << "IV of type: " << *LatchType
470 << "can be represented as range check type:"
471 << *RangeCheckType << "\n");
472 LLVM_DEBUG(dbgs() << "LatchCheck.IV: " << *NewLatchCheck.IV << "\n");
473 LLVM_DEBUG(dbgs() << "LatchCheck.Limit: " << *NewLatchCheck.Limit << "\n");
474 return NewLatchCheck;
475}
476
477bool LoopPredication::isSupportedStep(const SCEV* Step) {
478 return Step->isOne() || (Step->isAllOnesValue() && EnableCountDownLoop);
479}
480
481Instruction *LoopPredication::findInsertPt(Instruction *Use,
482 ArrayRef<Value*> Ops) {
483 for (Value *Op : Ops)
484 if (!L->isLoopInvariant(Op))
485 return Use;
486 return Preheader->getTerminator();
487}
488
489Instruction *LoopPredication::findInsertPt(const SCEVExpander &Expander,
492 // Subtlety: SCEV considers things to be invariant if the value produced is
493 // the same across iterations. This is not the same as being able to
494 // evaluate outside the loop, which is what we actually need here.
495 for (const SCEV *Op : Ops)
496 if (!SE->isLoopInvariant(Op, L) ||
497 !Expander.isSafeToExpandAt(Op, Preheader->getTerminator()))
498 return Use;
499 return Preheader->getTerminator();
500}
501
502bool LoopPredication::isLoopInvariantValue(const SCEV* S) {
503 // Handling expressions which produce invariant results, but *haven't* yet
504 // been removed from the loop serves two important purposes.
505 // 1) Most importantly, it resolves a pass ordering cycle which would
506 // otherwise need us to iteration licm, loop-predication, and either
507 // loop-unswitch or loop-peeling to make progress on examples with lots of
508 // predicable range checks in a row. (Since, in the general case, we can't
509 // hoist the length checks until the dominating checks have been discharged
510 // as we can't prove doing so is safe.)
511 // 2) As a nice side effect, this exposes the value of peeling or unswitching
512 // much more obviously in the IR. Otherwise, the cost modeling for other
513 // transforms would end up needing to duplicate all of this logic to model a
514 // check which becomes predictable based on a modeled peel or unswitch.
515 //
516 // The cost of doing so in the worst case is an extra fill from the stack in
517 // the loop to materialize the loop invariant test value instead of checking
518 // against the original IV which is presumable in a register inside the loop.
519 // Such cases are presumably rare, and hint at missing oppurtunities for
520 // other passes.
521
522 if (SE->isLoopInvariant(S, L))
523 // Note: This the SCEV variant, so the original Value* may be within the
524 // loop even though SCEV has proven it is loop invariant.
525 return true;
526
527 // Handle a particular important case which SCEV doesn't yet know about which
528 // shows up in range checks on arrays with immutable lengths.
529 // TODO: This should be sunk inside SCEV.
530 if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S))
531 if (const auto *LI = dyn_cast<LoadInst>(U->getValue()))
532 if (LI->isUnordered() && L->hasLoopInvariantOperands(LI))
533 if (!isModSet(AA->getModRefInfoMask(LI->getOperand(0))) ||
534 LI->hasMetadata(LLVMContext::MD_invariant_load))
535 return true;
536 return false;
537}
538
539std::optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop(
540 LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander,
541 Instruction *Guard) {
542 auto *Ty = RangeCheck.IV->getType();
543 // Generate the widened condition for the forward loop:
544 // guardStart u< guardLimit &&
545 // latchLimit <pred> guardLimit - 1 - guardStart + latchStart
546 // where <pred> depends on the latch condition predicate. See the file
547 // header comment for the reasoning.
548 // guardLimit - guardStart + latchStart - 1
549 const SCEV *GuardStart = RangeCheck.IV->getStart();
550 const SCEV *GuardLimit = RangeCheck.Limit;
551 const SCEV *LatchStart = LatchCheck.IV->getStart();
552 const SCEV *LatchLimit = LatchCheck.Limit;
553 // Subtlety: We need all the values to be *invariant* across all iterations,
554 // but we only need to check expansion safety for those which *aren't*
555 // already guaranteed to dominate the guard.
556 if (!isLoopInvariantValue(GuardStart) ||
557 !isLoopInvariantValue(GuardLimit) ||
558 !isLoopInvariantValue(LatchStart) ||
559 !isLoopInvariantValue(LatchLimit)) {
560 LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
561 return std::nullopt;
562 }
563 if (!Expander.isSafeToExpandAt(LatchStart, Guard) ||
564 !Expander.isSafeToExpandAt(LatchLimit, Guard)) {
565 LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
566 return std::nullopt;
567 }
568
569 // guardLimit - guardStart + latchStart - 1
570 const SCEV *RHS =
571 SE->getAddExpr(SE->getMinusSCEV(GuardLimit, GuardStart),
572 SE->getMinusSCEV(LatchStart, SE->getOne(Ty)));
573 auto LimitCheckPred =
575
576 LLVM_DEBUG(dbgs() << "LHS: " << *LatchLimit << "\n");
577 LLVM_DEBUG(dbgs() << "RHS: " << *RHS << "\n");
578 LLVM_DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n");
579
580 auto *LimitCheck =
581 expandCheck(Expander, Guard, LimitCheckPred, LatchLimit, RHS);
582 auto *FirstIterationCheck = expandCheck(Expander, Guard, RangeCheck.Pred,
583 GuardStart, GuardLimit);
584 IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck}));
585 return Builder.CreateFreeze(
586 Builder.CreateAnd(FirstIterationCheck, LimitCheck));
587}
588
589std::optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(
590 LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander,
591 Instruction *Guard) {
592 auto *Ty = RangeCheck.IV->getType();
593 const SCEV *GuardStart = RangeCheck.IV->getStart();
594 const SCEV *GuardLimit = RangeCheck.Limit;
595 const SCEV *LatchStart = LatchCheck.IV->getStart();
596 const SCEV *LatchLimit = LatchCheck.Limit;
597 // Subtlety: We need all the values to be *invariant* across all iterations,
598 // but we only need to check expansion safety for those which *aren't*
599 // already guaranteed to dominate the guard.
600 if (!isLoopInvariantValue(GuardStart) ||
601 !isLoopInvariantValue(GuardLimit) ||
602 !isLoopInvariantValue(LatchStart) ||
603 !isLoopInvariantValue(LatchLimit)) {
604 LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
605 return std::nullopt;
606 }
607 if (!Expander.isSafeToExpandAt(LatchStart, Guard) ||
608 !Expander.isSafeToExpandAt(LatchLimit, Guard)) {
609 LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
610 return std::nullopt;
611 }
612 // The decrement of the latch check IV should be the same as the
613 // rangeCheckIV.
614 auto *PostDecLatchCheckIV = LatchCheck.IV->getPostIncExpr(*SE);
615 if (RangeCheck.IV != PostDecLatchCheckIV) {
616 LLVM_DEBUG(dbgs() << "Not the same. PostDecLatchCheckIV: "
617 << *PostDecLatchCheckIV
618 << " and RangeCheckIV: " << *RangeCheck.IV << "\n");
619 return std::nullopt;
620 }
621
622 // Generate the widened condition for CountDownLoop:
623 // guardStart u< guardLimit &&
624 // latchLimit <pred> 1.
625 // See the header comment for reasoning of the checks.
626 auto LimitCheckPred =
628 auto *FirstIterationCheck = expandCheck(Expander, Guard,
630 GuardStart, GuardLimit);
631 auto *LimitCheck = expandCheck(Expander, Guard, LimitCheckPred, LatchLimit,
632 SE->getOne(Ty));
633 IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck}));
634 return Builder.CreateFreeze(
635 Builder.CreateAnd(FirstIterationCheck, LimitCheck));
636}
637
639 LoopICmp& RC) {
640 // LFTR canonicalizes checks to the ICMP_NE/EQ form; normalize back to the
641 // ULT/UGE form for ease of handling by our caller.
642 if (ICmpInst::isEquality(RC.Pred) &&
643 RC.IV->getStepRecurrence(*SE)->isOne() &&
644 SE->isKnownPredicate(ICmpInst::ICMP_ULE, RC.IV->getStart(), RC.Limit))
645 RC.Pred = RC.Pred == ICmpInst::ICMP_NE ?
647}
648
649/// If ICI can be widened to a loop invariant condition emits the loop
650/// invariant condition in the loop preheader and return it, otherwise
651/// returns std::nullopt.
652std::optional<Value *>
653LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander,
654 Instruction *Guard) {
655 LLVM_DEBUG(dbgs() << "Analyzing ICmpInst condition:\n");
656 LLVM_DEBUG(ICI->dump());
657
658 // parseLoopStructure guarantees that the latch condition is:
659 // ++i <pred> latchLimit, where <pred> is u<, u<=, s<, or s<=.
660 // We are looking for the range checks of the form:
661 // i u< guardLimit
662 auto RangeCheck = parseLoopICmp(ICI);
663 if (!RangeCheck) {
664 LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n");
665 return std::nullopt;
666 }
667 LLVM_DEBUG(dbgs() << "Guard check:\n");
668 LLVM_DEBUG(RangeCheck->dump());
669 if (RangeCheck->Pred != ICmpInst::ICMP_ULT) {
670 LLVM_DEBUG(dbgs() << "Unsupported range check predicate("
671 << RangeCheck->Pred << ")!\n");
672 return std::nullopt;
673 }
674 auto *RangeCheckIV = RangeCheck->IV;
675 if (!RangeCheckIV->isAffine()) {
676 LLVM_DEBUG(dbgs() << "Range check IV is not affine!\n");
677 return std::nullopt;
678 }
679 const SCEV *Step = RangeCheckIV->getStepRecurrence(*SE);
680 // We cannot just compare with latch IV step because the latch and range IVs
681 // may have different types.
682 if (!isSupportedStep(Step)) {
683 LLVM_DEBUG(dbgs() << "Range check and latch have IVs different steps!\n");
684 return std::nullopt;
685 }
686 auto *Ty = RangeCheckIV->getType();
687 auto CurrLatchCheckOpt = generateLoopLatchCheck(*DL, *SE, LatchCheck, Ty);
688 if (!CurrLatchCheckOpt) {
689 LLVM_DEBUG(dbgs() << "Failed to generate a loop latch check "
690 "corresponding to range type: "
691 << *Ty << "\n");
692 return std::nullopt;
693 }
694
695 LoopICmp CurrLatchCheck = *CurrLatchCheckOpt;
696 // At this point, the range and latch step should have the same type, but need
697 // not have the same value (we support both 1 and -1 steps).
698 assert(Step->getType() ==
699 CurrLatchCheck.IV->getStepRecurrence(*SE)->getType() &&
700 "Range and latch steps should be of same type!");
701 if (Step != CurrLatchCheck.IV->getStepRecurrence(*SE)) {
702 LLVM_DEBUG(dbgs() << "Range and latch have different step values!\n");
703 return std::nullopt;
704 }
705
706 if (Step->isOne())
707 return widenICmpRangeCheckIncrementingLoop(CurrLatchCheck, *RangeCheck,
708 Expander, Guard);
709 else {
710 assert(Step->isAllOnesValue() && "Step should be -1!");
711 return widenICmpRangeCheckDecrementingLoop(CurrLatchCheck, *RangeCheck,
712 Expander, Guard);
713 }
714}
715
716void LoopPredication::widenChecks(SmallVectorImpl<Value *> &Checks,
717 SmallVectorImpl<Value *> &WidenedChecks,
718 SCEVExpander &Expander, Instruction *Guard) {
719 for (auto &Check : Checks)
720 if (ICmpInst *ICI = dyn_cast<ICmpInst>(Check))
721 if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander, Guard)) {
722 WidenedChecks.push_back(Check);
723 Check = *NewRangeCheck;
724 }
725}
726
727bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard,
728 SCEVExpander &Expander) {
729 LLVM_DEBUG(dbgs() << "Processing guard:\n");
730 LLVM_DEBUG(Guard->dump());
731
732 TotalConsidered++;
734 SmallVector<Value *> WidenedChecks;
735 parseWidenableGuard(Guard, Checks);
736 widenChecks(Checks, WidenedChecks, Expander, Guard);
737 if (WidenedChecks.empty())
738 return false;
739
740 TotalWidened += WidenedChecks.size();
741
742 // Emit the new guard condition
743 IRBuilder<> Builder(findInsertPt(Guard, Checks));
744 Value *AllChecks = Builder.CreateAnd(Checks);
745 auto *OldCond = Guard->getOperand(0);
746 Guard->setOperand(0, AllChecks);
748 Builder.SetInsertPoint(&*++BasicBlock::iterator(Guard));
749 Builder.CreateAssumption(OldCond);
750 }
751 RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU);
752
753 LLVM_DEBUG(dbgs() << "Widened checks = " << WidenedChecks.size() << "\n");
754 return true;
755}
756
757bool LoopPredication::widenWidenableBranchGuardConditions(
758 BranchInst *BI, SCEVExpander &Expander) {
759 assert(isGuardAsWidenableBranch(BI) && "Must be!");
760 LLVM_DEBUG(dbgs() << "Processing guard:\n");
761 LLVM_DEBUG(BI->dump());
762
763 TotalConsidered++;
765 SmallVector<Value *> WidenedChecks;
766 parseWidenableGuard(BI, Checks);
767 // At the moment, our matching logic for wideable conditions implicitly
768 // assumes we preserve the form: (br (and Cond, WC())). FIXME
769 auto WC = extractWidenableCondition(BI);
770 Checks.push_back(WC);
771 widenChecks(Checks, WidenedChecks, Expander, BI);
772 if (WidenedChecks.empty())
773 return false;
774
775 TotalWidened += WidenedChecks.size();
776
777 // Emit the new guard condition
778 IRBuilder<> Builder(findInsertPt(BI, Checks));
779 Value *AllChecks = Builder.CreateAnd(Checks);
780 auto *OldCond = BI->getCondition();
781 BI->setCondition(AllChecks);
783 BasicBlock *IfTrueBB = BI->getSuccessor(0);
784 Builder.SetInsertPoint(IfTrueBB, IfTrueBB->getFirstInsertionPt());
785 // If this block has other predecessors, we might not be able to use Cond.
786 // In this case, create a Phi where every other input is `true` and input
787 // from guard block is Cond.
788 Value *AssumeCond = Builder.CreateAnd(WidenedChecks);
789 if (!IfTrueBB->getUniquePredecessor()) {
790 auto *GuardBB = BI->getParent();
791 auto *PN = Builder.CreatePHI(AssumeCond->getType(), pred_size(IfTrueBB),
792 "assume.cond");
793 for (auto *Pred : predecessors(IfTrueBB))
794 PN->addIncoming(Pred == GuardBB ? AssumeCond : Builder.getTrue(), Pred);
795 AssumeCond = PN;
796 }
797 Builder.CreateAssumption(AssumeCond);
798 }
799 RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU);
801 "Stopped being a guard after transform?");
802
803 LLVM_DEBUG(dbgs() << "Widened checks = " << WidenedChecks.size() << "\n");
804 return true;
805}
806
807std::optional<LoopICmp> LoopPredication::parseLoopLatchICmp() {
808 using namespace PatternMatch;
809
810 BasicBlock *LoopLatch = L->getLoopLatch();
811 if (!LoopLatch) {
812 LLVM_DEBUG(dbgs() << "The loop doesn't have a single latch!\n");
813 return std::nullopt;
814 }
815
816 auto *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator());
817 if (!BI || !BI->isConditional()) {
818 LLVM_DEBUG(dbgs() << "Failed to match the latch terminator!\n");
819 return std::nullopt;
820 }
821 BasicBlock *TrueDest = BI->getSuccessor(0);
822 assert(
823 (TrueDest == L->getHeader() || BI->getSuccessor(1) == L->getHeader()) &&
824 "One of the latch's destinations must be the header");
825
826 auto *ICI = dyn_cast<ICmpInst>(BI->getCondition());
827 if (!ICI) {
828 LLVM_DEBUG(dbgs() << "Failed to match the latch condition!\n");
829 return std::nullopt;
830 }
831 auto Result = parseLoopICmp(ICI);
832 if (!Result) {
833 LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n");
834 return std::nullopt;
835 }
836
837 if (TrueDest != L->getHeader())
839
840 // Check affine first, so if it's not we don't try to compute the step
841 // recurrence.
842 if (!Result->IV->isAffine()) {
843 LLVM_DEBUG(dbgs() << "The induction variable is not affine!\n");
844 return std::nullopt;
845 }
846
847 const SCEV *Step = Result->IV->getStepRecurrence(*SE);
848 if (!isSupportedStep(Step)) {
849 LLVM_DEBUG(dbgs() << "Unsupported loop stride(" << *Step << ")!\n");
850 return std::nullopt;
851 }
852
853 auto IsUnsupportedPredicate = [](const SCEV *Step, ICmpInst::Predicate Pred) {
854 if (Step->isOne()) {
855 return Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_SLT &&
856 Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_SLE;
857 } else {
858 assert(Step->isAllOnesValue() && "Step should be -1!");
859 return Pred != ICmpInst::ICMP_UGT && Pred != ICmpInst::ICMP_SGT &&
860 Pred != ICmpInst::ICMP_UGE && Pred != ICmpInst::ICMP_SGE;
861 }
862 };
863
864 normalizePredicate(SE, L, *Result);
865 if (IsUnsupportedPredicate(Step, Result->Pred)) {
866 LLVM_DEBUG(dbgs() << "Unsupported loop latch predicate(" << Result->Pred
867 << ")!\n");
868 return std::nullopt;
869 }
870
871 return Result;
872}
873
874bool LoopPredication::isLoopProfitableToPredicate() {
876 return true;
877
879 L->getExitEdges(ExitEdges);
880 // If there is only one exiting edge in the loop, it is always profitable to
881 // predicate the loop.
882 if (ExitEdges.size() == 1)
883 return true;
884
885 // Calculate the exiting probabilities of all exiting edges from the loop,
886 // starting with the LatchExitProbability.
887 // Heuristic for profitability: If any of the exiting blocks' probability of
888 // exiting the loop is larger than exiting through the latch block, it's not
889 // profitable to predicate the loop.
890 auto *LatchBlock = L->getLoopLatch();
891 assert(LatchBlock && "Should have a single latch at this point!");
892 auto *LatchTerm = LatchBlock->getTerminator();
893 assert(LatchTerm->getNumSuccessors() == 2 &&
894 "expected to be an exiting block with 2 succs!");
895 unsigned LatchBrExitIdx =
896 LatchTerm->getSuccessor(0) == L->getHeader() ? 1 : 0;
897 // We compute branch probabilities without BPI. We do not rely on BPI since
898 // Loop predication is usually run in an LPM and BPI is only preserved
899 // lossily within loop pass managers, while BPI has an inherent notion of
900 // being complete for an entire function.
901
902 // If the latch exits into a deoptimize or an unreachable block, do not
903 // predicate on that latch check.
904 auto *LatchExitBlock = LatchTerm->getSuccessor(LatchBrExitIdx);
905 if (isa<UnreachableInst>(LatchTerm) ||
906 LatchExitBlock->getTerminatingDeoptimizeCall())
907 return false;
908
909 // Latch terminator has no valid profile data, so nothing to check
910 // profitability on.
911 if (!hasValidBranchWeightMD(*LatchTerm))
912 return true;
913
914 auto ComputeBranchProbability =
915 [&](const BasicBlock *ExitingBlock,
916 const BasicBlock *ExitBlock) -> BranchProbability {
917 auto *Term = ExitingBlock->getTerminator();
918 unsigned NumSucc = Term->getNumSuccessors();
919 if (MDNode *ProfileData = getValidBranchWeightMDNode(*Term)) {
920 SmallVector<uint32_t> Weights;
921 extractBranchWeights(ProfileData, Weights);
922 uint64_t Numerator = 0, Denominator = 0;
923 for (auto [i, Weight] : llvm::enumerate(Weights)) {
924 if (Term->getSuccessor(i) == ExitBlock)
925 Numerator += Weight;
926 Denominator += Weight;
927 }
928 // If all weights are zero act as if there was no profile data
929 if (Denominator == 0)
931 return BranchProbability::getBranchProbability(Numerator, Denominator);
932 } else {
933 assert(LatchBlock != ExitingBlock &&
934 "Latch term should always have profile data!");
935 // No profile data, so we choose the weight as 1/num_of_succ(Src)
937 }
938 };
939
940 BranchProbability LatchExitProbability =
941 ComputeBranchProbability(LatchBlock, LatchExitBlock);
942
943 // Protect against degenerate inputs provided by the user. Providing a value
944 // less than one, can invert the definition of profitable loop predication.
945 float ScaleFactor = LatchExitProbabilityScale;
946 if (ScaleFactor < 1) {
948 dbgs()
949 << "Ignored user setting for loop-predication-latch-probability-scale: "
950 << LatchExitProbabilityScale << "\n");
951 LLVM_DEBUG(dbgs() << "The value is set to 1.0\n");
952 ScaleFactor = 1.0;
953 }
954 const auto LatchProbabilityThreshold = LatchExitProbability * ScaleFactor;
955
956 for (const auto &ExitEdge : ExitEdges) {
957 BranchProbability ExitingBlockProbability =
958 ComputeBranchProbability(ExitEdge.first, ExitEdge.second);
959 // Some exiting edge has higher probability than the latch exiting edge.
960 // No longer profitable to predicate.
961 if (ExitingBlockProbability > LatchProbabilityThreshold)
962 return false;
963 }
964
965 // We have concluded that the most probable way to exit from the
966 // loop is through the latch (or there's no profile information and all
967 // exits are equally likely).
968 return true;
969}
970
971/// If we can (cheaply) find a widenable branch which controls entry into the
972/// loop, return it.
974 // Walk back through any unconditional executed blocks and see if we can find
975 // a widenable condition which seems to control execution of this loop. Note
976 // that we predict that maythrow calls are likely untaken and thus that it's
977 // profitable to widen a branch before a maythrow call with a condition
978 // afterwards even though that may cause the slow path to run in a case where
979 // it wouldn't have otherwise.
980 BasicBlock *BB = L->getLoopPreheader();
981 if (!BB)
982 return nullptr;
983 do {
984 if (BasicBlock *Pred = BB->getSinglePredecessor())
985 if (BB == Pred->getSingleSuccessor()) {
986 BB = Pred;
987 continue;
988 }
989 break;
990 } while (true);
991
992 if (BasicBlock *Pred = BB->getSinglePredecessor()) {
993 if (auto *BI = dyn_cast<BranchInst>(Pred->getTerminator()))
994 if (BI->getSuccessor(0) == BB && isWidenableBranch(BI))
995 return BI;
996 }
997 return nullptr;
998}
999
1000/// Return the minimum of all analyzeable exit counts. This is an upper bound
1001/// on the actual exit count. If there are not at least two analyzeable exits,
1002/// returns SCEVCouldNotCompute.
1004 DominatorTree &DT,
1005 Loop *L) {
1006 SmallVector<BasicBlock *, 16> ExitingBlocks;
1007 L->getExitingBlocks(ExitingBlocks);
1008
1010 for (BasicBlock *ExitingBB : ExitingBlocks) {
1011 const SCEV *ExitCount = SE.getExitCount(L, ExitingBB);
1012 if (isa<SCEVCouldNotCompute>(ExitCount))
1013 continue;
1014 assert(DT.dominates(ExitingBB, L->getLoopLatch()) &&
1015 "We should only have known counts for exiting blocks that "
1016 "dominate latch!");
1017 ExitCounts.push_back(ExitCount);
1018 }
1019 if (ExitCounts.size() < 2)
1020 return SE.getCouldNotCompute();
1021 return SE.getUMinFromMismatchedTypes(ExitCounts);
1022}
1023
1024/// This implements an analogous, but entirely distinct transform from the main
1025/// loop predication transform. This one is phrased in terms of using a
1026/// widenable branch *outside* the loop to allow us to simplify loop exits in a
1027/// following loop. This is close in spirit to the IndVarSimplify transform
1028/// of the same name, but is materially different widening loosens legality
1029/// sharply.
1030bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
1031 // The transformation performed here aims to widen a widenable condition
1032 // above the loop such that all analyzeable exit leading to deopt are dead.
1033 // It assumes that the latch is the dominant exit for profitability and that
1034 // exits branching to deoptimizing blocks are rarely taken. It relies on the
1035 // semantics of widenable expressions for legality. (i.e. being able to fall
1036 // down the widenable path spuriously allows us to ignore exit order,
1037 // unanalyzeable exits, side effects, exceptional exits, and other challenges
1038 // which restrict the applicability of the non-WC based version of this
1039 // transform in IndVarSimplify.)
1040 //
1041 // NOTE ON POISON/UNDEF - We're hoisting an expression above guards which may
1042 // imply flags on the expression being hoisted and inserting new uses (flags
1043 // are only correct for current uses). The result is that we may be
1044 // inserting a branch on the value which can be either poison or undef. In
1045 // this case, the branch can legally go either way; we just need to avoid
1046 // introducing UB. This is achieved through the use of the freeze
1047 // instruction.
1048
1049 SmallVector<BasicBlock *, 16> ExitingBlocks;
1050 L->getExitingBlocks(ExitingBlocks);
1051
1052 if (ExitingBlocks.empty())
1053 return false; // Nothing to do.
1054
1055 auto *Latch = L->getLoopLatch();
1056 if (!Latch)
1057 return false;
1058
1059 auto *WidenableBR = FindWidenableTerminatorAboveLoop(L, *LI);
1060 if (!WidenableBR)
1061 return false;
1062
1063 const SCEV *LatchEC = SE->getExitCount(L, Latch);
1064 if (isa<SCEVCouldNotCompute>(LatchEC))
1065 return false; // profitability - want hot exit in analyzeable set
1066
1067 // At this point, we have found an analyzeable latch, and a widenable
1068 // condition above the loop. If we have a widenable exit within the loop
1069 // (for which we can't compute exit counts), drop the ability to further
1070 // widen so that we gain ability to analyze it's exit count and perform this
1071 // transform. TODO: It'd be nice to know for sure the exit became
1072 // analyzeable after dropping widenability.
1073 bool ChangedLoop = false;
1074
1075 for (auto *ExitingBB : ExitingBlocks) {
1076 if (LI->getLoopFor(ExitingBB) != L)
1077 continue;
1078
1079 auto *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
1080 if (!BI)
1081 continue;
1082
1083 if (auto WC = extractWidenableCondition(BI))
1084 if (L->contains(BI->getSuccessor(0))) {
1085 assert(WC->hasOneUse() && "Not appropriate widenable branch!");
1086 WC->user_back()->replaceUsesOfWith(
1087 WC, ConstantInt::getTrue(BI->getContext()));
1088 ChangedLoop = true;
1089 }
1090 }
1091 if (ChangedLoop)
1092 SE->forgetLoop(L);
1093
1094 // The insertion point for the widening should be at the widenably call, not
1095 // at the WidenableBR. If we do this at the widenableBR, we can incorrectly
1096 // change a loop-invariant condition to a loop-varying one.
1097 auto *IP = cast<Instruction>(WidenableBR->getCondition());
1098
1099 // The use of umin(all analyzeable exits) instead of latch is subtle, but
1100 // important for profitability. We may have a loop which hasn't been fully
1101 // canonicalized just yet. If the exit we chose to widen is provably never
1102 // taken, we want the widened form to *also* be provably never taken. We
1103 // can't guarantee this as a current unanalyzeable exit may later become
1104 // analyzeable, but we can at least avoid the obvious cases.
1105 const SCEV *MinEC = getMinAnalyzeableBackedgeTakenCount(*SE, *DT, L);
1106 if (isa<SCEVCouldNotCompute>(MinEC) || MinEC->getType()->isPointerTy() ||
1107 !SE->isLoopInvariant(MinEC, L) ||
1108 !Rewriter.isSafeToExpandAt(MinEC, IP))
1109 return ChangedLoop;
1110
1111 Rewriter.setInsertPoint(IP);
1112 IRBuilder<> B(IP);
1113
1114 bool InvalidateLoop = false;
1115 Value *MinECV = nullptr; // lazily generated if needed
1116 for (BasicBlock *ExitingBB : ExitingBlocks) {
1117 // If our exiting block exits multiple loops, we can only rewrite the
1118 // innermost one. Otherwise, we're changing how many times the innermost
1119 // loop runs before it exits.
1120 if (LI->getLoopFor(ExitingBB) != L)
1121 continue;
1122
1123 // Can't rewrite non-branch yet.
1124 auto *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
1125 if (!BI)
1126 continue;
1127
1128 // If already constant, nothing to do.
1129 if (isa<Constant>(BI->getCondition()))
1130 continue;
1131
1132 const SCEV *ExitCount = SE->getExitCount(L, ExitingBB);
1133 if (isa<SCEVCouldNotCompute>(ExitCount) ||
1134 ExitCount->getType()->isPointerTy() ||
1135 !Rewriter.isSafeToExpandAt(ExitCount, WidenableBR))
1136 continue;
1137
1138 const bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB));
1139 BasicBlock *ExitBB = BI->getSuccessor(ExitIfTrue ? 0 : 1);
1140 if (!ExitBB->getPostdominatingDeoptimizeCall())
1141 continue;
1142
1143 /// Here we can be fairly sure that executing this exit will most likely
1144 /// lead to executing llvm.experimental.deoptimize.
1145 /// This is a profitability heuristic, not a legality constraint.
1146
1147 // If we found a widenable exit condition, do two things:
1148 // 1) fold the widened exit test into the widenable condition
1149 // 2) fold the branch to untaken - avoids infinite looping
1150
1151 Value *ECV = Rewriter.expandCodeFor(ExitCount);
1152 if (!MinECV)
1153 MinECV = Rewriter.expandCodeFor(MinEC);
1154 Value *RHS = MinECV;
1155 if (ECV->getType() != RHS->getType()) {
1156 Type *WiderTy = SE->getWiderType(ECV->getType(), RHS->getType());
1157 ECV = B.CreateZExt(ECV, WiderTy);
1158 RHS = B.CreateZExt(RHS, WiderTy);
1159 }
1160 assert(!Latch || DT->dominates(ExitingBB, Latch));
1161 Value *NewCond = B.CreateICmp(ICmpInst::ICMP_UGT, ECV, RHS);
1162 // Freeze poison or undef to an arbitrary bit pattern to ensure we can
1163 // branch without introducing UB. See NOTE ON POISON/UNDEF above for
1164 // context.
1165 NewCond = B.CreateFreeze(NewCond);
1166
1167 widenWidenableBranch(WidenableBR, NewCond);
1168
1169 Value *OldCond = BI->getCondition();
1170 BI->setCondition(ConstantInt::get(OldCond->getType(), !ExitIfTrue));
1171 InvalidateLoop = true;
1172 }
1173
1174 if (InvalidateLoop)
1175 // We just mutated a bunch of loop exits changing there exit counts
1176 // widely. We need to force recomputation of the exit counts given these
1177 // changes. Note that all of the inserted exits are never taken, and
1178 // should be removed next time the CFG is modified.
1179 SE->forgetLoop(L);
1180
1181 // Always return `true` since we have moved the WidenableBR's condition.
1182 return true;
1183}
1184
1185bool LoopPredication::runOnLoop(Loop *Loop) {
1186 L = Loop;
1187
1188 LLVM_DEBUG(dbgs() << "Analyzing ");
1189 LLVM_DEBUG(L->dump());
1190
1191 Module *M = L->getHeader()->getModule();
1192
1193 // There is nothing to do if the module doesn't use guards
1194 auto *GuardDecl =
1195 Intrinsic::getDeclarationIfExists(M, Intrinsic::experimental_guard);
1196 bool HasIntrinsicGuards = GuardDecl && !GuardDecl->use_empty();
1197 auto *WCDecl = Intrinsic::getDeclarationIfExists(
1198 M, Intrinsic::experimental_widenable_condition);
1199 bool HasWidenableConditions =
1200 PredicateWidenableBranchGuards && WCDecl && !WCDecl->use_empty();
1201 if (!HasIntrinsicGuards && !HasWidenableConditions)
1202 return false;
1203
1204 DL = &M->getDataLayout();
1205
1206 Preheader = L->getLoopPreheader();
1207 if (!Preheader)
1208 return false;
1209
1210 auto LatchCheckOpt = parseLoopLatchICmp();
1211 if (!LatchCheckOpt)
1212 return false;
1213 LatchCheck = *LatchCheckOpt;
1214
1215 LLVM_DEBUG(dbgs() << "Latch check:\n");
1216 LLVM_DEBUG(LatchCheck.dump());
1217
1218 if (!isLoopProfitableToPredicate()) {
1219 LLVM_DEBUG(dbgs() << "Loop not profitable to predicate!\n");
1220 return false;
1221 }
1222 // Collect all the guards into a vector and process later, so as not
1223 // to invalidate the instruction iterator.
1225 SmallVector<BranchInst *, 4> GuardsAsWidenableBranches;
1226 for (const auto BB : L->blocks()) {
1227 for (auto &I : *BB)
1228 if (isGuard(&I))
1229 Guards.push_back(cast<IntrinsicInst>(&I));
1231 isGuardAsWidenableBranch(BB->getTerminator()))
1232 GuardsAsWidenableBranches.push_back(
1233 cast<BranchInst>(BB->getTerminator()));
1234 }
1235
1236 SCEVExpander Expander(*SE, *DL, "loop-predication");
1237 bool Changed = false;
1238 for (auto *Guard : Guards)
1239 Changed |= widenGuardConditions(Guard, Expander);
1240 for (auto *Guard : GuardsAsWidenableBranches)
1241 Changed |= widenWidenableBranchGuardConditions(Guard, Expander);
1242 Changed |= predicateLoopExits(L, Expander);
1243
1244 if (MSSAU && VerifyMemorySSA)
1245 MSSAU->getMemorySSA()->verifyMemorySSA();
1246 return Changed;
1247}
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DEBUG(...)
Definition: Debug.h:106
#define Check(C,...)
Module.h This file contains the declarations for the Module class.
static cl::opt< bool > SkipProfitabilityChecks("irce-skip-profitability-checks", cl::Hidden, cl::init(false))
static cl::opt< float > LatchExitProbabilityScale("loop-predication-latch-probability-scale", cl::Hidden, cl::init(2.0), cl::desc("scale factor for the latch probability. Value should be greater " "than 1. Lower values are ignored"))
static void normalizePredicate(ScalarEvolution *SE, Loop *L, LoopICmp &RC)
static cl::opt< bool > SkipProfitabilityChecks("loop-predication-skip-profitability-checks", cl::Hidden, cl::init(false))
static const SCEV * getMinAnalyzeableBackedgeTakenCount(ScalarEvolution &SE, DominatorTree &DT, Loop *L)
Return the minimum of all analyzeable exit counts.
static cl::opt< bool > EnableCountDownLoop("loop-predication-enable-count-down-loop", cl::Hidden, cl::init(true))
static cl::opt< bool > EnableIVTruncation("loop-predication-enable-iv-truncation", cl::Hidden, cl::init(true))
static std::optional< LoopICmp > generateLoopLatchCheck(const DataLayout &DL, ScalarEvolution &SE, const LoopICmp LatchCheck, Type *RangeCheckType)
static cl::opt< bool > PredicateWidenableBranchGuards("loop-predication-predicate-widenable-branches-to-deopt", cl::Hidden, cl::desc("Whether or not we should predicate guards " "expressed as widenable branches to deoptimize blocks"), cl::init(true))
static bool isSafeToTruncateWideIVType(const DataLayout &DL, ScalarEvolution &SE, const LoopICmp LatchCheck, Type *RangeCheckType)
static cl::opt< bool > InsertAssumesOfPredicatedGuardsConditions("loop-predication-insert-assumes-of-predicated-guards-conditions", cl::Hidden, cl::desc("Whether or not we should insert assumes of conditions of " "predicated guards"), cl::init(true))
static BranchInst * FindWidenableTerminatorAboveLoop(Loop *L, LoopInfo &LI)
If we can (cheaply) find a widenable branch which controls entry into the loop, return it.
#define I(x, y, z)
Definition: MD5.cpp:58
This file exposes an interface to building/using memory SSA to walk memory instructions using a use/d...
uint64_t IntrinsicInst * II
This file contains the declarations for profiling metadata utility functions.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:166
Virtual Register Rewriter
Definition: VirtRegMap.cpp:261
Value * RHS
Value * LHS
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
LLVM Basic Block Representation.
Definition: BasicBlock.h:61
const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
Definition: BasicBlock.cpp:416
const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
Definition: BasicBlock.cpp:459
const BasicBlock * getUniquePredecessor() const
Return the predecessor of this block if it has a unique predecessor block.
Definition: BasicBlock.cpp:467
const BasicBlock * getSingleSuccessor() const
Return the successor of this block if it has a single successor.
Definition: BasicBlock.cpp:489
InstListType::iterator iterator
Instruction iterators...
Definition: BasicBlock.h:177
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:239
const CallInst * getPostdominatingDeoptimizeCall() const
Returns the call instruction calling @llvm.experimental.deoptimize that is present either in current ...
Definition: BasicBlock.cpp:346
Conditional or Unconditional Branch instruction.
void setCondition(Value *V)
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
Value * getCondition() const
static BranchProbability getBranchProbability(uint64_t Numerator, uint64_t Denominator)
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:673
@ ICMP_SLT
signed less than
Definition: InstrTypes.h:702
@ ICMP_SLE
signed less or equal
Definition: InstrTypes.h:703
@ ICMP_UGE
unsigned greater or equal
Definition: InstrTypes.h:697
@ ICMP_UGT
unsigned greater than
Definition: InstrTypes.h:696
@ ICMP_SGT
signed greater than
Definition: InstrTypes.h:700
@ ICMP_ULT
unsigned less than
Definition: InstrTypes.h:698
@ ICMP_NE
not equal
Definition: InstrTypes.h:695
@ ICMP_SGE
signed greater or equal
Definition: InstrTypes.h:701
@ ICMP_ULE
unsigned less or equal
Definition: InstrTypes.h:699
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition: InstrTypes.h:825
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition: InstrTypes.h:787
Predicate getPredicate() const
Return the predicate for this instruction.
Definition: InstrTypes.h:763
Predicate getFlippedStrictnessPredicate() const
For predicate of kind "is X or equal to 0" returns the predicate "is X".
Definition: InstrTypes.h:891
static ConstantInt * getTrue(LLVMContext &Context)
Definition: Constants.cpp:866
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:63
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:162
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:122
This instruction compares its operands according to the predicate given to the constructor.
bool isEquality() const
Return true if this predicate is either EQ or NE.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2697
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:48
This class provides an interface for updating the loop pass manager based on mutations to the loop ne...
PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &U)
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:39
Metadata node.
Definition: Metadata.h:1069
An analysis that produces MemorySSA for a function.
Definition: MemorySSA.h:928
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
This node represents a polynomial recurrence on the trip count of the specified loop.
This class uses information about analyze scalars to rewrite expressions in canonical form.
bool isSafeToExpandAt(const SCEV *S, const Instruction *InsertionPoint) const
Return true if the given expression is safe to expand in the sense that all materialized values are d...
Value * expandCodeFor(const SCEV *SH, Type *Ty, BasicBlock::iterator I)
Insert code to directly compute the specified SCEV expression into the program.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents an analyzed expression in the program.
bool isOne() const
Return true if the expression is a constant one.
bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
Type * getType() const
Return the LLVM type of this SCEV expression.
The main scalar evolution driver.
bool isKnownPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
const SCEV * getCouldNotCompute()
const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
bool empty() const
Definition: SmallVector.h:81
size_t size() const
Definition: SmallVector.h:78
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:573
void push_back(const T &Elt)
Definition: SmallVector.h:413
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
bool isPointerTy() const
True if this is an instance of PointerType.
Definition: Type.h:264
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
void setOperand(unsigned i, Value *Val)
Definition: User.h:233
Value * getOperand(unsigned i) const
Definition: User.h:228
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:1075
void dump() const
Support for debugging, callable in GDB: V->dump()
Definition: AsmWriter.cpp:5304
const ParentTy * getParent() const
Definition: ilist_node.h:32
Function * getDeclarationIfExists(Module *M, ID id, ArrayRef< Type * > Tys, FunctionType *FT=nullptr)
This version supports overloaded intrinsics.
Definition: Intrinsics.cpp:746
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:443
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
Definition: Local.cpp:546
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are tuples (A, B,...
Definition: STLExtras.h:2448
void widenWidenableBranch(BranchInst *WidenableBR, Value *NewCond)
Given a branch we know is widenable (defined per Analysis/GuardUtils.h), widen it such that condition...
Definition: GuardUtils.cpp:82
Value * extractWidenableCondition(const User *U)
Definition: GuardUtils.cpp:151
auto pred_size(const MachineBasicBlock *BB)
void parseWidenableGuard(const User *U, llvm::SmallVectorImpl< Value * > &Checks)
Definition: GuardUtils.cpp:138
bool isGuard(const User *U)
Returns true iff U has semantics of a guard expressed in a form of call of llvm.experimental....
Definition: GuardUtils.cpp:18
MDNode * getValidBranchWeightMDNode(const Instruction &I)
Get the valid branch weights metadata node.
bool isModSet(const ModRefInfo MRI)
Definition: ModRef.h:48
bool hasValidBranchWeightMD(const Instruction &I)
Checks if an instructions has valid Branch Weight Metadata.
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
RNSuccIterator< NodeRef, BlockT, RegionT > succ_begin(NodeRef Node)
bool isWidenableBranch(const User *U)
Returns true iff U is a widenable branch (that is, extractWidenableCondition returns widenable condit...
Definition: GuardUtils.cpp:26
bool VerifyMemorySSA
Enables verification of MemorySSA.
Definition: MemorySSA.cpp:84
bool isGuardAsWidenableBranch(const User *U)
Returns true iff U has semantics of a guard expressed in a form of a widenable conditional branch to ...
Definition: GuardUtils.cpp:33
bool extractBranchWeights(const MDNode *ProfileData, SmallVectorImpl< uint32_t > &Weights)
Extract branch weights from MD_prof metadata.
PreservedAnalyses getLoopPassPreservedAnalyses()
Returns the minimum set of Analyses that all loop passes must preserve.
auto predecessors(const MachineBasicBlock *BB)
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:860
The adaptor from a function pass to a loop pass computes these analyses and makes them available to t...