LLVM  8.0.0svn
LoopPredication.cpp
Go to the documentation of this file.
1 //===-- LoopPredication.cpp - Guard based loop predication pass -----------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // The LoopPredication pass tries to convert loop variant range checks to loop
11 // invariant by widening checks across loop iterations. For example, it will
12 // convert
13 //
14 // for (i = 0; i < n; i++) {
15 // guard(i < len);
16 // ...
17 // }
18 //
19 // to
20 //
21 // for (i = 0; i < n; i++) {
22 // guard(n - 1 < len);
23 // ...
24 // }
25 //
26 // After this transformation the condition of the guard is loop invariant, so
27 // loop-unswitch can later unswitch the loop by this condition which basically
28 // predicates the loop by the widened condition:
29 //
30 // if (n - 1 < len)
31 // for (i = 0; i < n; i++) {
32 // ...
33 // }
34 // else
35 // deoptimize
36 //
37 // It's tempting to rely on SCEV here, but it has proven to be problematic.
38 // Generally the facts SCEV provides about the increment step of add
39 // recurrences are true if the backedge of the loop is taken, which implicitly
40 // assumes that the guard doesn't fail. Using these facts to optimize the
41 // guard results in a circular logic where the guard is optimized under the
42 // assumption that it never fails.
43 //
44 // For example, in the loop below the induction variable will be marked as nuw
45 // basing on the guard. Basing on nuw the guard predicate will be considered
46 // monotonic. Given a monotonic condition it's tempting to replace the induction
47 // variable in the condition with its value on the last iteration. But this
48 // transformation is not correct, e.g. e = 4, b = 5 breaks the loop.
49 //
50 // for (int i = b; i != e; i++)
51 // guard(i u< len)
52 //
53 // One of the ways to reason about this problem is to use an inductive proof
54 // approach. Given the loop:
55 //
56 // if (B(0)) {
57 // do {
58 // I = PHI(0, I.INC)
59 // I.INC = I + Step
60 // guard(G(I));
61 // } while (B(I));
62 // }
63 //
64 // where B(x) and G(x) are predicates that map integers to booleans, we want a
65 // loop invariant expression M such the following program has the same semantics
66 // as the above:
67 //
68 // if (B(0)) {
69 // do {
70 // I = PHI(0, I.INC)
71 // I.INC = I + Step
72 // guard(G(0) && M);
73 // } while (B(I));
74 // }
75 //
76 // One solution for M is M = forall X . (G(X) && B(X)) => G(X + Step)
77 //
78 // Informal proof that the transformation above is correct:
79 //
80 // By the definition of guards we can rewrite the guard condition to:
81 // G(I) && G(0) && M
82 //
83 // Let's prove that for each iteration of the loop:
84 // G(0) && M => G(I)
85 // And the condition above can be simplified to G(Start) && M.
86 //
87 // Induction base.
88 // G(0) && M => G(0)
89 //
90 // Induction step. Assuming G(0) && M => G(I) on the subsequent
91 // iteration:
92 //
93 // B(I) is true because it's the backedge condition.
94 // G(I) is true because the backedge is guarded by this condition.
95 //
96 // So M = forall X . (G(X) && B(X)) => G(X + Step) implies G(I + Step).
97 //
98 // Note that we can use anything stronger than M, i.e. any condition which
99 // implies M.
100 //
101 // When S = 1 (i.e. forward iterating loop), the transformation is supported
102 // when:
103 // * The loop has a single latch with the condition of the form:
104 // B(X) = latchStart + X <pred> latchLimit,
105 // where <pred> is u<, u<=, s<, or s<=.
106 // * The guard condition is of the form
107 // G(X) = guardStart + X u< guardLimit
108 //
109 // For the ult latch comparison case M is:
110 // forall X . guardStart + X u< guardLimit && latchStart + X <u latchLimit =>
111 // guardStart + X + 1 u< guardLimit
112 //
113 // The only way the antecedent can be true and the consequent can be false is
114 // if
115 // X == guardLimit - 1 - guardStart
116 // (and guardLimit is non-zero, but we won't use this latter fact).
117 // If X == guardLimit - 1 - guardStart then the second half of the antecedent is
118 // latchStart + guardLimit - 1 - guardStart u< latchLimit
119 // and its negation is
120 // latchStart + guardLimit - 1 - guardStart u>= latchLimit
121 //
122 // In other words, if
123 // latchLimit u<= latchStart + guardLimit - 1 - guardStart
124 // then:
125 // (the ranges below are written in ConstantRange notation, where [A, B) is the
126 // set for (I = A; I != B; I++ /*maywrap*/) yield(I);)
127 //
128 // forall X . guardStart + X u< guardLimit &&
129 // latchStart + X u< latchLimit =>
130 // guardStart + X + 1 u< guardLimit
131 // == forall X . guardStart + X u< guardLimit &&
132 // latchStart + X u< latchStart + guardLimit - 1 - guardStart =>
133 // guardStart + X + 1 u< guardLimit
134 // == forall X . (guardStart + X) in [0, guardLimit) &&
135 // (latchStart + X) in [0, latchStart + guardLimit - 1 - guardStart) =>
136 // (guardStart + X + 1) in [0, guardLimit)
137 // == forall X . X in [-guardStart, guardLimit - guardStart) &&
138 // X in [-latchStart, guardLimit - 1 - guardStart) =>
139 // X in [-guardStart - 1, guardLimit - guardStart - 1)
140 // == true
141 //
142 // So the widened condition is:
143 // guardStart u< guardLimit &&
144 // latchStart + guardLimit - 1 - guardStart u>= latchLimit
145 // Similarly for ule condition the widened condition is:
146 // guardStart u< guardLimit &&
147 // latchStart + guardLimit - 1 - guardStart u> latchLimit
148 // For slt condition the widened condition is:
149 // guardStart u< guardLimit &&
150 // latchStart + guardLimit - 1 - guardStart s>= latchLimit
151 // For sle condition the widened condition is:
152 // guardStart u< guardLimit &&
153 // latchStart + guardLimit - 1 - guardStart s> latchLimit
154 //
155 // When S = -1 (i.e. reverse iterating loop), the transformation is supported
156 // when:
157 // * The loop has a single latch with the condition of the form:
158 // B(X) = X <pred> latchLimit, where <pred> is u>, u>=, s>, or s>=.
159 // * The guard condition is of the form
160 // G(X) = X - 1 u< guardLimit
161 //
162 // For the ugt latch comparison case M is:
163 // forall X. X-1 u< guardLimit and X u> latchLimit => X-2 u< guardLimit
164 //
165 // The only way the antecedent can be true and the consequent can be false is if
166 // X == 1.
167 // If X == 1 then the second half of the antecedent is
168 // 1 u> latchLimit, and its negation is latchLimit u>= 1.
169 //
170 // So the widened condition is:
171 // guardStart u< guardLimit && latchLimit u>= 1.
172 // Similarly for sgt condition the widened condition is:
173 // guardStart u< guardLimit && latchLimit s>= 1.
174 // For uge condition the widened condition is:
175 // guardStart u< guardLimit && latchLimit u> 1.
176 // For sge condition the widened condition is:
177 // guardStart u< guardLimit && latchLimit s> 1.
178 //===----------------------------------------------------------------------===//
179 
181 #include "llvm/ADT/Statistic.h"
183 #include "llvm/Analysis/LoopInfo.h"
184 #include "llvm/Analysis/LoopPass.h"
188 #include "llvm/IR/Function.h"
189 #include "llvm/IR/GlobalValue.h"
190 #include "llvm/IR/IntrinsicInst.h"
191 #include "llvm/IR/Module.h"
192 #include "llvm/IR/PatternMatch.h"
193 #include "llvm/Pass.h"
194 #include "llvm/Support/Debug.h"
195 #include "llvm/Transforms/Scalar.h"
197 
198 #define DEBUG_TYPE "loop-predication"
199 
200 STATISTIC(TotalConsidered, "Number of guards considered");
201 STATISTIC(TotalWidened, "Number of checks widened");
202 
203 using namespace llvm;
204 
205 static cl::opt<bool> EnableIVTruncation("loop-predication-enable-iv-truncation",
206  cl::Hidden, cl::init(true));
207 
208 static cl::opt<bool> EnableCountDownLoop("loop-predication-enable-count-down-loop",
209  cl::Hidden, cl::init(true));
210 
211 static cl::opt<bool>
212  SkipProfitabilityChecks("loop-predication-skip-profitability-checks",
213  cl::Hidden, cl::init(false));
214 
215 // This is the scale factor for the latch probability. We use this during
216 // profitability analysis to find other exiting blocks that have a much higher
217 // probability of exiting the loop instead of loop exiting via latch.
218 // This value should be greater than 1 for a sane profitability check.
220  "loop-predication-latch-probability-scale", cl::Hidden, cl::init(2.0),
221  cl::desc("scale factor for the latch probability. Value should be greater "
222  "than 1. Lower values are ignored"));
223 
224 namespace {
225 class LoopPredication {
226  /// Represents an induction variable check:
227  /// icmp Pred, <induction variable>, <loop invariant limit>
228  struct LoopICmp {
229  ICmpInst::Predicate Pred;
230  const SCEVAddRecExpr *IV;
231  const SCEV *Limit;
232  LoopICmp(ICmpInst::Predicate Pred, const SCEVAddRecExpr *IV,
233  const SCEV *Limit)
234  : Pred(Pred), IV(IV), Limit(Limit) {}
235  LoopICmp() {}
236  void dump() {
237  dbgs() << "LoopICmp Pred = " << Pred << ", IV = " << *IV
238  << ", Limit = " << *Limit << "\n";
239  }
240  };
241 
242  ScalarEvolution *SE;
244 
245  Loop *L;
246  const DataLayout *DL;
247  BasicBlock *Preheader;
248  LoopICmp LatchCheck;
249 
250  bool isSupportedStep(const SCEV* Step);
251  Optional<LoopICmp> parseLoopICmp(ICmpInst *ICI) {
252  return parseLoopICmp(ICI->getPredicate(), ICI->getOperand(0),
253  ICI->getOperand(1));
254  }
255  Optional<LoopICmp> parseLoopICmp(ICmpInst::Predicate Pred, Value *LHS,
256  Value *RHS);
257 
258  Optional<LoopICmp> parseLoopLatchICmp();
259 
260  bool CanExpand(const SCEV* S);
261  Value *expandCheck(SCEVExpander &Expander, IRBuilder<> &Builder,
262  ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
263  Instruction *InsertAt);
264 
265  Optional<Value *> widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander,
266  IRBuilder<> &Builder);
267  Optional<Value *> widenICmpRangeCheckIncrementingLoop(LoopICmp LatchCheck,
268  LoopICmp RangeCheck,
269  SCEVExpander &Expander,
270  IRBuilder<> &Builder);
271  Optional<Value *> widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck,
272  LoopICmp RangeCheck,
273  SCEVExpander &Expander,
274  IRBuilder<> &Builder);
275  bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander);
276 
277  // If the loop always exits through another block in the loop, we should not
278  // predicate based on the latch check. For example, the latch check can be a
279  // very coarse grained check and there can be more fine grained exit checks
280  // within the loop. We identify such unprofitable loops through BPI.
281  bool isLoopProfitableToPredicate();
282 
283  // When the IV type is wider than the range operand type, we can still do loop
284  // predication, by generating SCEVs for the range and latch that are of the
285  // same type. We achieve this by generating a SCEV truncate expression for the
286  // latch IV. This is done iff truncation of the IV is a safe operation,
287  // without loss of information.
288  // Another way to achieve this is by generating a wider type SCEV for the
289  // range check operand, however, this needs a more involved check that
290  // operands do not overflow. This can lead to loss of information when the
291  // range operand is of the form: add i32 %offset, %iv. We need to prove that
292  // sext(x + y) is same as sext(x) + sext(y).
293  // This function returns true if we can safely represent the IV type in
294  // the RangeCheckType without loss of information.
295  bool isSafeToTruncateWideIVType(Type *RangeCheckType);
296  // Return the loopLatchCheck corresponding to the RangeCheckType if safe to do
297  // so.
298  Optional<LoopICmp> generateLoopLatchCheck(Type *RangeCheckType);
299 
300 public:
301  LoopPredication(ScalarEvolution *SE, BranchProbabilityInfo *BPI)
302  : SE(SE), BPI(BPI){};
303  bool runOnLoop(Loop *L);
304 };
305 
306 class LoopPredicationLegacyPass : public LoopPass {
307 public:
308  static char ID;
309  LoopPredicationLegacyPass() : LoopPass(ID) {
311  }
312 
313  void getAnalysisUsage(AnalysisUsage &AU) const override {
316  }
317 
318  bool runOnLoop(Loop *L, LPPassManager &LPM) override {
319  if (skipLoop(L))
320  return false;
321  auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
322  BranchProbabilityInfo &BPI =
323  getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI();
324  LoopPredication LP(SE, &BPI);
325  return LP.runOnLoop(L);
326  }
327 };
328 
330 } // end namespace llvm
331 
332 INITIALIZE_PASS_BEGIN(LoopPredicationLegacyPass, "loop-predication",
333  "Loop predication", false, false)
336 INITIALIZE_PASS_END(LoopPredicationLegacyPass, "loop-predication",
337  "Loop predication", false, false)
338 
340  return new LoopPredicationLegacyPass();
341 }
342 
345  LPMUpdater &U) {
346  const auto &FAM =
347  AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager();
348  Function *F = L.getHeader()->getParent();
349  auto *BPI = FAM.getCachedResult<BranchProbabilityAnalysis>(*F);
350  LoopPredication LP(&AR.SE, BPI);
351  if (!LP.runOnLoop(&L))
352  return PreservedAnalyses::all();
353 
355 }
356 
358 LoopPredication::parseLoopICmp(ICmpInst::Predicate Pred, Value *LHS,
359  Value *RHS) {
360  const SCEV *LHSS = SE->getSCEV(LHS);
361  if (isa<SCEVCouldNotCompute>(LHSS))
362  return None;
363  const SCEV *RHSS = SE->getSCEV(RHS);
364  if (isa<SCEVCouldNotCompute>(RHSS))
365  return None;
366 
367  // Canonicalize RHS to be loop invariant bound, LHS - a loop computable IV
368  if (SE->isLoopInvariant(LHSS, L)) {
369  std::swap(LHS, RHS);
370  std::swap(LHSS, RHSS);
371  Pred = ICmpInst::getSwappedPredicate(Pred);
372  }
373 
374  const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHSS);
375  if (!AR || AR->getLoop() != L)
376  return None;
377 
378  return LoopICmp(Pred, AR, RHSS);
379 }
380 
381 Value *LoopPredication::expandCheck(SCEVExpander &Expander,
382  IRBuilder<> &Builder,
383  ICmpInst::Predicate Pred, const SCEV *LHS,
384  const SCEV *RHS, Instruction *InsertAt) {
385  // TODO: we can check isLoopEntryGuardedByCond before emitting the check
386 
387  Type *Ty = LHS->getType();
388  assert(Ty == RHS->getType() && "expandCheck operands have different types?");
389 
390  if (SE->isLoopEntryGuardedByCond(L, Pred, LHS, RHS))
391  return Builder.getTrue();
392 
393  Value *LHSV = Expander.expandCodeFor(LHS, Ty, InsertAt);
394  Value *RHSV = Expander.expandCodeFor(RHS, Ty, InsertAt);
395  return Builder.CreateICmp(Pred, LHSV, RHSV);
396 }
397 
399 LoopPredication::generateLoopLatchCheck(Type *RangeCheckType) {
400 
401  auto *LatchType = LatchCheck.IV->getType();
402  if (RangeCheckType == LatchType)
403  return LatchCheck;
404  // For now, bail out if latch type is narrower than range type.
405  if (DL->getTypeSizeInBits(LatchType) < DL->getTypeSizeInBits(RangeCheckType))
406  return None;
407  if (!isSafeToTruncateWideIVType(RangeCheckType))
408  return None;
409  // We can now safely identify the truncated version of the IV and limit for
410  // RangeCheckType.
411  LoopICmp NewLatchCheck;
412  NewLatchCheck.Pred = LatchCheck.Pred;
413  NewLatchCheck.IV = dyn_cast<SCEVAddRecExpr>(
414  SE->getTruncateExpr(LatchCheck.IV, RangeCheckType));
415  if (!NewLatchCheck.IV)
416  return None;
417  NewLatchCheck.Limit = SE->getTruncateExpr(LatchCheck.Limit, RangeCheckType);
418  LLVM_DEBUG(dbgs() << "IV of type: " << *LatchType
419  << "can be represented as range check type:"
420  << *RangeCheckType << "\n");
421  LLVM_DEBUG(dbgs() << "LatchCheck.IV: " << *NewLatchCheck.IV << "\n");
422  LLVM_DEBUG(dbgs() << "LatchCheck.Limit: " << *NewLatchCheck.Limit << "\n");
423  return NewLatchCheck;
424 }
425 
426 bool LoopPredication::isSupportedStep(const SCEV* Step) {
427  return Step->isOne() || (Step->isAllOnesValue() && EnableCountDownLoop);
428 }
429 
430 bool LoopPredication::CanExpand(const SCEV* S) {
431  return SE->isLoopInvariant(S, L) && isSafeToExpand(S, *SE);
432 }
433 
434 Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop(
435  LoopPredication::LoopICmp LatchCheck, LoopPredication::LoopICmp RangeCheck,
436  SCEVExpander &Expander, IRBuilder<> &Builder) {
437  auto *Ty = RangeCheck.IV->getType();
438  // Generate the widened condition for the forward loop:
439  // guardStart u< guardLimit &&
440  // latchLimit <pred> guardLimit - 1 - guardStart + latchStart
441  // where <pred> depends on the latch condition predicate. See the file
442  // header comment for the reasoning.
443  // guardLimit - guardStart + latchStart - 1
444  const SCEV *GuardStart = RangeCheck.IV->getStart();
445  const SCEV *GuardLimit = RangeCheck.Limit;
446  const SCEV *LatchStart = LatchCheck.IV->getStart();
447  const SCEV *LatchLimit = LatchCheck.Limit;
448 
449  // guardLimit - guardStart + latchStart - 1
450  const SCEV *RHS =
451  SE->getAddExpr(SE->getMinusSCEV(GuardLimit, GuardStart),
452  SE->getMinusSCEV(LatchStart, SE->getOne(Ty)));
453  if (!CanExpand(GuardStart) || !CanExpand(GuardLimit) ||
454  !CanExpand(LatchLimit) || !CanExpand(RHS)) {
455  LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
456  return None;
457  }
458  auto LimitCheckPred =
460 
461  LLVM_DEBUG(dbgs() << "LHS: " << *LatchLimit << "\n");
462  LLVM_DEBUG(dbgs() << "RHS: " << *RHS << "\n");
463  LLVM_DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n");
464 
465  Instruction *InsertAt = Preheader->getTerminator();
466  auto *LimitCheck =
467  expandCheck(Expander, Builder, LimitCheckPred, LatchLimit, RHS, InsertAt);
468  auto *FirstIterationCheck = expandCheck(Expander, Builder, RangeCheck.Pred,
469  GuardStart, GuardLimit, InsertAt);
470  return Builder.CreateAnd(FirstIterationCheck, LimitCheck);
471 }
472 
473 Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(
474  LoopPredication::LoopICmp LatchCheck, LoopPredication::LoopICmp RangeCheck,
475  SCEVExpander &Expander, IRBuilder<> &Builder) {
476  auto *Ty = RangeCheck.IV->getType();
477  const SCEV *GuardStart = RangeCheck.IV->getStart();
478  const SCEV *GuardLimit = RangeCheck.Limit;
479  const SCEV *LatchLimit = LatchCheck.Limit;
480  if (!CanExpand(GuardStart) || !CanExpand(GuardLimit) ||
481  !CanExpand(LatchLimit)) {
482  LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
483  return None;
484  }
485  // The decrement of the latch check IV should be the same as the
486  // rangeCheckIV.
487  auto *PostDecLatchCheckIV = LatchCheck.IV->getPostIncExpr(*SE);
488  if (RangeCheck.IV != PostDecLatchCheckIV) {
489  LLVM_DEBUG(dbgs() << "Not the same. PostDecLatchCheckIV: "
490  << *PostDecLatchCheckIV
491  << " and RangeCheckIV: " << *RangeCheck.IV << "\n");
492  return None;
493  }
494 
495  // Generate the widened condition for CountDownLoop:
496  // guardStart u< guardLimit &&
497  // latchLimit <pred> 1.
498  // See the header comment for reasoning of the checks.
499  Instruction *InsertAt = Preheader->getTerminator();
500  auto LimitCheckPred =
502  auto *FirstIterationCheck = expandCheck(Expander, Builder, ICmpInst::ICMP_ULT,
503  GuardStart, GuardLimit, InsertAt);
504  auto *LimitCheck = expandCheck(Expander, Builder, LimitCheckPred, LatchLimit,
505  SE->getOne(Ty), InsertAt);
506  return Builder.CreateAnd(FirstIterationCheck, LimitCheck);
507 }
508 
509 /// If ICI can be widened to a loop invariant condition emits the loop
510 /// invariant condition in the loop preheader and return it, otherwise
511 /// returns None.
512 Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI,
513  SCEVExpander &Expander,
514  IRBuilder<> &Builder) {
515  LLVM_DEBUG(dbgs() << "Analyzing ICmpInst condition:\n");
516  LLVM_DEBUG(ICI->dump());
517 
518  // parseLoopStructure guarantees that the latch condition is:
519  // ++i <pred> latchLimit, where <pred> is u<, u<=, s<, or s<=.
520  // We are looking for the range checks of the form:
521  // i u< guardLimit
522  auto RangeCheck = parseLoopICmp(ICI);
523  if (!RangeCheck) {
524  LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n");
525  return None;
526  }
527  LLVM_DEBUG(dbgs() << "Guard check:\n");
528  LLVM_DEBUG(RangeCheck->dump());
529  if (RangeCheck->Pred != ICmpInst::ICMP_ULT) {
530  LLVM_DEBUG(dbgs() << "Unsupported range check predicate("
531  << RangeCheck->Pred << ")!\n");
532  return None;
533  }
534  auto *RangeCheckIV = RangeCheck->IV;
535  if (!RangeCheckIV->isAffine()) {
536  LLVM_DEBUG(dbgs() << "Range check IV is not affine!\n");
537  return None;
538  }
539  auto *Step = RangeCheckIV->getStepRecurrence(*SE);
540  // We cannot just compare with latch IV step because the latch and range IVs
541  // may have different types.
542  if (!isSupportedStep(Step)) {
543  LLVM_DEBUG(dbgs() << "Range check and latch have IVs different steps!\n");
544  return None;
545  }
546  auto *Ty = RangeCheckIV->getType();
547  auto CurrLatchCheckOpt = generateLoopLatchCheck(Ty);
548  if (!CurrLatchCheckOpt) {
549  LLVM_DEBUG(dbgs() << "Failed to generate a loop latch check "
550  "corresponding to range type: "
551  << *Ty << "\n");
552  return None;
553  }
554 
555  LoopICmp CurrLatchCheck = *CurrLatchCheckOpt;
556  // At this point, the range and latch step should have the same type, but need
557  // not have the same value (we support both 1 and -1 steps).
558  assert(Step->getType() ==
559  CurrLatchCheck.IV->getStepRecurrence(*SE)->getType() &&
560  "Range and latch steps should be of same type!");
561  if (Step != CurrLatchCheck.IV->getStepRecurrence(*SE)) {
562  LLVM_DEBUG(dbgs() << "Range and latch have different step values!\n");
563  return None;
564  }
565 
566  if (Step->isOne())
567  return widenICmpRangeCheckIncrementingLoop(CurrLatchCheck, *RangeCheck,
568  Expander, Builder);
569  else {
570  assert(Step->isAllOnesValue() && "Step should be -1!");
571  return widenICmpRangeCheckDecrementingLoop(CurrLatchCheck, *RangeCheck,
572  Expander, Builder);
573  }
574 }
575 
576 bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard,
577  SCEVExpander &Expander) {
578  LLVM_DEBUG(dbgs() << "Processing guard:\n");
579  LLVM_DEBUG(Guard->dump());
580 
581  TotalConsidered++;
582 
583  IRBuilder<> Builder(cast<Instruction>(Preheader->getTerminator()));
584 
585  // The guard condition is expected to be in form of:
586  // cond1 && cond2 && cond3 ...
587  // Iterate over subconditions looking for icmp conditions which can be
588  // widened across loop iterations. Widening these conditions remember the
589  // resulting list of subconditions in Checks vector.
590  SmallVector<Value *, 4> Worklist(1, Guard->getOperand(0));
591  SmallPtrSet<Value *, 4> Visited;
592 
594 
595  unsigned NumWidened = 0;
596  do {
597  Value *Condition = Worklist.pop_back_val();
598  if (!Visited.insert(Condition).second)
599  continue;
600 
601  Value *LHS, *RHS;
602  using namespace llvm::PatternMatch;
603  if (match(Condition, m_And(m_Value(LHS), m_Value(RHS)))) {
604  Worklist.push_back(LHS);
605  Worklist.push_back(RHS);
606  continue;
607  }
608 
609  if (ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) {
610  if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander, Builder)) {
611  Checks.push_back(NewRangeCheck.getValue());
612  NumWidened++;
613  continue;
614  }
615  }
616 
617  // Save the condition as is if we can't widen it
618  Checks.push_back(Condition);
619  } while (Worklist.size() != 0);
620 
621  if (NumWidened == 0)
622  return false;
623 
624  TotalWidened += NumWidened;
625 
626  // Emit the new guard condition
627  Builder.SetInsertPoint(Guard);
628  Value *LastCheck = nullptr;
629  for (auto *Check : Checks)
630  if (!LastCheck)
631  LastCheck = Check;
632  else
633  LastCheck = Builder.CreateAnd(LastCheck, Check);
634  Guard->setOperand(0, LastCheck);
635 
636  LLVM_DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n");
637  return true;
638 }
639 
640 Optional<LoopPredication::LoopICmp> LoopPredication::parseLoopLatchICmp() {
641  using namespace PatternMatch;
642 
643  BasicBlock *LoopLatch = L->getLoopLatch();
644  if (!LoopLatch) {
645  LLVM_DEBUG(dbgs() << "The loop doesn't have a single latch!\n");
646  return None;
647  }
648 
649  ICmpInst::Predicate Pred;
650  Value *LHS, *RHS;
651  BasicBlock *TrueDest, *FalseDest;
652 
653  if (!match(LoopLatch->getTerminator(),
654  m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)), TrueDest,
655  FalseDest))) {
656  LLVM_DEBUG(dbgs() << "Failed to match the latch terminator!\n");
657  return None;
658  }
659  assert((TrueDest == L->getHeader() || FalseDest == L->getHeader()) &&
660  "One of the latch's destinations must be the header");
661  if (TrueDest != L->getHeader())
662  Pred = ICmpInst::getInversePredicate(Pred);
663 
664  auto Result = parseLoopICmp(Pred, LHS, RHS);
665  if (!Result) {
666  LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n");
667  return None;
668  }
669 
670  // Check affine first, so if it's not we don't try to compute the step
671  // recurrence.
672  if (!Result->IV->isAffine()) {
673  LLVM_DEBUG(dbgs() << "The induction variable is not affine!\n");
674  return None;
675  }
676 
677  auto *Step = Result->IV->getStepRecurrence(*SE);
678  if (!isSupportedStep(Step)) {
679  LLVM_DEBUG(dbgs() << "Unsupported loop stride(" << *Step << ")!\n");
680  return None;
681  }
682 
683  auto IsUnsupportedPredicate = [](const SCEV *Step, ICmpInst::Predicate Pred) {
684  if (Step->isOne()) {
685  return Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_SLT &&
686  Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_SLE;
687  } else {
688  assert(Step->isAllOnesValue() && "Step should be -1!");
689  return Pred != ICmpInst::ICMP_UGT && Pred != ICmpInst::ICMP_SGT &&
690  Pred != ICmpInst::ICMP_UGE && Pred != ICmpInst::ICMP_SGE;
691  }
692  };
693 
694  if (IsUnsupportedPredicate(Step, Result->Pred)) {
695  LLVM_DEBUG(dbgs() << "Unsupported loop latch predicate(" << Result->Pred
696  << ")!\n");
697  return None;
698  }
699  return Result;
700 }
701 
702 // Returns true if its safe to truncate the IV to RangeCheckType.
703 bool LoopPredication::isSafeToTruncateWideIVType(Type *RangeCheckType) {
704  if (!EnableIVTruncation)
705  return false;
706  assert(DL->getTypeSizeInBits(LatchCheck.IV->getType()) >
707  DL->getTypeSizeInBits(RangeCheckType) &&
708  "Expected latch check IV type to be larger than range check operand "
709  "type!");
710  // The start and end values of the IV should be known. This is to guarantee
711  // that truncating the wide type will not lose information.
712  auto *Limit = dyn_cast<SCEVConstant>(LatchCheck.Limit);
713  auto *Start = dyn_cast<SCEVConstant>(LatchCheck.IV->getStart());
714  if (!Limit || !Start)
715  return false;
716  // This check makes sure that the IV does not change sign during loop
717  // iterations. Consider latchType = i64, LatchStart = 5, Pred = ICMP_SGE,
718  // LatchEnd = 2, rangeCheckType = i32. If it's not a monotonic predicate, the
719  // IV wraps around, and the truncation of the IV would lose the range of
720  // iterations between 2^32 and 2^64.
721  bool Increasing;
722  if (!SE->isMonotonicPredicate(LatchCheck.IV, LatchCheck.Pred, Increasing))
723  return false;
724  // The active bits should be less than the bits in the RangeCheckType. This
725  // guarantees that truncating the latch check to RangeCheckType is a safe
726  // operation.
727  auto RangeCheckTypeBitSize = DL->getTypeSizeInBits(RangeCheckType);
728  return Start->getAPInt().getActiveBits() < RangeCheckTypeBitSize &&
729  Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize;
730 }
731 
732 bool LoopPredication::isLoopProfitableToPredicate() {
733  if (SkipProfitabilityChecks || !BPI)
734  return true;
735 
737  L->getExitEdges(ExitEdges);
738  // If there is only one exiting edge in the loop, it is always profitable to
739  // predicate the loop.
740  if (ExitEdges.size() == 1)
741  return true;
742 
743  // Calculate the exiting probabilities of all exiting edges from the loop,
744  // starting with the LatchExitProbability.
745  // Heuristic for profitability: If any of the exiting blocks' probability of
746  // exiting the loop is larger than exiting through the latch block, it's not
747  // profitable to predicate the loop.
748  auto *LatchBlock = L->getLoopLatch();
749  assert(LatchBlock && "Should have a single latch at this point!");
750  auto *LatchTerm = LatchBlock->getTerminator();
751  assert(LatchTerm->getNumSuccessors() == 2 &&
752  "expected to be an exiting block with 2 succs!");
753  unsigned LatchBrExitIdx =
754  LatchTerm->getSuccessor(0) == L->getHeader() ? 1 : 0;
755  BranchProbability LatchExitProbability =
756  BPI->getEdgeProbability(LatchBlock, LatchBrExitIdx);
757 
758  // Protect against degenerate inputs provided by the user. Providing a value
759  // less than one, can invert the definition of profitable loop predication.
760  float ScaleFactor = LatchExitProbabilityScale;
761  if (ScaleFactor < 1) {
762  LLVM_DEBUG(
763  dbgs()
764  << "Ignored user setting for loop-predication-latch-probability-scale: "
765  << LatchExitProbabilityScale << "\n");
766  LLVM_DEBUG(dbgs() << "The value is set to 1.0\n");
767  ScaleFactor = 1.0;
768  }
769  const auto LatchProbabilityThreshold =
770  LatchExitProbability * ScaleFactor;
771 
772  for (const auto &ExitEdge : ExitEdges) {
773  BranchProbability ExitingBlockProbability =
774  BPI->getEdgeProbability(ExitEdge.first, ExitEdge.second);
775  // Some exiting edge has higher probability than the latch exiting edge.
776  // No longer profitable to predicate.
777  if (ExitingBlockProbability > LatchProbabilityThreshold)
778  return false;
779  }
780  // Using BPI, we have concluded that the most probable way to exit from the
781  // loop is through the latch (or there's no profile information and all
782  // exits are equally likely).
783  return true;
784 }
785 
786 bool LoopPredication::runOnLoop(Loop *Loop) {
787  L = Loop;
788 
789  LLVM_DEBUG(dbgs() << "Analyzing ");
790  LLVM_DEBUG(L->dump());
791 
792  Module *M = L->getHeader()->getModule();
793 
794  // There is nothing to do if the module doesn't use guards
795  auto *GuardDecl =
796  M->getFunction(Intrinsic::getName(Intrinsic::experimental_guard));
797  if (!GuardDecl || GuardDecl->use_empty())
798  return false;
799 
800  DL = &M->getDataLayout();
801 
802  Preheader = L->getLoopPreheader();
803  if (!Preheader)
804  return false;
805 
806  auto LatchCheckOpt = parseLoopLatchICmp();
807  if (!LatchCheckOpt)
808  return false;
809  LatchCheck = *LatchCheckOpt;
810 
811  LLVM_DEBUG(dbgs() << "Latch check:\n");
812  LLVM_DEBUG(LatchCheck.dump());
813 
814  if (!isLoopProfitableToPredicate()) {
815  LLVM_DEBUG(dbgs() << "Loop not profitable to predicate!\n");
816  return false;
817  }
818  // Collect all the guards into a vector and process later, so as not
819  // to invalidate the instruction iterator.
821  for (const auto BB : L->blocks())
822  for (auto &I : *BB)
823  if (auto *II = dyn_cast<IntrinsicInst>(&I))
824  if (II->getIntrinsicID() == Intrinsic::experimental_guard)
825  Guards.push_back(II);
826 
827  if (Guards.empty())
828  return false;
829 
830  SCEVExpander Expander(*SE, *DL, "loop-predication");
831 
832  bool Changed = false;
833  for (auto *Guard : Guards)
834  Changed |= widenGuardConditions(Guard, Expander);
835 
836  return Changed;
837 }
Pass interface - Implemented by all &#39;passes&#39;.
Definition: Pass.h:81
static bool Check(DecodeStatus &Out, DecodeStatus In)
BinaryOp_match< LHS, RHS, Instruction::And > m_And(const LHS &L, const RHS &R)
Definition: PatternMatch.h:725
A parsed version of the target data layout string in and methods for querying it. ...
Definition: DataLayout.h:111
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:72
Value * CreateICmp(CmpInst::Predicate P, Value *LHS, Value *RHS, const Twine &Name="")
Definition: IRBuilder.h:1858
BlockT * getLoopLatch() const
If there is a single latch block for this loop, return it.
Definition: LoopInfoImpl.h:225
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
PreservedAnalyses getLoopPassPreservedAnalyses()
Returns the minimum set of Analyses that all loop passes must preserve.
STATISTIC(TotalConsidered, "Number of guards considered")
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:770
Compute iterated dominance frontiers using a linear time algorithm.
Definition: AllocatorList.h:24
loop predication
A Module instance is used to store all the information related to an LLVM module. ...
Definition: Module.h:64
The main scalar evolution driver.
BlockT * getLoopPreheader() const
If there is a preheader for this loop, return it.
Definition: LoopInfoImpl.h:174
unsigned less or equal
Definition: InstrTypes.h:683
unsigned less than
Definition: InstrTypes.h:682
bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
bool isMonotonicPredicate(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred, bool &Increasing)
Return true if, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or...
The adaptor from a function pass to a loop pass computes these analyses and makes them available to t...
F(f)
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.cpp:138
INITIALIZE_PASS_BEGIN(LoopPredicationLegacyPass, "loop-predication", "Loop predication", false, false) INITIALIZE_PASS_END(LoopPredicationLegacyPass
void dump() const
Support for debugging, callable in GDB: V->dump()
Definition: AsmWriter.cpp:4270
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
AnalysisUsage & addRequired()
const Module * getModule() const
Return the module owning the function this basic block belongs to, or nullptr if the function does no...
Definition: BasicBlock.cpp:134
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:51
StringRef getName(ID id)
Return the LLVM name for an intrinsic, such as "llvm.ppc.altivec.lvx".
Definition: Function.cpp:627
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE, etc.
Definition: InstrTypes.h:755
static cl::opt< float > LatchExitProbabilityScale("loop-predication-latch-probability-scale", cl::Hidden, cl::init(2.0), cl::desc("scale factor for the latch probability. Value should be greater " "than 1. Lower values are ignored"))
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:743
BlockT * getHeader() const
Definition: LoopInfo.h:100
Analysis pass which computes BranchProbabilityInfo.
This node represents a polynomial recurrence on the trip count of the specified loop.
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block...
Definition: IRBuilder.h:127
Legacy analysis pass which computes BranchProbabilityInfo.
Value * getOperand(unsigned i) const
Definition: User.h:170
Value * getOperand(unsigned i_nocapture) const
static cl::opt< bool > EnableIVTruncation("loop-predication-enable-iv-truncation", cl::Hidden, cl::init(true))
bool isLoopEntryGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS. ...
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:410
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
A set of analyses that are preserved following a run of a transformation pass.
Definition: PassManager.h:154
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
void dump() const
Definition: LoopInfo.cpp:371
LLVM Basic Block Representation.
Definition: BasicBlock.h:58
The instances of the Type class are immutable: once they are created, they are never changed...
Definition: Type.h:46
ConstantInt * getTrue()
Get the constant value for i1 true.
Definition: IRBuilder.h:287
const SCEV * getAddExpr(SmallVectorImpl< const SCEV *> &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
brc_match< Cond_t > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
Represent the analysis usage information of a pass.
Pass * createLoopPredicationPass()
This instruction compares its operands according to the predicate given to the constructor.
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:657
Value * expandCodeFor(const SCEV *SH, Type *Ty, Instruction *I)
Insert code to directly compute the specified SCEV expression into the program.
const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS. Minus is represented in SCEV as A+B*-1.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: PassManager.h:160
size_t size() const
Definition: SmallVector.h:53
INITIALIZE_PASS_END(RegBankSelect, DEBUG_TYPE, "Assign register bank of generic virtual registers", false, false) RegBankSelect
signed greater than
Definition: InstrTypes.h:684
BranchProbability getEdgeProbability(const BasicBlock *Src, unsigned IndexInSuccessors) const
Get an edge&#39;s probability, relative to other out-edges of the Src.
void getExitEdges(SmallVectorImpl< Edge > &ExitEdges) const
Return all pairs of (inside_block,outside_block).
Definition: LoopInfoImpl.h:155
This class provides an interface for updating the loop pass manager based on mutations to the loop ne...
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements...
Definition: SmallPtrSet.h:418
bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
Type * getType() const
Return the LLVM type of this SCEV expression.
const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty)
An analysis over an "inner" IR unit that provides access to an analysis manager over a "outer" IR uni...
Definition: PassManager.h:1154
This is a &#39;vector&#39; (really, a variable-sized array), optimized for the case when the array is small...
Definition: SmallVector.h:847
Module.h This file contains the declarations for the Module class.
signed less than
Definition: InstrTypes.h:686
Predicate getFlippedStrictnessPredicate() const
For predicate of kind "is X or equal to 0" returns the predicate "is X".
Definition: InstrTypes.h:786
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:133
Function * getFunction(StringRef Name) const
Look up the specified function in the module symbol table.
Definition: Module.cpp:175
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:941
signed less or equal
Definition: InstrTypes.h:687
void setOperand(unsigned i_nocapture, Value *Val_nocapture)
static cl::opt< bool > EnableCountDownLoop("loop-predication-enable-count-down-loop", cl::Hidden, cl::init(true))
This class uses information about analyze scalars to rewrite expressions in canonical form...
iterator insert(iterator I, T &&Elt)
Definition: SmallVector.h:478
uint64_t getTypeSizeInBits(Type *Ty) const
Size examples:
Definition: DataLayout.h:560
Predicate getPredicate() const
Return the predicate for this instruction.
Definition: InstrTypes.h:731
Analysis providing branch probability information.
This class represents an analyzed expression in the program.
LLVM_NODISCARD bool empty() const
Definition: SmallVector.h:56
unsigned greater or equal
Definition: InstrTypes.h:681
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:459
#define I(x, y, z)
Definition: MD5.cpp:58
void getLoopAnalysisUsage(AnalysisUsage &AU)
Helper to consistently add the set of standard passes to a loop pass&#39;s AnalysisUsage.
Definition: LoopUtils.cpp:128
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
Definition: Casting.h:323
static cl::opt< bool > SkipProfitabilityChecks("loop-predication-skip-profitability-checks", cl::Hidden, cl::init(false))
Value * CreateAnd(Value *LHS, Value *RHS, const Twine &Name="")
Definition: IRBuilder.h:1124
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
void initializeLoopPredicationLegacyPassPass(PassRegistry &)
PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &U)
bool isOne() const
Return true if the expression is a constant one.
LLVM Value Representation.
Definition: Value.h:73
const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
unsigned greater than
Definition: InstrTypes.h:680
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition: InstrTypes.h:771
A container for analyses that lazily runs them and caches their results.
#define LLVM_DEBUG(X)
Definition: Debug.h:123
bool isSafeToExpand(const SCEV *S, ScalarEvolution &SE)
Return true if the given expression is safe to expand in the sense that all materialized values are s...
iterator_range< block_iterator > blocks() const
Definition: LoopInfo.h:156
signed greater or equal
Definition: InstrTypes.h:685
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:44
This class represents a constant integer value.
CmpClass_match< LHS, RHS, ICmpInst, ICmpInst::Predicate > m_ICmp(ICmpInst::Predicate &Pred, const LHS &L, const RHS &R)
Definition: PatternMatch.h:990