LLVM 20.0.0git
LoopPredication.cpp
Go to the documentation of this file.
1//===-- LoopPredication.cpp - Guard based loop predication pass -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The LoopPredication pass tries to convert loop variant range checks to loop
10// invariant by widening checks across loop iterations. For example, it will
11// convert
12//
13// for (i = 0; i < n; i++) {
14// guard(i < len);
15// ...
16// }
17//
18// to
19//
20// for (i = 0; i < n; i++) {
21// guard(n - 1 < len);
22// ...
23// }
24//
25// After this transformation the condition of the guard is loop invariant, so
26// loop-unswitch can later unswitch the loop by this condition which basically
27// predicates the loop by the widened condition:
28//
29// if (n - 1 < len)
30// for (i = 0; i < n; i++) {
31// ...
32// }
33// else
34// deoptimize
35//
36// It's tempting to rely on SCEV here, but it has proven to be problematic.
37// Generally the facts SCEV provides about the increment step of add
38// recurrences are true if the backedge of the loop is taken, which implicitly
39// assumes that the guard doesn't fail. Using these facts to optimize the
40// guard results in a circular logic where the guard is optimized under the
41// assumption that it never fails.
42//
43// For example, in the loop below the induction variable will be marked as nuw
44// basing on the guard. Basing on nuw the guard predicate will be considered
45// monotonic. Given a monotonic condition it's tempting to replace the induction
46// variable in the condition with its value on the last iteration. But this
47// transformation is not correct, e.g. e = 4, b = 5 breaks the loop.
48//
49// for (int i = b; i != e; i++)
50// guard(i u< len)
51//
52// One of the ways to reason about this problem is to use an inductive proof
53// approach. Given the loop:
54//
55// if (B(0)) {
56// do {
57// I = PHI(0, I.INC)
58// I.INC = I + Step
59// guard(G(I));
60// } while (B(I));
61// }
62//
63// where B(x) and G(x) are predicates that map integers to booleans, we want a
64// loop invariant expression M such the following program has the same semantics
65// as the above:
66//
67// if (B(0)) {
68// do {
69// I = PHI(0, I.INC)
70// I.INC = I + Step
71// guard(G(0) && M);
72// } while (B(I));
73// }
74//
75// One solution for M is M = forall X . (G(X) && B(X)) => G(X + Step)
76//
77// Informal proof that the transformation above is correct:
78//
79// By the definition of guards we can rewrite the guard condition to:
80// G(I) && G(0) && M
81//
82// Let's prove that for each iteration of the loop:
83// G(0) && M => G(I)
84// And the condition above can be simplified to G(Start) && M.
85//
86// Induction base.
87// G(0) && M => G(0)
88//
89// Induction step. Assuming G(0) && M => G(I) on the subsequent
90// iteration:
91//
92// B(I) is true because it's the backedge condition.
93// G(I) is true because the backedge is guarded by this condition.
94//
95// So M = forall X . (G(X) && B(X)) => G(X + Step) implies G(I + Step).
96//
97// Note that we can use anything stronger than M, i.e. any condition which
98// implies M.
99//
100// When S = 1 (i.e. forward iterating loop), the transformation is supported
101// when:
102// * The loop has a single latch with the condition of the form:
103// B(X) = latchStart + X <pred> latchLimit,
104// where <pred> is u<, u<=, s<, or s<=.
105// * The guard condition is of the form
106// G(X) = guardStart + X u< guardLimit
107//
108// For the ult latch comparison case M is:
109// forall X . guardStart + X u< guardLimit && latchStart + X <u latchLimit =>
110// guardStart + X + 1 u< guardLimit
111//
112// The only way the antecedent can be true and the consequent can be false is
113// if
114// X == guardLimit - 1 - guardStart
115// (and guardLimit is non-zero, but we won't use this latter fact).
116// If X == guardLimit - 1 - guardStart then the second half of the antecedent is
117// latchStart + guardLimit - 1 - guardStart u< latchLimit
118// and its negation is
119// latchStart + guardLimit - 1 - guardStart u>= latchLimit
120//
121// In other words, if
122// latchLimit u<= latchStart + guardLimit - 1 - guardStart
123// then:
124// (the ranges below are written in ConstantRange notation, where [A, B) is the
125// set for (I = A; I != B; I++ /*maywrap*/) yield(I);)
126//
127// forall X . guardStart + X u< guardLimit &&
128// latchStart + X u< latchLimit =>
129// guardStart + X + 1 u< guardLimit
130// == forall X . guardStart + X u< guardLimit &&
131// latchStart + X u< latchStart + guardLimit - 1 - guardStart =>
132// guardStart + X + 1 u< guardLimit
133// == forall X . (guardStart + X) in [0, guardLimit) &&
134// (latchStart + X) in [0, latchStart + guardLimit - 1 - guardStart) =>
135// (guardStart + X + 1) in [0, guardLimit)
136// == forall X . X in [-guardStart, guardLimit - guardStart) &&
137// X in [-latchStart, guardLimit - 1 - guardStart) =>
138// X in [-guardStart - 1, guardLimit - guardStart - 1)
139// == true
140//
141// So the widened condition is:
142// guardStart u< guardLimit &&
143// latchStart + guardLimit - 1 - guardStart u>= latchLimit
144// Similarly for ule condition the widened condition is:
145// guardStart u< guardLimit &&
146// latchStart + guardLimit - 1 - guardStart u> latchLimit
147// For slt condition the widened condition is:
148// guardStart u< guardLimit &&
149// latchStart + guardLimit - 1 - guardStart s>= latchLimit
150// For sle condition the widened condition is:
151// guardStart u< guardLimit &&
152// latchStart + guardLimit - 1 - guardStart s> latchLimit
153//
154// When S = -1 (i.e. reverse iterating loop), the transformation is supported
155// when:
156// * The loop has a single latch with the condition of the form:
157// B(X) = X <pred> latchLimit, where <pred> is u>, u>=, s>, or s>=.
158// * The guard condition is of the form
159// G(X) = X - 1 u< guardLimit
160//
161// For the ugt latch comparison case M is:
162// forall X. X-1 u< guardLimit and X u> latchLimit => X-2 u< guardLimit
163//
164// The only way the antecedent can be true and the consequent can be false is if
165// X == 1.
166// If X == 1 then the second half of the antecedent is
167// 1 u> latchLimit, and its negation is latchLimit u>= 1.
168//
169// So the widened condition is:
170// guardStart u< guardLimit && latchLimit u>= 1.
171// Similarly for sgt condition the widened condition is:
172// guardStart u< guardLimit && latchLimit s>= 1.
173// For uge condition the widened condition is:
174// guardStart u< guardLimit && latchLimit u> 1.
175// For sge condition the widened condition is:
176// guardStart u< guardLimit && latchLimit s> 1.
177//===----------------------------------------------------------------------===//
178
180#include "llvm/ADT/Statistic.h"
190#include "llvm/IR/Function.h"
192#include "llvm/IR/Module.h"
193#include "llvm/IR/PatternMatch.h"
196#include "llvm/Pass.h"
198#include "llvm/Support/Debug.h"
204#include <optional>
205
206#define DEBUG_TYPE "loop-predication"
207
208STATISTIC(TotalConsidered, "Number of guards considered");
209STATISTIC(TotalWidened, "Number of checks widened");
210
211using namespace llvm;
212
213static cl::opt<bool> EnableIVTruncation("loop-predication-enable-iv-truncation",
214 cl::Hidden, cl::init(true));
215
216static cl::opt<bool> EnableCountDownLoop("loop-predication-enable-count-down-loop",
217 cl::Hidden, cl::init(true));
218
219static cl::opt<bool>
220 SkipProfitabilityChecks("loop-predication-skip-profitability-checks",
221 cl::Hidden, cl::init(false));
222
223// This is the scale factor for the latch probability. We use this during
224// profitability analysis to find other exiting blocks that have a much higher
225// probability of exiting the loop instead of loop exiting via latch.
226// This value should be greater than 1 for a sane profitability check.
228 "loop-predication-latch-probability-scale", cl::Hidden, cl::init(2.0),
229 cl::desc("scale factor for the latch probability. Value should be greater "
230 "than 1. Lower values are ignored"));
231
233 "loop-predication-predicate-widenable-branches-to-deopt", cl::Hidden,
234 cl::desc("Whether or not we should predicate guards "
235 "expressed as widenable branches to deoptimize blocks"),
236 cl::init(true));
237
239 "loop-predication-insert-assumes-of-predicated-guards-conditions",
241 cl::desc("Whether or not we should insert assumes of conditions of "
242 "predicated guards"),
243 cl::init(true));
244
245namespace {
246/// Represents an induction variable check:
247/// icmp Pred, <induction variable>, <loop invariant limit>
248struct LoopICmp {
250 const SCEVAddRecExpr *IV;
251 const SCEV *Limit;
252 LoopICmp(ICmpInst::Predicate Pred, const SCEVAddRecExpr *IV,
253 const SCEV *Limit)
254 : Pred(Pred), IV(IV), Limit(Limit) {}
255 LoopICmp() = default;
256 void dump() {
257 dbgs() << "LoopICmp Pred = " << Pred << ", IV = " << *IV
258 << ", Limit = " << *Limit << "\n";
259 }
260};
261
262class LoopPredication {
263 AliasAnalysis *AA;
264 DominatorTree *DT;
265 ScalarEvolution *SE;
266 LoopInfo *LI;
267 MemorySSAUpdater *MSSAU;
268
269 Loop *L;
270 const DataLayout *DL;
271 BasicBlock *Preheader;
272 LoopICmp LatchCheck;
273
274 bool isSupportedStep(const SCEV* Step);
275 std::optional<LoopICmp> parseLoopICmp(ICmpInst *ICI);
276 std::optional<LoopICmp> parseLoopLatchICmp();
277
278 /// Return an insertion point suitable for inserting a safe to speculate
279 /// instruction whose only user will be 'User' which has operands 'Ops'. A
280 /// trivial result would be the at the User itself, but we try to return a
281 /// loop invariant location if possible.
282 Instruction *findInsertPt(Instruction *User, ArrayRef<Value*> Ops);
283 /// Same as above, *except* that this uses the SCEV definition of invariant
284 /// which is that an expression *can be made* invariant via SCEVExpander.
285 /// Thus, this version is only suitable for finding an insert point to be
286 /// passed to SCEVExpander!
287 Instruction *findInsertPt(const SCEVExpander &Expander, Instruction *User,
289
290 /// Return true if the value is known to produce a single fixed value across
291 /// all iterations on which it executes. Note that this does not imply
292 /// speculation safety. That must be established separately.
293 bool isLoopInvariantValue(const SCEV* S);
294
295 Value *expandCheck(SCEVExpander &Expander, Instruction *Guard,
296 ICmpInst::Predicate Pred, const SCEV *LHS,
297 const SCEV *RHS);
298
299 std::optional<Value *> widenICmpRangeCheck(ICmpInst *ICI,
300 SCEVExpander &Expander,
301 Instruction *Guard);
302 std::optional<Value *>
303 widenICmpRangeCheckIncrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck,
304 SCEVExpander &Expander,
305 Instruction *Guard);
306 std::optional<Value *>
307 widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck,
308 SCEVExpander &Expander,
309 Instruction *Guard);
310 void widenChecks(SmallVectorImpl<Value *> &Checks,
311 SmallVectorImpl<Value *> &WidenedChecks,
312 SCEVExpander &Expander, Instruction *Guard);
313 bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander);
314 bool widenWidenableBranchGuardConditions(BranchInst *Guard, SCEVExpander &Expander);
315 // If the loop always exits through another block in the loop, we should not
316 // predicate based on the latch check. For example, the latch check can be a
317 // very coarse grained check and there can be more fine grained exit checks
318 // within the loop.
319 bool isLoopProfitableToPredicate();
320
321 bool predicateLoopExits(Loop *L, SCEVExpander &Rewriter);
322
323public:
324 LoopPredication(AliasAnalysis *AA, DominatorTree *DT, ScalarEvolution *SE,
325 LoopInfo *LI, MemorySSAUpdater *MSSAU)
326 : AA(AA), DT(DT), SE(SE), LI(LI), MSSAU(MSSAU){};
327 bool runOnLoop(Loop *L);
328};
329
330} // end namespace
331
334 LPMUpdater &U) {
335 std::unique_ptr<MemorySSAUpdater> MSSAU;
336 if (AR.MSSA)
337 MSSAU = std::make_unique<MemorySSAUpdater>(AR.MSSA);
338 LoopPredication LP(&AR.AA, &AR.DT, &AR.SE, &AR.LI,
339 MSSAU ? MSSAU.get() : nullptr);
340 if (!LP.runOnLoop(&L))
341 return PreservedAnalyses::all();
342
344 if (AR.MSSA)
345 PA.preserve<MemorySSAAnalysis>();
346 return PA;
347}
348
349std::optional<LoopICmp> LoopPredication::parseLoopICmp(ICmpInst *ICI) {
350 auto Pred = ICI->getPredicate();
351 auto *LHS = ICI->getOperand(0);
352 auto *RHS = ICI->getOperand(1);
353
354 const SCEV *LHSS = SE->getSCEV(LHS);
355 if (isa<SCEVCouldNotCompute>(LHSS))
356 return std::nullopt;
357 const SCEV *RHSS = SE->getSCEV(RHS);
358 if (isa<SCEVCouldNotCompute>(RHSS))
359 return std::nullopt;
360
361 // Canonicalize RHS to be loop invariant bound, LHS - a loop computable IV
362 if (SE->isLoopInvariant(LHSS, L)) {
363 std::swap(LHS, RHS);
364 std::swap(LHSS, RHSS);
366 }
367
368 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHSS);
369 if (!AR || AR->getLoop() != L)
370 return std::nullopt;
371
372 return LoopICmp(Pred, AR, RHSS);
373}
374
375Value *LoopPredication::expandCheck(SCEVExpander &Expander,
376 Instruction *Guard,
377 ICmpInst::Predicate Pred, const SCEV *LHS,
378 const SCEV *RHS) {
379 Type *Ty = LHS->getType();
380 assert(Ty == RHS->getType() && "expandCheck operands have different types?");
381
382 if (SE->isLoopInvariant(LHS, L) && SE->isLoopInvariant(RHS, L)) {
383 IRBuilder<> Builder(Guard);
384 if (SE->isLoopEntryGuardedByCond(L, Pred, LHS, RHS))
385 return Builder.getTrue();
386 if (SE->isLoopEntryGuardedByCond(L, ICmpInst::getInversePredicate(Pred),
387 LHS, RHS))
388 return Builder.getFalse();
389 }
390
391 Value *LHSV =
392 Expander.expandCodeFor(LHS, Ty, findInsertPt(Expander, Guard, {LHS}));
393 Value *RHSV =
394 Expander.expandCodeFor(RHS, Ty, findInsertPt(Expander, Guard, {RHS}));
395 IRBuilder<> Builder(findInsertPt(Guard, {LHSV, RHSV}));
396 return Builder.CreateICmp(Pred, LHSV, RHSV);
397}
398
399// Returns true if its safe to truncate the IV to RangeCheckType.
400// When the IV type is wider than the range operand type, we can still do loop
401// predication, by generating SCEVs for the range and latch that are of the
402// same type. We achieve this by generating a SCEV truncate expression for the
403// latch IV. This is done iff truncation of the IV is a safe operation,
404// without loss of information.
405// Another way to achieve this is by generating a wider type SCEV for the
406// range check operand, however, this needs a more involved check that
407// operands do not overflow. This can lead to loss of information when the
408// range operand is of the form: add i32 %offset, %iv. We need to prove that
409// sext(x + y) is same as sext(x) + sext(y).
410// This function returns true if we can safely represent the IV type in
411// the RangeCheckType without loss of information.
413 ScalarEvolution &SE,
414 const LoopICmp LatchCheck,
415 Type *RangeCheckType) {
417 return false;
418 assert(DL.getTypeSizeInBits(LatchCheck.IV->getType()).getFixedValue() >
419 DL.getTypeSizeInBits(RangeCheckType).getFixedValue() &&
420 "Expected latch check IV type to be larger than range check operand "
421 "type!");
422 // The start and end values of the IV should be known. This is to guarantee
423 // that truncating the wide type will not lose information.
424 auto *Limit = dyn_cast<SCEVConstant>(LatchCheck.Limit);
425 auto *Start = dyn_cast<SCEVConstant>(LatchCheck.IV->getStart());
426 if (!Limit || !Start)
427 return false;
428 // This check makes sure that the IV does not change sign during loop
429 // iterations. Consider latchType = i64, LatchStart = 5, Pred = ICMP_SGE,
430 // LatchEnd = 2, rangeCheckType = i32. If it's not a monotonic predicate, the
431 // IV wraps around, and the truncation of the IV would lose the range of
432 // iterations between 2^32 and 2^64.
433 if (!SE.getMonotonicPredicateType(LatchCheck.IV, LatchCheck.Pred))
434 return false;
435 // The active bits should be less than the bits in the RangeCheckType. This
436 // guarantees that truncating the latch check to RangeCheckType is a safe
437 // operation.
438 auto RangeCheckTypeBitSize =
439 DL.getTypeSizeInBits(RangeCheckType).getFixedValue();
440 return Start->getAPInt().getActiveBits() < RangeCheckTypeBitSize &&
441 Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize;
442}
443
444
445// Return an LoopICmp describing a latch check equivlent to LatchCheck but with
446// the requested type if safe to do so. May involve the use of a new IV.
447static std::optional<LoopICmp> generateLoopLatchCheck(const DataLayout &DL,
448 ScalarEvolution &SE,
449 const LoopICmp LatchCheck,
450 Type *RangeCheckType) {
451
452 auto *LatchType = LatchCheck.IV->getType();
453 if (RangeCheckType == LatchType)
454 return LatchCheck;
455 // For now, bail out if latch type is narrower than range type.
456 if (DL.getTypeSizeInBits(LatchType).getFixedValue() <
457 DL.getTypeSizeInBits(RangeCheckType).getFixedValue())
458 return std::nullopt;
459 if (!isSafeToTruncateWideIVType(DL, SE, LatchCheck, RangeCheckType))
460 return std::nullopt;
461 // We can now safely identify the truncated version of the IV and limit for
462 // RangeCheckType.
463 LoopICmp NewLatchCheck;
464 NewLatchCheck.Pred = LatchCheck.Pred;
465 NewLatchCheck.IV = dyn_cast<SCEVAddRecExpr>(
466 SE.getTruncateExpr(LatchCheck.IV, RangeCheckType));
467 if (!NewLatchCheck.IV)
468 return std::nullopt;
469 NewLatchCheck.Limit = SE.getTruncateExpr(LatchCheck.Limit, RangeCheckType);
470 LLVM_DEBUG(dbgs() << "IV of type: " << *LatchType
471 << "can be represented as range check type:"
472 << *RangeCheckType << "\n");
473 LLVM_DEBUG(dbgs() << "LatchCheck.IV: " << *NewLatchCheck.IV << "\n");
474 LLVM_DEBUG(dbgs() << "LatchCheck.Limit: " << *NewLatchCheck.Limit << "\n");
475 return NewLatchCheck;
476}
477
478bool LoopPredication::isSupportedStep(const SCEV* Step) {
479 return Step->isOne() || (Step->isAllOnesValue() && EnableCountDownLoop);
480}
481
482Instruction *LoopPredication::findInsertPt(Instruction *Use,
483 ArrayRef<Value*> Ops) {
484 for (Value *Op : Ops)
485 if (!L->isLoopInvariant(Op))
486 return Use;
487 return Preheader->getTerminator();
488}
489
490Instruction *LoopPredication::findInsertPt(const SCEVExpander &Expander,
493 // Subtlety: SCEV considers things to be invariant if the value produced is
494 // the same across iterations. This is not the same as being able to
495 // evaluate outside the loop, which is what we actually need here.
496 for (const SCEV *Op : Ops)
497 if (!SE->isLoopInvariant(Op, L) ||
498 !Expander.isSafeToExpandAt(Op, Preheader->getTerminator()))
499 return Use;
500 return Preheader->getTerminator();
501}
502
503bool LoopPredication::isLoopInvariantValue(const SCEV* S) {
504 // Handling expressions which produce invariant results, but *haven't* yet
505 // been removed from the loop serves two important purposes.
506 // 1) Most importantly, it resolves a pass ordering cycle which would
507 // otherwise need us to iteration licm, loop-predication, and either
508 // loop-unswitch or loop-peeling to make progress on examples with lots of
509 // predicable range checks in a row. (Since, in the general case, we can't
510 // hoist the length checks until the dominating checks have been discharged
511 // as we can't prove doing so is safe.)
512 // 2) As a nice side effect, this exposes the value of peeling or unswitching
513 // much more obviously in the IR. Otherwise, the cost modeling for other
514 // transforms would end up needing to duplicate all of this logic to model a
515 // check which becomes predictable based on a modeled peel or unswitch.
516 //
517 // The cost of doing so in the worst case is an extra fill from the stack in
518 // the loop to materialize the loop invariant test value instead of checking
519 // against the original IV which is presumable in a register inside the loop.
520 // Such cases are presumably rare, and hint at missing oppurtunities for
521 // other passes.
522
523 if (SE->isLoopInvariant(S, L))
524 // Note: This the SCEV variant, so the original Value* may be within the
525 // loop even though SCEV has proven it is loop invariant.
526 return true;
527
528 // Handle a particular important case which SCEV doesn't yet know about which
529 // shows up in range checks on arrays with immutable lengths.
530 // TODO: This should be sunk inside SCEV.
531 if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S))
532 if (const auto *LI = dyn_cast<LoadInst>(U->getValue()))
533 if (LI->isUnordered() && L->hasLoopInvariantOperands(LI))
534 if (!isModSet(AA->getModRefInfoMask(LI->getOperand(0))) ||
535 LI->hasMetadata(LLVMContext::MD_invariant_load))
536 return true;
537 return false;
538}
539
540std::optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop(
541 LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander,
542 Instruction *Guard) {
543 auto *Ty = RangeCheck.IV->getType();
544 // Generate the widened condition for the forward loop:
545 // guardStart u< guardLimit &&
546 // latchLimit <pred> guardLimit - 1 - guardStart + latchStart
547 // where <pred> depends on the latch condition predicate. See the file
548 // header comment for the reasoning.
549 // guardLimit - guardStart + latchStart - 1
550 const SCEV *GuardStart = RangeCheck.IV->getStart();
551 const SCEV *GuardLimit = RangeCheck.Limit;
552 const SCEV *LatchStart = LatchCheck.IV->getStart();
553 const SCEV *LatchLimit = LatchCheck.Limit;
554 // Subtlety: We need all the values to be *invariant* across all iterations,
555 // but we only need to check expansion safety for those which *aren't*
556 // already guaranteed to dominate the guard.
557 if (!isLoopInvariantValue(GuardStart) ||
558 !isLoopInvariantValue(GuardLimit) ||
559 !isLoopInvariantValue(LatchStart) ||
560 !isLoopInvariantValue(LatchLimit)) {
561 LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
562 return std::nullopt;
563 }
564 if (!Expander.isSafeToExpandAt(LatchStart, Guard) ||
565 !Expander.isSafeToExpandAt(LatchLimit, Guard)) {
566 LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
567 return std::nullopt;
568 }
569
570 // guardLimit - guardStart + latchStart - 1
571 const SCEV *RHS =
572 SE->getAddExpr(SE->getMinusSCEV(GuardLimit, GuardStart),
573 SE->getMinusSCEV(LatchStart, SE->getOne(Ty)));
574 auto LimitCheckPred =
576
577 LLVM_DEBUG(dbgs() << "LHS: " << *LatchLimit << "\n");
578 LLVM_DEBUG(dbgs() << "RHS: " << *RHS << "\n");
579 LLVM_DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n");
580
581 auto *LimitCheck =
582 expandCheck(Expander, Guard, LimitCheckPred, LatchLimit, RHS);
583 auto *FirstIterationCheck = expandCheck(Expander, Guard, RangeCheck.Pred,
584 GuardStart, GuardLimit);
585 IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck}));
586 return Builder.CreateFreeze(
587 Builder.CreateAnd(FirstIterationCheck, LimitCheck));
588}
589
590std::optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(
591 LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander,
592 Instruction *Guard) {
593 auto *Ty = RangeCheck.IV->getType();
594 const SCEV *GuardStart = RangeCheck.IV->getStart();
595 const SCEV *GuardLimit = RangeCheck.Limit;
596 const SCEV *LatchStart = LatchCheck.IV->getStart();
597 const SCEV *LatchLimit = LatchCheck.Limit;
598 // Subtlety: We need all the values to be *invariant* across all iterations,
599 // but we only need to check expansion safety for those which *aren't*
600 // already guaranteed to dominate the guard.
601 if (!isLoopInvariantValue(GuardStart) ||
602 !isLoopInvariantValue(GuardLimit) ||
603 !isLoopInvariantValue(LatchStart) ||
604 !isLoopInvariantValue(LatchLimit)) {
605 LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
606 return std::nullopt;
607 }
608 if (!Expander.isSafeToExpandAt(LatchStart, Guard) ||
609 !Expander.isSafeToExpandAt(LatchLimit, Guard)) {
610 LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
611 return std::nullopt;
612 }
613 // The decrement of the latch check IV should be the same as the
614 // rangeCheckIV.
615 auto *PostDecLatchCheckIV = LatchCheck.IV->getPostIncExpr(*SE);
616 if (RangeCheck.IV != PostDecLatchCheckIV) {
617 LLVM_DEBUG(dbgs() << "Not the same. PostDecLatchCheckIV: "
618 << *PostDecLatchCheckIV
619 << " and RangeCheckIV: " << *RangeCheck.IV << "\n");
620 return std::nullopt;
621 }
622
623 // Generate the widened condition for CountDownLoop:
624 // guardStart u< guardLimit &&
625 // latchLimit <pred> 1.
626 // See the header comment for reasoning of the checks.
627 auto LimitCheckPred =
629 auto *FirstIterationCheck = expandCheck(Expander, Guard,
631 GuardStart, GuardLimit);
632 auto *LimitCheck = expandCheck(Expander, Guard, LimitCheckPred, LatchLimit,
633 SE->getOne(Ty));
634 IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck}));
635 return Builder.CreateFreeze(
636 Builder.CreateAnd(FirstIterationCheck, LimitCheck));
637}
638
640 LoopICmp& RC) {
641 // LFTR canonicalizes checks to the ICMP_NE/EQ form; normalize back to the
642 // ULT/UGE form for ease of handling by our caller.
643 if (ICmpInst::isEquality(RC.Pred) &&
644 RC.IV->getStepRecurrence(*SE)->isOne() &&
645 SE->isKnownPredicate(ICmpInst::ICMP_ULE, RC.IV->getStart(), RC.Limit))
646 RC.Pred = RC.Pred == ICmpInst::ICMP_NE ?
648}
649
650/// If ICI can be widened to a loop invariant condition emits the loop
651/// invariant condition in the loop preheader and return it, otherwise
652/// returns std::nullopt.
653std::optional<Value *>
654LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander,
655 Instruction *Guard) {
656 LLVM_DEBUG(dbgs() << "Analyzing ICmpInst condition:\n");
657 LLVM_DEBUG(ICI->dump());
658
659 // parseLoopStructure guarantees that the latch condition is:
660 // ++i <pred> latchLimit, where <pred> is u<, u<=, s<, or s<=.
661 // We are looking for the range checks of the form:
662 // i u< guardLimit
663 auto RangeCheck = parseLoopICmp(ICI);
664 if (!RangeCheck) {
665 LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n");
666 return std::nullopt;
667 }
668 LLVM_DEBUG(dbgs() << "Guard check:\n");
669 LLVM_DEBUG(RangeCheck->dump());
670 if (RangeCheck->Pred != ICmpInst::ICMP_ULT) {
671 LLVM_DEBUG(dbgs() << "Unsupported range check predicate("
672 << RangeCheck->Pred << ")!\n");
673 return std::nullopt;
674 }
675 auto *RangeCheckIV = RangeCheck->IV;
676 if (!RangeCheckIV->isAffine()) {
677 LLVM_DEBUG(dbgs() << "Range check IV is not affine!\n");
678 return std::nullopt;
679 }
680 const SCEV *Step = RangeCheckIV->getStepRecurrence(*SE);
681 // We cannot just compare with latch IV step because the latch and range IVs
682 // may have different types.
683 if (!isSupportedStep(Step)) {
684 LLVM_DEBUG(dbgs() << "Range check and latch have IVs different steps!\n");
685 return std::nullopt;
686 }
687 auto *Ty = RangeCheckIV->getType();
688 auto CurrLatchCheckOpt = generateLoopLatchCheck(*DL, *SE, LatchCheck, Ty);
689 if (!CurrLatchCheckOpt) {
690 LLVM_DEBUG(dbgs() << "Failed to generate a loop latch check "
691 "corresponding to range type: "
692 << *Ty << "\n");
693 return std::nullopt;
694 }
695
696 LoopICmp CurrLatchCheck = *CurrLatchCheckOpt;
697 // At this point, the range and latch step should have the same type, but need
698 // not have the same value (we support both 1 and -1 steps).
699 assert(Step->getType() ==
700 CurrLatchCheck.IV->getStepRecurrence(*SE)->getType() &&
701 "Range and latch steps should be of same type!");
702 if (Step != CurrLatchCheck.IV->getStepRecurrence(*SE)) {
703 LLVM_DEBUG(dbgs() << "Range and latch have different step values!\n");
704 return std::nullopt;
705 }
706
707 if (Step->isOne())
708 return widenICmpRangeCheckIncrementingLoop(CurrLatchCheck, *RangeCheck,
709 Expander, Guard);
710 else {
711 assert(Step->isAllOnesValue() && "Step should be -1!");
712 return widenICmpRangeCheckDecrementingLoop(CurrLatchCheck, *RangeCheck,
713 Expander, Guard);
714 }
715}
716
717void LoopPredication::widenChecks(SmallVectorImpl<Value *> &Checks,
718 SmallVectorImpl<Value *> &WidenedChecks,
719 SCEVExpander &Expander, Instruction *Guard) {
720 for (auto &Check : Checks)
721 if (ICmpInst *ICI = dyn_cast<ICmpInst>(Check))
722 if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander, Guard)) {
723 WidenedChecks.push_back(Check);
724 Check = *NewRangeCheck;
725 }
726}
727
728bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard,
729 SCEVExpander &Expander) {
730 LLVM_DEBUG(dbgs() << "Processing guard:\n");
731 LLVM_DEBUG(Guard->dump());
732
733 TotalConsidered++;
735 SmallVector<Value *> WidenedChecks;
736 parseWidenableGuard(Guard, Checks);
737 widenChecks(Checks, WidenedChecks, Expander, Guard);
738 if (WidenedChecks.empty())
739 return false;
740
741 TotalWidened += WidenedChecks.size();
742
743 // Emit the new guard condition
744 IRBuilder<> Builder(findInsertPt(Guard, Checks));
745 Value *AllChecks = Builder.CreateAnd(Checks);
746 auto *OldCond = Guard->getOperand(0);
747 Guard->setOperand(0, AllChecks);
749 Builder.SetInsertPoint(&*++BasicBlock::iterator(Guard));
750 Builder.CreateAssumption(OldCond);
751 }
752 RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU);
753
754 LLVM_DEBUG(dbgs() << "Widened checks = " << WidenedChecks.size() << "\n");
755 return true;
756}
757
758bool LoopPredication::widenWidenableBranchGuardConditions(
759 BranchInst *BI, SCEVExpander &Expander) {
760 assert(isGuardAsWidenableBranch(BI) && "Must be!");
761 LLVM_DEBUG(dbgs() << "Processing guard:\n");
762 LLVM_DEBUG(BI->dump());
763
764 TotalConsidered++;
766 SmallVector<Value *> WidenedChecks;
767 parseWidenableGuard(BI, Checks);
768 // At the moment, our matching logic for wideable conditions implicitly
769 // assumes we preserve the form: (br (and Cond, WC())). FIXME
770 auto WC = extractWidenableCondition(BI);
771 Checks.push_back(WC);
772 widenChecks(Checks, WidenedChecks, Expander, BI);
773 if (WidenedChecks.empty())
774 return false;
775
776 TotalWidened += WidenedChecks.size();
777
778 // Emit the new guard condition
779 IRBuilder<> Builder(findInsertPt(BI, Checks));
780 Value *AllChecks = Builder.CreateAnd(Checks);
781 auto *OldCond = BI->getCondition();
782 BI->setCondition(AllChecks);
784 BasicBlock *IfTrueBB = BI->getSuccessor(0);
785 Builder.SetInsertPoint(IfTrueBB, IfTrueBB->getFirstInsertionPt());
786 // If this block has other predecessors, we might not be able to use Cond.
787 // In this case, create a Phi where every other input is `true` and input
788 // from guard block is Cond.
789 Value *AssumeCond = Builder.CreateAnd(WidenedChecks);
790 if (!IfTrueBB->getUniquePredecessor()) {
791 auto *GuardBB = BI->getParent();
792 auto *PN = Builder.CreatePHI(AssumeCond->getType(), pred_size(IfTrueBB),
793 "assume.cond");
794 for (auto *Pred : predecessors(IfTrueBB))
795 PN->addIncoming(Pred == GuardBB ? AssumeCond : Builder.getTrue(), Pred);
796 AssumeCond = PN;
797 }
798 Builder.CreateAssumption(AssumeCond);
799 }
800 RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU);
802 "Stopped being a guard after transform?");
803
804 LLVM_DEBUG(dbgs() << "Widened checks = " << WidenedChecks.size() << "\n");
805 return true;
806}
807
808std::optional<LoopICmp> LoopPredication::parseLoopLatchICmp() {
809 using namespace PatternMatch;
810
811 BasicBlock *LoopLatch = L->getLoopLatch();
812 if (!LoopLatch) {
813 LLVM_DEBUG(dbgs() << "The loop doesn't have a single latch!\n");
814 return std::nullopt;
815 }
816
817 auto *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator());
818 if (!BI || !BI->isConditional()) {
819 LLVM_DEBUG(dbgs() << "Failed to match the latch terminator!\n");
820 return std::nullopt;
821 }
822 BasicBlock *TrueDest = BI->getSuccessor(0);
823 assert(
824 (TrueDest == L->getHeader() || BI->getSuccessor(1) == L->getHeader()) &&
825 "One of the latch's destinations must be the header");
826
827 auto *ICI = dyn_cast<ICmpInst>(BI->getCondition());
828 if (!ICI) {
829 LLVM_DEBUG(dbgs() << "Failed to match the latch condition!\n");
830 return std::nullopt;
831 }
832 auto Result = parseLoopICmp(ICI);
833 if (!Result) {
834 LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n");
835 return std::nullopt;
836 }
837
838 if (TrueDest != L->getHeader())
840
841 // Check affine first, so if it's not we don't try to compute the step
842 // recurrence.
843 if (!Result->IV->isAffine()) {
844 LLVM_DEBUG(dbgs() << "The induction variable is not affine!\n");
845 return std::nullopt;
846 }
847
848 const SCEV *Step = Result->IV->getStepRecurrence(*SE);
849 if (!isSupportedStep(Step)) {
850 LLVM_DEBUG(dbgs() << "Unsupported loop stride(" << *Step << ")!\n");
851 return std::nullopt;
852 }
853
854 auto IsUnsupportedPredicate = [](const SCEV *Step, ICmpInst::Predicate Pred) {
855 if (Step->isOne()) {
856 return Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_SLT &&
857 Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_SLE;
858 } else {
859 assert(Step->isAllOnesValue() && "Step should be -1!");
860 return Pred != ICmpInst::ICMP_UGT && Pred != ICmpInst::ICMP_SGT &&
861 Pred != ICmpInst::ICMP_UGE && Pred != ICmpInst::ICMP_SGE;
862 }
863 };
864
865 normalizePredicate(SE, L, *Result);
866 if (IsUnsupportedPredicate(Step, Result->Pred)) {
867 LLVM_DEBUG(dbgs() << "Unsupported loop latch predicate(" << Result->Pred
868 << ")!\n");
869 return std::nullopt;
870 }
871
872 return Result;
873}
874
875bool LoopPredication::isLoopProfitableToPredicate() {
877 return true;
878
880 L->getExitEdges(ExitEdges);
881 // If there is only one exiting edge in the loop, it is always profitable to
882 // predicate the loop.
883 if (ExitEdges.size() == 1)
884 return true;
885
886 // Calculate the exiting probabilities of all exiting edges from the loop,
887 // starting with the LatchExitProbability.
888 // Heuristic for profitability: If any of the exiting blocks' probability of
889 // exiting the loop is larger than exiting through the latch block, it's not
890 // profitable to predicate the loop.
891 auto *LatchBlock = L->getLoopLatch();
892 assert(LatchBlock && "Should have a single latch at this point!");
893 auto *LatchTerm = LatchBlock->getTerminator();
894 assert(LatchTerm->getNumSuccessors() == 2 &&
895 "expected to be an exiting block with 2 succs!");
896 unsigned LatchBrExitIdx =
897 LatchTerm->getSuccessor(0) == L->getHeader() ? 1 : 0;
898 // We compute branch probabilities without BPI. We do not rely on BPI since
899 // Loop predication is usually run in an LPM and BPI is only preserved
900 // lossily within loop pass managers, while BPI has an inherent notion of
901 // being complete for an entire function.
902
903 // If the latch exits into a deoptimize or an unreachable block, do not
904 // predicate on that latch check.
905 auto *LatchExitBlock = LatchTerm->getSuccessor(LatchBrExitIdx);
906 if (isa<UnreachableInst>(LatchTerm) ||
907 LatchExitBlock->getTerminatingDeoptimizeCall())
908 return false;
909
910 // Latch terminator has no valid profile data, so nothing to check
911 // profitability on.
912 if (!hasValidBranchWeightMD(*LatchTerm))
913 return true;
914
915 auto ComputeBranchProbability =
916 [&](const BasicBlock *ExitingBlock,
917 const BasicBlock *ExitBlock) -> BranchProbability {
918 auto *Term = ExitingBlock->getTerminator();
919 unsigned NumSucc = Term->getNumSuccessors();
920 if (MDNode *ProfileData = getValidBranchWeightMDNode(*Term)) {
921 SmallVector<uint32_t> Weights;
922 extractBranchWeights(ProfileData, Weights);
923 uint64_t Numerator = 0, Denominator = 0;
924 for (auto [i, Weight] : llvm::enumerate(Weights)) {
925 if (Term->getSuccessor(i) == ExitBlock)
926 Numerator += Weight;
927 Denominator += Weight;
928 }
929 // If all weights are zero act as if there was no profile data
930 if (Denominator == 0)
932 return BranchProbability::getBranchProbability(Numerator, Denominator);
933 } else {
934 assert(LatchBlock != ExitingBlock &&
935 "Latch term should always have profile data!");
936 // No profile data, so we choose the weight as 1/num_of_succ(Src)
938 }
939 };
940
941 BranchProbability LatchExitProbability =
942 ComputeBranchProbability(LatchBlock, LatchExitBlock);
943
944 // Protect against degenerate inputs provided by the user. Providing a value
945 // less than one, can invert the definition of profitable loop predication.
946 float ScaleFactor = LatchExitProbabilityScale;
947 if (ScaleFactor < 1) {
949 dbgs()
950 << "Ignored user setting for loop-predication-latch-probability-scale: "
951 << LatchExitProbabilityScale << "\n");
952 LLVM_DEBUG(dbgs() << "The value is set to 1.0\n");
953 ScaleFactor = 1.0;
954 }
955 const auto LatchProbabilityThreshold = LatchExitProbability * ScaleFactor;
956
957 for (const auto &ExitEdge : ExitEdges) {
958 BranchProbability ExitingBlockProbability =
959 ComputeBranchProbability(ExitEdge.first, ExitEdge.second);
960 // Some exiting edge has higher probability than the latch exiting edge.
961 // No longer profitable to predicate.
962 if (ExitingBlockProbability > LatchProbabilityThreshold)
963 return false;
964 }
965
966 // We have concluded that the most probable way to exit from the
967 // loop is through the latch (or there's no profile information and all
968 // exits are equally likely).
969 return true;
970}
971
972/// If we can (cheaply) find a widenable branch which controls entry into the
973/// loop, return it.
975 // Walk back through any unconditional executed blocks and see if we can find
976 // a widenable condition which seems to control execution of this loop. Note
977 // that we predict that maythrow calls are likely untaken and thus that it's
978 // profitable to widen a branch before a maythrow call with a condition
979 // afterwards even though that may cause the slow path to run in a case where
980 // it wouldn't have otherwise.
981 BasicBlock *BB = L->getLoopPreheader();
982 if (!BB)
983 return nullptr;
984 do {
985 if (BasicBlock *Pred = BB->getSinglePredecessor())
986 if (BB == Pred->getSingleSuccessor()) {
987 BB = Pred;
988 continue;
989 }
990 break;
991 } while (true);
992
993 if (BasicBlock *Pred = BB->getSinglePredecessor()) {
994 if (auto *BI = dyn_cast<BranchInst>(Pred->getTerminator()))
995 if (BI->getSuccessor(0) == BB && isWidenableBranch(BI))
996 return BI;
997 }
998 return nullptr;
999}
1000
1001/// Return the minimum of all analyzeable exit counts. This is an upper bound
1002/// on the actual exit count. If there are not at least two analyzeable exits,
1003/// returns SCEVCouldNotCompute.
1005 DominatorTree &DT,
1006 Loop *L) {
1007 SmallVector<BasicBlock *, 16> ExitingBlocks;
1008 L->getExitingBlocks(ExitingBlocks);
1009
1011 for (BasicBlock *ExitingBB : ExitingBlocks) {
1012 const SCEV *ExitCount = SE.getExitCount(L, ExitingBB);
1013 if (isa<SCEVCouldNotCompute>(ExitCount))
1014 continue;
1015 assert(DT.dominates(ExitingBB, L->getLoopLatch()) &&
1016 "We should only have known counts for exiting blocks that "
1017 "dominate latch!");
1018 ExitCounts.push_back(ExitCount);
1019 }
1020 if (ExitCounts.size() < 2)
1021 return SE.getCouldNotCompute();
1022 return SE.getUMinFromMismatchedTypes(ExitCounts);
1023}
1024
1025/// This implements an analogous, but entirely distinct transform from the main
1026/// loop predication transform. This one is phrased in terms of using a
1027/// widenable branch *outside* the loop to allow us to simplify loop exits in a
1028/// following loop. This is close in spirit to the IndVarSimplify transform
1029/// of the same name, but is materially different widening loosens legality
1030/// sharply.
1031bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
1032 // The transformation performed here aims to widen a widenable condition
1033 // above the loop such that all analyzeable exit leading to deopt are dead.
1034 // It assumes that the latch is the dominant exit for profitability and that
1035 // exits branching to deoptimizing blocks are rarely taken. It relies on the
1036 // semantics of widenable expressions for legality. (i.e. being able to fall
1037 // down the widenable path spuriously allows us to ignore exit order,
1038 // unanalyzeable exits, side effects, exceptional exits, and other challenges
1039 // which restrict the applicability of the non-WC based version of this
1040 // transform in IndVarSimplify.)
1041 //
1042 // NOTE ON POISON/UNDEF - We're hoisting an expression above guards which may
1043 // imply flags on the expression being hoisted and inserting new uses (flags
1044 // are only correct for current uses). The result is that we may be
1045 // inserting a branch on the value which can be either poison or undef. In
1046 // this case, the branch can legally go either way; we just need to avoid
1047 // introducing UB. This is achieved through the use of the freeze
1048 // instruction.
1049
1050 SmallVector<BasicBlock *, 16> ExitingBlocks;
1051 L->getExitingBlocks(ExitingBlocks);
1052
1053 if (ExitingBlocks.empty())
1054 return false; // Nothing to do.
1055
1056 auto *Latch = L->getLoopLatch();
1057 if (!Latch)
1058 return false;
1059
1060 auto *WidenableBR = FindWidenableTerminatorAboveLoop(L, *LI);
1061 if (!WidenableBR)
1062 return false;
1063
1064 const SCEV *LatchEC = SE->getExitCount(L, Latch);
1065 if (isa<SCEVCouldNotCompute>(LatchEC))
1066 return false; // profitability - want hot exit in analyzeable set
1067
1068 // At this point, we have found an analyzeable latch, and a widenable
1069 // condition above the loop. If we have a widenable exit within the loop
1070 // (for which we can't compute exit counts), drop the ability to further
1071 // widen so that we gain ability to analyze it's exit count and perform this
1072 // transform. TODO: It'd be nice to know for sure the exit became
1073 // analyzeable after dropping widenability.
1074 bool ChangedLoop = false;
1075
1076 for (auto *ExitingBB : ExitingBlocks) {
1077 if (LI->getLoopFor(ExitingBB) != L)
1078 continue;
1079
1080 auto *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
1081 if (!BI)
1082 continue;
1083
1084 if (auto WC = extractWidenableCondition(BI))
1085 if (L->contains(BI->getSuccessor(0))) {
1086 assert(WC->hasOneUse() && "Not appropriate widenable branch!");
1087 WC->user_back()->replaceUsesOfWith(
1088 WC, ConstantInt::getTrue(BI->getContext()));
1089 ChangedLoop = true;
1090 }
1091 }
1092 if (ChangedLoop)
1093 SE->forgetLoop(L);
1094
1095 // The insertion point for the widening should be at the widenably call, not
1096 // at the WidenableBR. If we do this at the widenableBR, we can incorrectly
1097 // change a loop-invariant condition to a loop-varying one.
1098 auto *IP = cast<Instruction>(WidenableBR->getCondition());
1099
1100 // The use of umin(all analyzeable exits) instead of latch is subtle, but
1101 // important for profitability. We may have a loop which hasn't been fully
1102 // canonicalized just yet. If the exit we chose to widen is provably never
1103 // taken, we want the widened form to *also* be provably never taken. We
1104 // can't guarantee this as a current unanalyzeable exit may later become
1105 // analyzeable, but we can at least avoid the obvious cases.
1106 const SCEV *MinEC = getMinAnalyzeableBackedgeTakenCount(*SE, *DT, L);
1107 if (isa<SCEVCouldNotCompute>(MinEC) || MinEC->getType()->isPointerTy() ||
1108 !SE->isLoopInvariant(MinEC, L) ||
1109 !Rewriter.isSafeToExpandAt(MinEC, IP))
1110 return ChangedLoop;
1111
1112 Rewriter.setInsertPoint(IP);
1113 IRBuilder<> B(IP);
1114
1115 bool InvalidateLoop = false;
1116 Value *MinECV = nullptr; // lazily generated if needed
1117 for (BasicBlock *ExitingBB : ExitingBlocks) {
1118 // If our exiting block exits multiple loops, we can only rewrite the
1119 // innermost one. Otherwise, we're changing how many times the innermost
1120 // loop runs before it exits.
1121 if (LI->getLoopFor(ExitingBB) != L)
1122 continue;
1123
1124 // Can't rewrite non-branch yet.
1125 auto *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
1126 if (!BI)
1127 continue;
1128
1129 // If already constant, nothing to do.
1130 if (isa<Constant>(BI->getCondition()))
1131 continue;
1132
1133 const SCEV *ExitCount = SE->getExitCount(L, ExitingBB);
1134 if (isa<SCEVCouldNotCompute>(ExitCount) ||
1135 ExitCount->getType()->isPointerTy() ||
1136 !Rewriter.isSafeToExpandAt(ExitCount, WidenableBR))
1137 continue;
1138
1139 const bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB));
1140 BasicBlock *ExitBB = BI->getSuccessor(ExitIfTrue ? 0 : 1);
1141 if (!ExitBB->getPostdominatingDeoptimizeCall())
1142 continue;
1143
1144 /// Here we can be fairly sure that executing this exit will most likely
1145 /// lead to executing llvm.experimental.deoptimize.
1146 /// This is a profitability heuristic, not a legality constraint.
1147
1148 // If we found a widenable exit condition, do two things:
1149 // 1) fold the widened exit test into the widenable condition
1150 // 2) fold the branch to untaken - avoids infinite looping
1151
1152 Value *ECV = Rewriter.expandCodeFor(ExitCount);
1153 if (!MinECV)
1154 MinECV = Rewriter.expandCodeFor(MinEC);
1155 Value *RHS = MinECV;
1156 if (ECV->getType() != RHS->getType()) {
1157 Type *WiderTy = SE->getWiderType(ECV->getType(), RHS->getType());
1158 ECV = B.CreateZExt(ECV, WiderTy);
1159 RHS = B.CreateZExt(RHS, WiderTy);
1160 }
1161 assert(!Latch || DT->dominates(ExitingBB, Latch));
1162 Value *NewCond = B.CreateICmp(ICmpInst::ICMP_UGT, ECV, RHS);
1163 // Freeze poison or undef to an arbitrary bit pattern to ensure we can
1164 // branch without introducing UB. See NOTE ON POISON/UNDEF above for
1165 // context.
1166 NewCond = B.CreateFreeze(NewCond);
1167
1168 widenWidenableBranch(WidenableBR, NewCond);
1169
1170 Value *OldCond = BI->getCondition();
1171 BI->setCondition(ConstantInt::get(OldCond->getType(), !ExitIfTrue));
1172 InvalidateLoop = true;
1173 }
1174
1175 if (InvalidateLoop)
1176 // We just mutated a bunch of loop exits changing there exit counts
1177 // widely. We need to force recomputation of the exit counts given these
1178 // changes. Note that all of the inserted exits are never taken, and
1179 // should be removed next time the CFG is modified.
1180 SE->forgetLoop(L);
1181
1182 // Always return `true` since we have moved the WidenableBR's condition.
1183 return true;
1184}
1185
1186bool LoopPredication::runOnLoop(Loop *Loop) {
1187 L = Loop;
1188
1189 LLVM_DEBUG(dbgs() << "Analyzing ");
1190 LLVM_DEBUG(L->dump());
1191
1192 Module *M = L->getHeader()->getModule();
1193
1194 // There is nothing to do if the module doesn't use guards
1195 auto *GuardDecl =
1196 M->getFunction(Intrinsic::getName(Intrinsic::experimental_guard));
1197 bool HasIntrinsicGuards = GuardDecl && !GuardDecl->use_empty();
1198 auto *WCDecl = M->getFunction(
1199 Intrinsic::getName(Intrinsic::experimental_widenable_condition));
1200 bool HasWidenableConditions =
1201 PredicateWidenableBranchGuards && WCDecl && !WCDecl->use_empty();
1202 if (!HasIntrinsicGuards && !HasWidenableConditions)
1203 return false;
1204
1205 DL = &M->getDataLayout();
1206
1207 Preheader = L->getLoopPreheader();
1208 if (!Preheader)
1209 return false;
1210
1211 auto LatchCheckOpt = parseLoopLatchICmp();
1212 if (!LatchCheckOpt)
1213 return false;
1214 LatchCheck = *LatchCheckOpt;
1215
1216 LLVM_DEBUG(dbgs() << "Latch check:\n");
1217 LLVM_DEBUG(LatchCheck.dump());
1218
1219 if (!isLoopProfitableToPredicate()) {
1220 LLVM_DEBUG(dbgs() << "Loop not profitable to predicate!\n");
1221 return false;
1222 }
1223 // Collect all the guards into a vector and process later, so as not
1224 // to invalidate the instruction iterator.
1226 SmallVector<BranchInst *, 4> GuardsAsWidenableBranches;
1227 for (const auto BB : L->blocks()) {
1228 for (auto &I : *BB)
1229 if (isGuard(&I))
1230 Guards.push_back(cast<IntrinsicInst>(&I));
1232 isGuardAsWidenableBranch(BB->getTerminator()))
1233 GuardsAsWidenableBranches.push_back(
1234 cast<BranchInst>(BB->getTerminator()));
1235 }
1236
1237 SCEVExpander Expander(*SE, *DL, "loop-predication");
1238 bool Changed = false;
1239 for (auto *Guard : Guards)
1240 Changed |= widenGuardConditions(Guard, Expander);
1241 for (auto *Guard : GuardsAsWidenableBranches)
1242 Changed |= widenWidenableBranchGuardConditions(Guard, Expander);
1243 Changed |= predicateLoopExits(L, Expander);
1244
1245 if (MSSAU && VerifyMemorySSA)
1246 MSSAU->getMemorySSA()->verifyMemorySSA();
1247 return Changed;
1248}
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DEBUG(X)
Definition: Debug.h:101
#define Check(C,...)
static cl::opt< bool > SkipProfitabilityChecks("irce-skip-profitability-checks", cl::Hidden, cl::init(false))
static cl::opt< float > LatchExitProbabilityScale("loop-predication-latch-probability-scale", cl::Hidden, cl::init(2.0), cl::desc("scale factor for the latch probability. Value should be greater " "than 1. Lower values are ignored"))
static void normalizePredicate(ScalarEvolution *SE, Loop *L, LoopICmp &RC)
static cl::opt< bool > SkipProfitabilityChecks("loop-predication-skip-profitability-checks", cl::Hidden, cl::init(false))
static const SCEV * getMinAnalyzeableBackedgeTakenCount(ScalarEvolution &SE, DominatorTree &DT, Loop *L)
Return the minimum of all analyzeable exit counts.
static cl::opt< bool > EnableCountDownLoop("loop-predication-enable-count-down-loop", cl::Hidden, cl::init(true))
static cl::opt< bool > EnableIVTruncation("loop-predication-enable-iv-truncation", cl::Hidden, cl::init(true))
static std::optional< LoopICmp > generateLoopLatchCheck(const DataLayout &DL, ScalarEvolution &SE, const LoopICmp LatchCheck, Type *RangeCheckType)
static cl::opt< bool > PredicateWidenableBranchGuards("loop-predication-predicate-widenable-branches-to-deopt", cl::Hidden, cl::desc("Whether or not we should predicate guards " "expressed as widenable branches to deoptimize blocks"), cl::init(true))
static bool isSafeToTruncateWideIVType(const DataLayout &DL, ScalarEvolution &SE, const LoopICmp LatchCheck, Type *RangeCheckType)
static cl::opt< bool > InsertAssumesOfPredicatedGuardsConditions("loop-predication-insert-assumes-of-predicated-guards-conditions", cl::Hidden, cl::desc("Whether or not we should insert assumes of conditions of " "predicated guards"), cl::init(true))
static BranchInst * FindWidenableTerminatorAboveLoop(Loop *L, LoopInfo &LI)
If we can (cheaply) find a widenable branch which controls entry into the loop, return it.
#define I(x, y, z)
Definition: MD5.cpp:58
This file exposes an interface to building/using memory SSA to walk memory instructions using a use/d...
Module.h This file contains the declarations for the Module class.
uint64_t IntrinsicInst * II
This file contains the declarations for profiling metadata utility functions.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:166
Virtual Register Rewriter
Definition: VirtRegMap.cpp:237
Value * RHS
Value * LHS
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
LLVM Basic Block Representation.
Definition: BasicBlock.h:61
const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
Definition: BasicBlock.cpp:416
const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
Definition: BasicBlock.cpp:459
const BasicBlock * getUniquePredecessor() const
Return the predecessor of this block if it has a unique predecessor block.
Definition: BasicBlock.cpp:467
const BasicBlock * getSingleSuccessor() const
Return the successor of this block if it has a single successor.
Definition: BasicBlock.cpp:489
InstListType::iterator iterator
Instruction iterators...
Definition: BasicBlock.h:177
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:239
const CallInst * getPostdominatingDeoptimizeCall() const
Returns the call instruction calling @llvm.experimental.deoptimize that is present either in current ...
Definition: BasicBlock.cpp:346
Conditional or Unconditional Branch instruction.
void setCondition(Value *V)
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
Value * getCondition() const
static BranchProbability getBranchProbability(uint64_t Numerator, uint64_t Denominator)
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:757
@ ICMP_SLT
signed less than
Definition: InstrTypes.h:786
@ ICMP_SLE
signed less or equal
Definition: InstrTypes.h:787
@ ICMP_UGE
unsigned greater or equal
Definition: InstrTypes.h:781
@ ICMP_UGT
unsigned greater than
Definition: InstrTypes.h:780
@ ICMP_SGT
signed greater than
Definition: InstrTypes.h:784
@ ICMP_ULT
unsigned less than
Definition: InstrTypes.h:782
@ ICMP_NE
not equal
Definition: InstrTypes.h:779
@ ICMP_SGE
signed greater or equal
Definition: InstrTypes.h:785
@ ICMP_ULE
unsigned less or equal
Definition: InstrTypes.h:783
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition: InstrTypes.h:909
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition: InstrTypes.h:871
Predicate getPredicate() const
Return the predicate for this instruction.
Definition: InstrTypes.h:847
Predicate getFlippedStrictnessPredicate() const
For predicate of kind "is X or equal to 0" returns the predicate "is X".
Definition: InstrTypes.h:975
static ConstantInt * getTrue(LLVMContext &Context)
Definition: Constants.cpp:850
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:63
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:162
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:122
This instruction compares its operands according to the predicate given to the constructor.
bool isEquality() const
Return true if this predicate is either EQ or NE.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2686
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:48
This class provides an interface for updating the loop pass manager based on mutations to the loop ne...
PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &U)
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:39
Metadata node.
Definition: Metadata.h:1069
An analysis that produces MemorySSA for a function.
Definition: MemorySSA.h:924
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
This node represents a polynomial recurrence on the trip count of the specified loop.
This class uses information about analyze scalars to rewrite expressions in canonical form.
bool isSafeToExpandAt(const SCEV *S, const Instruction *InsertionPoint) const
Return true if the given expression is safe to expand in the sense that all materialized values are d...
Value * expandCodeFor(const SCEV *SH, Type *Ty, BasicBlock::iterator I)
Insert code to directly compute the specified SCEV expression into the program.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents an analyzed expression in the program.
bool isOne() const
Return true if the expression is a constant one.
bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
Type * getType() const
Return the LLVM type of this SCEV expression.
The main scalar evolution driver.
bool isKnownPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
const SCEV * getCouldNotCompute()
const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
bool empty() const
Definition: SmallVector.h:94
size_t size() const
Definition: SmallVector.h:91
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:586
void push_back(const T &Elt)
Definition: SmallVector.h:426
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1209
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
bool isPointerTy() const
True if this is an instance of PointerType.
Definition: Type.h:251
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
void setOperand(unsigned i, Value *Val)
Definition: User.h:174
Value * getOperand(unsigned i) const
Definition: User.h:169
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:1075
void dump() const
Support for debugging, callable in GDB: V->dump()
Definition: AsmWriter.cpp:5266
const ParentTy * getParent() const
Definition: ilist_node.h:32
StringRef getName(ID id)
Return the LLVM name for an intrinsic, such as "llvm.ppc.altivec.lvx".
Definition: Function.cpp:1096
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:443
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
Definition: Local.cpp:540
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are tuples (A, B,...
Definition: STLExtras.h:2431
void widenWidenableBranch(BranchInst *WidenableBR, Value *NewCond)
Given a branch we know is widenable (defined per Analysis/GuardUtils.h), widen it such that condition...
Definition: GuardUtils.cpp:82
Value * extractWidenableCondition(const User *U)
Definition: GuardUtils.cpp:151
void parseWidenableGuard(const User *U, llvm::SmallVectorImpl< Value * > &Checks)
Definition: GuardUtils.cpp:138
bool isGuard(const User *U)
Returns true iff U has semantics of a guard expressed in a form of call of llvm.experimental....
Definition: GuardUtils.cpp:18
MDNode * getValidBranchWeightMDNode(const Instruction &I)
Get the valid branch weights metadata node.
bool isModSet(const ModRefInfo MRI)
Definition: ModRef.h:48
bool hasValidBranchWeightMD(const Instruction &I)
Checks if an instructions has valid Branch Weight Metadata.
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
RNSuccIterator< NodeRef, BlockT, RegionT > succ_begin(NodeRef Node)
bool isWidenableBranch(const User *U)
Returns true iff U is a widenable branch (that is, extractWidenableCondition returns widenable condit...
Definition: GuardUtils.cpp:26
bool VerifyMemorySSA
Enables verification of MemorySSA.
Definition: MemorySSA.cpp:84
bool isGuardAsWidenableBranch(const User *U)
Returns true iff U has semantics of a guard expressed in a form of a widenable conditional branch to ...
Definition: GuardUtils.cpp:33
bool extractBranchWeights(const MDNode *ProfileData, SmallVectorImpl< uint32_t > &Weights)
Extract branch weights from MD_prof metadata.
PreservedAnalyses getLoopPassPreservedAnalyses()
Returns the minimum set of Analyses that all loop passes must preserve.
auto predecessors(const MachineBasicBlock *BB)
unsigned pred_size(const MachineBasicBlock *BB)
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:860
The adaptor from a function pass to a loop pass computes these analyses and makes them available to t...