LLVM  6.0.0svn
ScalarEvolution.cpp
Go to the documentation of this file.
1 //===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file contains the implementation of the scalar evolution analysis
11 // engine, which is used primarily to analyze expressions involving induction
12 // variables in loops.
13 //
14 // There are several aspects to this library. First is the representation of
15 // scalar expressions, which are represented as subclasses of the SCEV class.
16 // These classes are used to represent certain types of subexpressions that we
17 // can handle. We only create one SCEV of a particular shape, so
18 // pointer-comparisons for equality are legal.
19 //
20 // One important aspect of the SCEV objects is that they are never cyclic, even
21 // if there is a cycle in the dataflow for an expression (ie, a PHI node). If
22 // the PHI node is one of the idioms that we can represent (e.g., a polynomial
23 // recurrence) then we represent it directly as a recurrence node, otherwise we
24 // represent it as a SCEVUnknown node.
25 //
26 // In addition to being able to represent expressions of various types, we also
27 // have folders that are used to build the *canonical* representation for a
28 // particular expression. These folders are capable of using a variety of
29 // rewrite rules to simplify the expressions.
30 //
31 // Once the folders are defined, we can implement the more interesting
32 // higher-level code, such as the code that recognizes PHI nodes of various
33 // types, computes the execution count of a loop, etc.
34 //
35 // TODO: We should use these routines and value representations to implement
36 // dependence analysis!
37 //
38 //===----------------------------------------------------------------------===//
39 //
40 // There are several good references for the techniques used in this analysis.
41 //
42 // Chains of recurrences -- a method to expedite the evaluation
43 // of closed-form functions
44 // Olaf Bachmann, Paul S. Wang, Eugene V. Zima
45 //
46 // On computational properties of chains of recurrences
47 // Eugene V. Zima
48 //
49 // Symbolic Evaluation of Chains of Recurrences for Loop Optimization
50 // Robert A. van Engelen
51 //
52 // Efficient Symbolic Analysis for Optimizing Compilers
53 // Robert A. van Engelen
54 //
55 // Using the chains of recurrences algebra for data dependence testing and
56 // induction variable substitution
57 // MS Thesis, Johnie Birch
58 //
59 //===----------------------------------------------------------------------===//
60 
62 #include "llvm/ADT/Optional.h"
63 #include "llvm/ADT/STLExtras.h"
64 #include "llvm/ADT/ScopeExit.h"
65 #include "llvm/ADT/Sequence.h"
66 #include "llvm/ADT/SmallPtrSet.h"
67 #include "llvm/ADT/Statistic.h"
71 #include "llvm/Analysis/LoopInfo.h"
75 #include "llvm/IR/ConstantRange.h"
76 #include "llvm/IR/Constants.h"
77 #include "llvm/IR/DataLayout.h"
78 #include "llvm/IR/DerivedTypes.h"
79 #include "llvm/IR/Dominators.h"
81 #include "llvm/IR/GlobalAlias.h"
82 #include "llvm/IR/GlobalVariable.h"
83 #include "llvm/IR/InstIterator.h"
84 #include "llvm/IR/Instructions.h"
85 #include "llvm/IR/LLVMContext.h"
86 #include "llvm/IR/Metadata.h"
87 #include "llvm/IR/Operator.h"
88 #include "llvm/IR/PatternMatch.h"
90 #include "llvm/Support/Debug.h"
92 #include "llvm/Support/KnownBits.h"
96 #include <algorithm>
97 using namespace llvm;
98 
99 #define DEBUG_TYPE "scalar-evolution"
100 
101 STATISTIC(NumArrayLenItCounts,
102  "Number of trip counts computed with array length");
103 STATISTIC(NumTripCountsComputed,
104  "Number of loops with predictable loop counts");
105 STATISTIC(NumTripCountsNotComputed,
106  "Number of loops without predictable loop counts");
107 STATISTIC(NumBruteForceTripCountsComputed,
108  "Number of loops with trip counts computed by force");
109 
110 static cl::opt<unsigned>
111 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
112  cl::desc("Maximum number of iterations SCEV will "
113  "symbolically execute a constant "
114  "derived loop"),
115  cl::init(100));
116 
117 // FIXME: Enable this with EXPENSIVE_CHECKS when the test suite is clean.
118 static cl::opt<bool>
119 VerifySCEV("verify-scev",
120  cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
121 static cl::opt<bool>
122  VerifySCEVMap("verify-scev-maps",
123  cl::desc("Verify no dangling value in ScalarEvolution's "
124  "ExprValueMap (slow)"));
125 
127  "scev-mulops-inline-threshold", cl::Hidden,
128  cl::desc("Threshold for inlining multiplication operands into a SCEV"),
129  cl::init(32));
130 
132  "scev-addops-inline-threshold", cl::Hidden,
133  cl::desc("Threshold for inlining addition operands into a SCEV"),
134  cl::init(500));
135 
137  "scalar-evolution-max-scev-compare-depth", cl::Hidden,
138  cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
139  cl::init(32));
140 
142  "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
143  cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
144  cl::init(2));
145 
147  "scalar-evolution-max-value-compare-depth", cl::Hidden,
148  cl::desc("Maximum depth of recursive value complexity comparisons"),
149  cl::init(2));
150 
151 static cl::opt<unsigned>
152  MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
153  cl::desc("Maximum depth of recursive arithmetics"),
154  cl::init(32));
155 
157  "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
158  cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
159 
160 static cl::opt<unsigned>
161  MaxExtDepth("scalar-evolution-max-ext-depth", cl::Hidden,
162  cl::desc("Maximum depth of recursive SExt/ZExt"),
163  cl::init(8));
164 
165 //===----------------------------------------------------------------------===//
166 // SCEV class definitions
167 //===----------------------------------------------------------------------===//
168 
169 //===----------------------------------------------------------------------===//
170 // Implementation of the SCEV class.
171 //
172 
173 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
175  print(dbgs());
176  dbgs() << '\n';
177 }
178 #endif
179 
180 void SCEV::print(raw_ostream &OS) const {
181  switch (static_cast<SCEVTypes>(getSCEVType())) {
182  case scConstant:
183  cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
184  return;
185  case scTruncate: {
186  const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
187  const SCEV *Op = Trunc->getOperand();
188  OS << "(trunc " << *Op->getType() << " " << *Op << " to "
189  << *Trunc->getType() << ")";
190  return;
191  }
192  case scZeroExtend: {
193  const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
194  const SCEV *Op = ZExt->getOperand();
195  OS << "(zext " << *Op->getType() << " " << *Op << " to "
196  << *ZExt->getType() << ")";
197  return;
198  }
199  case scSignExtend: {
200  const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
201  const SCEV *Op = SExt->getOperand();
202  OS << "(sext " << *Op->getType() << " " << *Op << " to "
203  << *SExt->getType() << ")";
204  return;
205  }
206  case scAddRecExpr: {
207  const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
208  OS << "{" << *AR->getOperand(0);
209  for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
210  OS << ",+," << *AR->getOperand(i);
211  OS << "}<";
212  if (AR->hasNoUnsignedWrap())
213  OS << "nuw><";
214  if (AR->hasNoSignedWrap())
215  OS << "nsw><";
216  if (AR->hasNoSelfWrap() &&
218  OS << "nw><";
219  AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
220  OS << ">";
221  return;
222  }
223  case scAddExpr:
224  case scMulExpr:
225  case scUMaxExpr:
226  case scSMaxExpr: {
227  const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
228  const char *OpStr = nullptr;
229  switch (NAry->getSCEVType()) {
230  case scAddExpr: OpStr = " + "; break;
231  case scMulExpr: OpStr = " * "; break;
232  case scUMaxExpr: OpStr = " umax "; break;
233  case scSMaxExpr: OpStr = " smax "; break;
234  }
235  OS << "(";
236  for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
237  I != E; ++I) {
238  OS << **I;
239  if (std::next(I) != E)
240  OS << OpStr;
241  }
242  OS << ")";
243  switch (NAry->getSCEVType()) {
244  case scAddExpr:
245  case scMulExpr:
246  if (NAry->hasNoUnsignedWrap())
247  OS << "<nuw>";
248  if (NAry->hasNoSignedWrap())
249  OS << "<nsw>";
250  }
251  return;
252  }
253  case scUDivExpr: {
254  const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
255  OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
256  return;
257  }
258  case scUnknown: {
259  const SCEVUnknown *U = cast<SCEVUnknown>(this);
260  Type *AllocTy;
261  if (U->isSizeOf(AllocTy)) {
262  OS << "sizeof(" << *AllocTy << ")";
263  return;
264  }
265  if (U->isAlignOf(AllocTy)) {
266  OS << "alignof(" << *AllocTy << ")";
267  return;
268  }
269 
270  Type *CTy;
271  Constant *FieldNo;
272  if (U->isOffsetOf(CTy, FieldNo)) {
273  OS << "offsetof(" << *CTy << ", ";
274  FieldNo->printAsOperand(OS, false);
275  OS << ")";
276  return;
277  }
278 
279  // Otherwise just print it normally.
280  U->getValue()->printAsOperand(OS, false);
281  return;
282  }
283  case scCouldNotCompute:
284  OS << "***COULDNOTCOMPUTE***";
285  return;
286  }
287  llvm_unreachable("Unknown SCEV kind!");
288 }
289 
290 Type *SCEV::getType() const {
291  switch (static_cast<SCEVTypes>(getSCEVType())) {
292  case scConstant:
293  return cast<SCEVConstant>(this)->getType();
294  case scTruncate:
295  case scZeroExtend:
296  case scSignExtend:
297  return cast<SCEVCastExpr>(this)->getType();
298  case scAddRecExpr:
299  case scMulExpr:
300  case scUMaxExpr:
301  case scSMaxExpr:
302  return cast<SCEVNAryExpr>(this)->getType();
303  case scAddExpr:
304  return cast<SCEVAddExpr>(this)->getType();
305  case scUDivExpr:
306  return cast<SCEVUDivExpr>(this)->getType();
307  case scUnknown:
308  return cast<SCEVUnknown>(this)->getType();
309  case scCouldNotCompute:
310  llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
311  }
312  llvm_unreachable("Unknown SCEV kind!");
313 }
314 
315 bool SCEV::isZero() const {
316  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
317  return SC->getValue()->isZero();
318  return false;
319 }
320 
321 bool SCEV::isOne() const {
322  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
323  return SC->getValue()->isOne();
324  return false;
325 }
326 
327 bool SCEV::isAllOnesValue() const {
328  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
329  return SC->getValue()->isMinusOne();
330  return false;
331 }
332 
334  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
335  if (!Mul) return false;
336 
337  // If there is a constant factor, it will be first.
338  const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
339  if (!SC) return false;
340 
341  // Return true if the value is negative, this matches things like (-42 * V).
342  return SC->getAPInt().isNegative();
343 }
344 
347 
349  return S->getSCEVType() == scCouldNotCompute;
350 }
351 
355  ID.AddPointer(V);
356  void *IP = nullptr;
357  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
358  SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
359  UniqueSCEVs.InsertNode(S, IP);
360  return S;
361 }
362 
364  return getConstant(ConstantInt::get(getContext(), Val));
365 }
366 
367 const SCEV *
368 ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) {
369  IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
370  return getConstant(ConstantInt::get(ITy, V, isSigned));
371 }
372 
374  unsigned SCEVTy, const SCEV *op, Type *ty)
375  : SCEV(ID, SCEVTy), Op(op), Ty(ty) {}
376 
377 SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID,
378  const SCEV *op, Type *ty)
379  : SCEVCastExpr(ID, scTruncate, op, ty) {
380  assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
381  (Ty->isIntegerTy() || Ty->isPointerTy()) &&
382  "Cannot truncate non-integer value!");
383 }
384 
385 SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
386  const SCEV *op, Type *ty)
387  : SCEVCastExpr(ID, scZeroExtend, op, ty) {
388  assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
389  (Ty->isIntegerTy() || Ty->isPointerTy()) &&
390  "Cannot zero extend non-integer value!");
391 }
392 
393 SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
394  const SCEV *op, Type *ty)
395  : SCEVCastExpr(ID, scSignExtend, op, ty) {
396  assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
397  (Ty->isIntegerTy() || Ty->isPointerTy()) &&
398  "Cannot sign extend non-integer value!");
399 }
400 
401 void SCEVUnknown::deleted() {
402  // Clear this SCEVUnknown from various maps.
403  SE->forgetMemoizedResults(this);
404 
405  // Remove this SCEVUnknown from the uniquing map.
406  SE->UniqueSCEVs.RemoveNode(this);
407 
408  // Release the value.
409  setValPtr(nullptr);
410 }
411 
412 void SCEVUnknown::allUsesReplacedWith(Value *New) {
413  // Clear this SCEVUnknown from various maps.
414  SE->forgetMemoizedResults(this);
415 
416  // Remove this SCEVUnknown from the uniquing map.
417  SE->UniqueSCEVs.RemoveNode(this);
418 
419  // Update this SCEVUnknown to point to the new value. This is needed
420  // because there may still be outstanding SCEVs which still point to
421  // this SCEVUnknown.
422  setValPtr(New);
423 }
424 
425 bool SCEVUnknown::isSizeOf(Type *&AllocTy) const {
426  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
427  if (VCE->getOpcode() == Instruction::PtrToInt)
428  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
429  if (CE->getOpcode() == Instruction::GetElementPtr &&
430  CE->getOperand(0)->isNullValue() &&
431  CE->getNumOperands() == 2)
432  if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1)))
433  if (CI->isOne()) {
434  AllocTy = cast<PointerType>(CE->getOperand(0)->getType())
435  ->getElementType();
436  return true;
437  }
438 
439  return false;
440 }
441 
442 bool SCEVUnknown::isAlignOf(Type *&AllocTy) const {
443  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
444  if (VCE->getOpcode() == Instruction::PtrToInt)
445  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
446  if (CE->getOpcode() == Instruction::GetElementPtr &&
447  CE->getOperand(0)->isNullValue()) {
448  Type *Ty =
449  cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
450  if (StructType *STy = dyn_cast<StructType>(Ty))
451  if (!STy->isPacked() &&
452  CE->getNumOperands() == 3 &&
453  CE->getOperand(1)->isNullValue()) {
454  if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2)))
455  if (CI->isOne() &&
456  STy->getNumElements() == 2 &&
457  STy->getElementType(0)->isIntegerTy(1)) {
458  AllocTy = STy->getElementType(1);
459  return true;
460  }
461  }
462  }
463 
464  return false;
465 }
466 
467 bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const {
468  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
469  if (VCE->getOpcode() == Instruction::PtrToInt)
470  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
471  if (CE->getOpcode() == Instruction::GetElementPtr &&
472  CE->getNumOperands() == 3 &&
473  CE->getOperand(0)->isNullValue() &&
474  CE->getOperand(1)->isNullValue()) {
475  Type *Ty =
476  cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
477  // Ignore vector types here so that ScalarEvolutionExpander doesn't
478  // emit getelementptrs that index into vectors.
479  if (Ty->isStructTy() || Ty->isArrayTy()) {
480  CTy = Ty;
481  FieldNo = CE->getOperand(2);
482  return true;
483  }
484  }
485 
486  return false;
487 }
488 
489 //===----------------------------------------------------------------------===//
490 // SCEV Utilities
491 //===----------------------------------------------------------------------===//
492 
493 /// Compare the two values \p LV and \p RV in terms of their "complexity" where
494 /// "complexity" is a partial (and somewhat ad-hoc) relation used to order
495 /// operands in SCEV expressions. \p EqCache is a set of pairs of values that
496 /// have been previously deemed to be "equally complex" by this routine. It is
497 /// intended to avoid exponential time complexity in cases like:
498 ///
499 /// %a = f(%x, %y)
500 /// %b = f(%a, %a)
501 /// %c = f(%b, %b)
502 ///
503 /// %d = f(%x, %y)
504 /// %e = f(%d, %d)
505 /// %f = f(%e, %e)
506 ///
507 /// CompareValueComplexity(%f, %c)
508 ///
509 /// Since we do not continue running this routine on expression trees once we
510 /// have seen unequal values, there is no need to track them in the cache.
511 static int
512 CompareValueComplexity(SmallSet<std::pair<Value *, Value *>, 8> &EqCache,
513  const LoopInfo *const LI, Value *LV, Value *RV,
514  unsigned Depth) {
515  if (Depth > MaxValueCompareDepth || EqCache.count({LV, RV}))
516  return 0;
517 
518  // Order pointer values after integer values. This helps SCEVExpander form
519  // GEPs.
520  bool LIsPointer = LV->getType()->isPointerTy(),
521  RIsPointer = RV->getType()->isPointerTy();
522  if (LIsPointer != RIsPointer)
523  return (int)LIsPointer - (int)RIsPointer;
524 
525  // Compare getValueID values.
526  unsigned LID = LV->getValueID(), RID = RV->getValueID();
527  if (LID != RID)
528  return (int)LID - (int)RID;
529 
530  // Sort arguments by their position.
531  if (const auto *LA = dyn_cast<Argument>(LV)) {
532  const auto *RA = cast<Argument>(RV);
533  unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
534  return (int)LArgNo - (int)RArgNo;
535  }
536 
537  if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
538  const auto *RGV = cast<GlobalValue>(RV);
539 
540  const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
541  auto LT = GV->getLinkage();
542  return !(GlobalValue::isPrivateLinkage(LT) ||
544  };
545 
546  // Use the names to distinguish the two values, but only if the
547  // names are semantically important.
548  if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
549  return LGV->getName().compare(RGV->getName());
550  }
551 
552  // For instructions, compare their loop depth, and their operand count. This
553  // is pretty loose.
554  if (const auto *LInst = dyn_cast<Instruction>(LV)) {
555  const auto *RInst = cast<Instruction>(RV);
556 
557  // Compare loop depths.
558  const BasicBlock *LParent = LInst->getParent(),
559  *RParent = RInst->getParent();
560  if (LParent != RParent) {
561  unsigned LDepth = LI->getLoopDepth(LParent),
562  RDepth = LI->getLoopDepth(RParent);
563  if (LDepth != RDepth)
564  return (int)LDepth - (int)RDepth;
565  }
566 
567  // Compare the number of operands.
568  unsigned LNumOps = LInst->getNumOperands(),
569  RNumOps = RInst->getNumOperands();
570  if (LNumOps != RNumOps)
571  return (int)LNumOps - (int)RNumOps;
572 
573  for (unsigned Idx : seq(0u, LNumOps)) {
574  int Result =
575  CompareValueComplexity(EqCache, LI, LInst->getOperand(Idx),
576  RInst->getOperand(Idx), Depth + 1);
577  if (Result != 0)
578  return Result;
579  }
580  }
581 
582  EqCache.insert({LV, RV});
583  return 0;
584 }
585 
586 // Return negative, zero, or positive, if LHS is less than, equal to, or greater
587 // than RHS, respectively. A three-way result allows recursive comparisons to be
588 // more efficient.
590  SmallSet<std::pair<const SCEV *, const SCEV *>, 8> &EqCacheSCEV,
591  const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS,
592  DominatorTree &DT, unsigned Depth = 0) {
593  // Fast-path: SCEVs are uniqued so we can do a quick equality check.
594  if (LHS == RHS)
595  return 0;
596 
597  // Primarily, sort the SCEVs by their getSCEVType().
598  unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
599  if (LType != RType)
600  return (int)LType - (int)RType;
601 
602  if (Depth > MaxSCEVCompareDepth || EqCacheSCEV.count({LHS, RHS}))
603  return 0;
604  // Aside from the getSCEVType() ordering, the particular ordering
605  // isn't very important except that it's beneficial to be consistent,
606  // so that (a + b) and (b + a) don't end up as different expressions.
607  switch (static_cast<SCEVTypes>(LType)) {
608  case scUnknown: {
609  const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
610  const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
611 
613  int X = CompareValueComplexity(EqCache, LI, LU->getValue(), RU->getValue(),
614  Depth + 1);
615  if (X == 0)
616  EqCacheSCEV.insert({LHS, RHS});
617  return X;
618  }
619 
620  case scConstant: {
621  const SCEVConstant *LC = cast<SCEVConstant>(LHS);
622  const SCEVConstant *RC = cast<SCEVConstant>(RHS);
623 
624  // Compare constant values.
625  const APInt &LA = LC->getAPInt();
626  const APInt &RA = RC->getAPInt();
627  unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
628  if (LBitWidth != RBitWidth)
629  return (int)LBitWidth - (int)RBitWidth;
630  return LA.ult(RA) ? -1 : 1;
631  }
632 
633  case scAddRecExpr: {
634  const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
635  const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
636 
637  // There is always a dominance between two recs that are used by one SCEV,
638  // so we can safely sort recs by loop header dominance. We require such
639  // order in getAddExpr.
640  const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
641  if (LLoop != RLoop) {
642  const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
643  assert(LHead != RHead && "Two loops share the same header?");
644  if (DT.dominates(LHead, RHead))
645  return 1;
646  else
647  assert(DT.dominates(RHead, LHead) &&
648  "No dominance between recurrences used by one SCEV?");
649  return -1;
650  }
651 
652  // Addrec complexity grows with operand count.
653  unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands();
654  if (LNumOps != RNumOps)
655  return (int)LNumOps - (int)RNumOps;
656 
657  // Lexicographically compare.
658  for (unsigned i = 0; i != LNumOps; ++i) {
659  int X = CompareSCEVComplexity(EqCacheSCEV, LI, LA->getOperand(i),
660  RA->getOperand(i), DT, Depth + 1);
661  if (X != 0)
662  return X;
663  }
664  EqCacheSCEV.insert({LHS, RHS});
665  return 0;
666  }
667 
668  case scAddExpr:
669  case scMulExpr:
670  case scSMaxExpr:
671  case scUMaxExpr: {
672  const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS);
673  const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
674 
675  // Lexicographically compare n-ary expressions.
676  unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
677  if (LNumOps != RNumOps)
678  return (int)LNumOps - (int)RNumOps;
679 
680  for (unsigned i = 0; i != LNumOps; ++i) {
681  if (i >= RNumOps)
682  return 1;
683  int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getOperand(i),
684  RC->getOperand(i), DT, Depth + 1);
685  if (X != 0)
686  return X;
687  }
688  EqCacheSCEV.insert({LHS, RHS});
689  return 0;
690  }
691 
692  case scUDivExpr: {
693  const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS);
694  const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
695 
696  // Lexicographically compare udiv expressions.
697  int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getLHS(), RC->getLHS(),
698  DT, Depth + 1);
699  if (X != 0)
700  return X;
701  X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getRHS(), RC->getRHS(), DT,
702  Depth + 1);
703  if (X == 0)
704  EqCacheSCEV.insert({LHS, RHS});
705  return X;
706  }
707 
708  case scTruncate:
709  case scZeroExtend:
710  case scSignExtend: {
711  const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS);
712  const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
713 
714  // Compare cast expressions by operand.
715  int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getOperand(),
716  RC->getOperand(), DT, Depth + 1);
717  if (X == 0)
718  EqCacheSCEV.insert({LHS, RHS});
719  return X;
720  }
721 
722  case scCouldNotCompute:
723  llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
724  }
725  llvm_unreachable("Unknown SCEV kind!");
726 }
727 
728 /// Given a list of SCEV objects, order them by their complexity, and group
729 /// objects of the same complexity together by value. When this routine is
730 /// finished, we know that any duplicates in the vector are consecutive and that
731 /// complexity is monotonically increasing.
732 ///
733 /// Note that we go take special precautions to ensure that we get deterministic
734 /// results from this routine. In other words, we don't want the results of
735 /// this to depend on where the addresses of various SCEV objects happened to
736 /// land in memory.
737 ///
739  LoopInfo *LI, DominatorTree &DT) {
740  if (Ops.size() < 2) return; // Noop
741 
743  if (Ops.size() == 2) {
744  // This is the common case, which also happens to be trivially simple.
745  // Special case it.
746  const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
747  if (CompareSCEVComplexity(EqCache, LI, RHS, LHS, DT) < 0)
748  std::swap(LHS, RHS);
749  return;
750  }
751 
752  // Do the rough sort by complexity.
753  std::stable_sort(Ops.begin(), Ops.end(),
754  [&EqCache, LI, &DT](const SCEV *LHS, const SCEV *RHS) {
755  return
756  CompareSCEVComplexity(EqCache, LI, LHS, RHS, DT) < 0;
757  });
758 
759  // Now that we are sorted by complexity, group elements of the same
760  // complexity. Note that this is, at worst, N^2, but the vector is likely to
761  // be extremely short in practice. Note that we take this approach because we
762  // do not want to depend on the addresses of the objects we are grouping.
763  for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
764  const SCEV *S = Ops[i];
765  unsigned Complexity = S->getSCEVType();
766 
767  // If there are any objects of the same complexity and same value as this
768  // one, group them.
769  for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
770  if (Ops[j] == S) { // Found a duplicate.
771  // Move it to immediately after i'th element.
772  std::swap(Ops[i+1], Ops[j]);
773  ++i; // no need to rescan it.
774  if (i == e-2) return; // Done!
775  }
776  }
777  }
778 }
779 
780 // Returns the size of the SCEV S.
781 static inline int sizeOfSCEV(const SCEV *S) {
782  struct FindSCEVSize {
783  int Size;
784  FindSCEVSize() : Size(0) {}
785 
786  bool follow(const SCEV *S) {
787  ++Size;
788  // Keep looking at all operands of S.
789  return true;
790  }
791  bool isDone() const {
792  return false;
793  }
794  };
795 
796  FindSCEVSize F;
798  ST.visitAll(S);
799  return F.Size;
800 }
801 
802 namespace {
803 
804 struct SCEVDivision : public SCEVVisitor<SCEVDivision, void> {
805 public:
806  // Computes the Quotient and Remainder of the division of Numerator by
807  // Denominator.
808  static void divide(ScalarEvolution &SE, const SCEV *Numerator,
809  const SCEV *Denominator, const SCEV **Quotient,
810  const SCEV **Remainder) {
811  assert(Numerator && Denominator && "Uninitialized SCEV");
812 
813  SCEVDivision D(SE, Numerator, Denominator);
814 
815  // Check for the trivial case here to avoid having to check for it in the
816  // rest of the code.
817  if (Numerator == Denominator) {
818  *Quotient = D.One;
819  *Remainder = D.Zero;
820  return;
821  }
822 
823  if (Numerator->isZero()) {
824  *Quotient = D.Zero;
825  *Remainder = D.Zero;
826  return;
827  }
828 
829  // A simple case when N/1. The quotient is N.
830  if (Denominator->isOne()) {
831  *Quotient = Numerator;
832  *Remainder = D.Zero;
833  return;
834  }
835 
836  // Split the Denominator when it is a product.
837  if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) {
838  const SCEV *Q, *R;
839  *Quotient = Numerator;
840  for (const SCEV *Op : T->operands()) {
841  divide(SE, *Quotient, Op, &Q, &R);
842  *Quotient = Q;
843 
844  // Bail out when the Numerator is not divisible by one of the terms of
845  // the Denominator.
846  if (!R->isZero()) {
847  *Quotient = D.Zero;
848  *Remainder = Numerator;
849  return;
850  }
851  }
852  *Remainder = D.Zero;
853  return;
854  }
855 
856  D.visit(Numerator);
857  *Quotient = D.Quotient;
858  *Remainder = D.Remainder;
859  }
860 
861  // Except in the trivial case described above, we do not know how to divide
862  // Expr by Denominator for the following functions with empty implementation.
863  void visitTruncateExpr(const SCEVTruncateExpr *Numerator) {}
864  void visitZeroExtendExpr(const SCEVZeroExtendExpr *Numerator) {}
865  void visitSignExtendExpr(const SCEVSignExtendExpr *Numerator) {}
866  void visitUDivExpr(const SCEVUDivExpr *Numerator) {}
867  void visitSMaxExpr(const SCEVSMaxExpr *Numerator) {}
868  void visitUMaxExpr(const SCEVUMaxExpr *Numerator) {}
869  void visitUnknown(const SCEVUnknown *Numerator) {}
870  void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {}
871 
872  void visitConstant(const SCEVConstant *Numerator) {
873  if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
874  APInt NumeratorVal = Numerator->getAPInt();
875  APInt DenominatorVal = D->getAPInt();
876  uint32_t NumeratorBW = NumeratorVal.getBitWidth();
877  uint32_t DenominatorBW = DenominatorVal.getBitWidth();
878 
879  if (NumeratorBW > DenominatorBW)
880  DenominatorVal = DenominatorVal.sext(NumeratorBW);
881  else if (NumeratorBW < DenominatorBW)
882  NumeratorVal = NumeratorVal.sext(DenominatorBW);
883 
884  APInt QuotientVal(NumeratorVal.getBitWidth(), 0);
885  APInt RemainderVal(NumeratorVal.getBitWidth(), 0);
886  APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal);
887  Quotient = SE.getConstant(QuotientVal);
888  Remainder = SE.getConstant(RemainderVal);
889  return;
890  }
891  }
892 
893  void visitAddRecExpr(const SCEVAddRecExpr *Numerator) {
894  const SCEV *StartQ, *StartR, *StepQ, *StepR;
895  if (!Numerator->isAffine())
896  return cannotDivide(Numerator);
897  divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR);
898  divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR);
899  // Bail out if the types do not match.
900  Type *Ty = Denominator->getType();
901  if (Ty != StartQ->getType() || Ty != StartR->getType() ||
902  Ty != StepQ->getType() || Ty != StepR->getType())
903  return cannotDivide(Numerator);
904  Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(),
905  Numerator->getNoWrapFlags());
906  Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(),
907  Numerator->getNoWrapFlags());
908  }
909 
910  void visitAddExpr(const SCEVAddExpr *Numerator) {
912  Type *Ty = Denominator->getType();
913 
914  for (const SCEV *Op : Numerator->operands()) {
915  const SCEV *Q, *R;
916  divide(SE, Op, Denominator, &Q, &R);
917 
918  // Bail out if types do not match.
919  if (Ty != Q->getType() || Ty != R->getType())
920  return cannotDivide(Numerator);
921 
922  Qs.push_back(Q);
923  Rs.push_back(R);
924  }
925 
926  if (Qs.size() == 1) {
927  Quotient = Qs[0];
928  Remainder = Rs[0];
929  return;
930  }
931 
932  Quotient = SE.getAddExpr(Qs);
933  Remainder = SE.getAddExpr(Rs);
934  }
935 
936  void visitMulExpr(const SCEVMulExpr *Numerator) {
938  Type *Ty = Denominator->getType();
939 
940  bool FoundDenominatorTerm = false;
941  for (const SCEV *Op : Numerator->operands()) {
942  // Bail out if types do not match.
943  if (Ty != Op->getType())
944  return cannotDivide(Numerator);
945 
946  if (FoundDenominatorTerm) {
947  Qs.push_back(Op);
948  continue;
949  }
950 
951  // Check whether Denominator divides one of the product operands.
952  const SCEV *Q, *R;
953  divide(SE, Op, Denominator, &Q, &R);
954  if (!R->isZero()) {
955  Qs.push_back(Op);
956  continue;
957  }
958 
959  // Bail out if types do not match.
960  if (Ty != Q->getType())
961  return cannotDivide(Numerator);
962 
963  FoundDenominatorTerm = true;
964  Qs.push_back(Q);
965  }
966 
967  if (FoundDenominatorTerm) {
968  Remainder = Zero;
969  if (Qs.size() == 1)
970  Quotient = Qs[0];
971  else
972  Quotient = SE.getMulExpr(Qs);
973  return;
974  }
975 
976  if (!isa<SCEVUnknown>(Denominator))
977  return cannotDivide(Numerator);
978 
979  // The Remainder is obtained by replacing Denominator by 0 in Numerator.
980  ValueToValueMap RewriteMap;
981  RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] =
982  cast<SCEVConstant>(Zero)->getValue();
983  Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true);
984 
985  if (Remainder->isZero()) {
986  // The Quotient is obtained by replacing Denominator by 1 in Numerator.
987  RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] =
988  cast<SCEVConstant>(One)->getValue();
989  Quotient =
990  SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true);
991  return;
992  }
993 
994  // Quotient is (Numerator - Remainder) divided by Denominator.
995  const SCEV *Q, *R;
996  const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder);
997  // This SCEV does not seem to simplify: fail the division here.
998  if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator))
999  return cannotDivide(Numerator);
1000  divide(SE, Diff, Denominator, &Q, &R);
1001  if (R != Zero)
1002  return cannotDivide(Numerator);
1003  Quotient = Q;
1004  }
1005 
1006 private:
1007  SCEVDivision(ScalarEvolution &S, const SCEV *Numerator,
1008  const SCEV *Denominator)
1009  : SE(S), Denominator(Denominator) {
1010  Zero = SE.getZero(Denominator->getType());
1011  One = SE.getOne(Denominator->getType());
1012 
1013  // We generally do not know how to divide Expr by Denominator. We
1014  // initialize the division to a "cannot divide" state to simplify the rest
1015  // of the code.
1016  cannotDivide(Numerator);
1017  }
1018 
1019  // Convenience function for giving up on the division. We set the quotient to
1020  // be equal to zero and the remainder to be equal to the numerator.
1021  void cannotDivide(const SCEV *Numerator) {
1022  Quotient = Zero;
1023  Remainder = Numerator;
1024  }
1025 
1026  ScalarEvolution &SE;
1027  const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One;
1028 };
1029 
1030 }
1031 
1032 //===----------------------------------------------------------------------===//
1033 // Simple SCEV method implementations
1034 //===----------------------------------------------------------------------===//
1035 
1036 /// Compute BC(It, K). The result has width W. Assume, K > 0.
1037 static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
1038  ScalarEvolution &SE,
1039  Type *ResultTy) {
1040  // Handle the simplest case efficiently.
1041  if (K == 1)
1042  return SE.getTruncateOrZeroExtend(It, ResultTy);
1043 
1044  // We are using the following formula for BC(It, K):
1045  //
1046  // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
1047  //
1048  // Suppose, W is the bitwidth of the return value. We must be prepared for
1049  // overflow. Hence, we must assure that the result of our computation is
1050  // equal to the accurate one modulo 2^W. Unfortunately, division isn't
1051  // safe in modular arithmetic.
1052  //
1053  // However, this code doesn't use exactly that formula; the formula it uses
1054  // is something like the following, where T is the number of factors of 2 in
1055  // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
1056  // exponentiation:
1057  //
1058  // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
1059  //
1060  // This formula is trivially equivalent to the previous formula. However,
1061  // this formula can be implemented much more efficiently. The trick is that
1062  // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
1063  // arithmetic. To do exact division in modular arithmetic, all we have
1064  // to do is multiply by the inverse. Therefore, this step can be done at
1065  // width W.
1066  //
1067  // The next issue is how to safely do the division by 2^T. The way this
1068  // is done is by doing the multiplication step at a width of at least W + T
1069  // bits. This way, the bottom W+T bits of the product are accurate. Then,
1070  // when we perform the division by 2^T (which is equivalent to a right shift
1071  // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
1072  // truncated out after the division by 2^T.
1073  //
1074  // In comparison to just directly using the first formula, this technique
1075  // is much more efficient; using the first formula requires W * K bits,
1076  // but this formula less than W + K bits. Also, the first formula requires
1077  // a division step, whereas this formula only requires multiplies and shifts.
1078  //
1079  // It doesn't matter whether the subtraction step is done in the calculation
1080  // width or the input iteration count's width; if the subtraction overflows,
1081  // the result must be zero anyway. We prefer here to do it in the width of
1082  // the induction variable because it helps a lot for certain cases; CodeGen
1083  // isn't smart enough to ignore the overflow, which leads to much less
1084  // efficient code if the width of the subtraction is wider than the native
1085  // register width.
1086  //
1087  // (It's possible to not widen at all by pulling out factors of 2 before
1088  // the multiplication; for example, K=2 can be calculated as
1089  // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
1090  // extra arithmetic, so it's not an obvious win, and it gets
1091  // much more complicated for K > 3.)
1092 
1093  // Protection from insane SCEVs; this bound is conservative,
1094  // but it probably doesn't matter.
1095  if (K > 1000)
1096  return SE.getCouldNotCompute();
1097 
1098  unsigned W = SE.getTypeSizeInBits(ResultTy);
1099 
1100  // Calculate K! / 2^T and T; we divide out the factors of two before
1101  // multiplying for calculating K! / 2^T to avoid overflow.
1102  // Other overflow doesn't matter because we only care about the bottom
1103  // W bits of the result.
1104  APInt OddFactorial(W, 1);
1105  unsigned T = 1;
1106  for (unsigned i = 3; i <= K; ++i) {
1107  APInt Mult(W, i);
1108  unsigned TwoFactors = Mult.countTrailingZeros();
1109  T += TwoFactors;
1110  Mult.lshrInPlace(TwoFactors);
1111  OddFactorial *= Mult;
1112  }
1113 
1114  // We need at least W + T bits for the multiplication step
1115  unsigned CalculationBits = W + T;
1116 
1117  // Calculate 2^T, at width T+W.
1118  APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
1119 
1120  // Calculate the multiplicative inverse of K! / 2^T;
1121  // this multiplication factor will perform the exact division by
1122  // K! / 2^T.
1123  APInt Mod = APInt::getSignedMinValue(W+1);
1124  APInt MultiplyFactor = OddFactorial.zext(W+1);
1125  MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
1126  MultiplyFactor = MultiplyFactor.trunc(W);
1127 
1128  // Calculate the product, at width T+W
1129  IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
1130  CalculationBits);
1131  const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
1132  for (unsigned i = 1; i != K; ++i) {
1133  const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
1134  Dividend = SE.getMulExpr(Dividend,
1135  SE.getTruncateOrZeroExtend(S, CalculationTy));
1136  }
1137 
1138  // Divide by 2^T
1139  const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
1140 
1141  // Truncate the result, and divide by K! / 2^T.
1142 
1143  return SE.getMulExpr(SE.getConstant(MultiplyFactor),
1144  SE.getTruncateOrZeroExtend(DivResult, ResultTy));
1145 }
1146 
1147 /// Return the value of this chain of recurrences at the specified iteration
1148 /// number. We can evaluate this recurrence by multiplying each element in the
1149 /// chain by the binomial coefficient corresponding to it. In other words, we
1150 /// can evaluate {A,+,B,+,C,+,D} as:
1151 ///
1152 /// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
1153 ///
1154 /// where BC(It, k) stands for binomial coefficient.
1155 ///
1157  ScalarEvolution &SE) const {
1158  const SCEV *Result = getStart();
1159  for (unsigned i = 1, e = getNumOperands(); i != e; ++i) {
1160  // The computation is correct in the face of overflow provided that the
1161  // multiplication is performed _after_ the evaluation of the binomial
1162  // coefficient.
1163  const SCEV *Coeff = BinomialCoefficient(It, i, SE, getType());
1164  if (isa<SCEVCouldNotCompute>(Coeff))
1165  return Coeff;
1166 
1167  Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff));
1168  }
1169  return Result;
1170 }
1171 
1172 //===----------------------------------------------------------------------===//
1173 // SCEV Expression folder implementations
1174 //===----------------------------------------------------------------------===//
1175 
1177  Type *Ty) {
1178  assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1179  "This is not a truncating conversion!");
1180  assert(isSCEVable(Ty) &&
1181  "This is not a conversion to a SCEVable type!");
1182  Ty = getEffectiveSCEVType(Ty);
1183 
1185  ID.AddInteger(scTruncate);
1186  ID.AddPointer(Op);
1187  ID.AddPointer(Ty);
1188  void *IP = nullptr;
1189  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1190 
1191  // Fold if the operand is constant.
1192  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1193  return getConstant(
1194  cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1195 
1196  // trunc(trunc(x)) --> trunc(x)
1197  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
1198  return getTruncateExpr(ST->getOperand(), Ty);
1199 
1200  // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1201  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1202  return getTruncateOrSignExtend(SS->getOperand(), Ty);
1203 
1204  // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1205  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1206  return getTruncateOrZeroExtend(SZ->getOperand(), Ty);
1207 
1208  // trunc(x1+x2+...+xN) --> trunc(x1)+trunc(x2)+...+trunc(xN) if we can
1209  // eliminate all the truncates, or we replace other casts with truncates.
1210  if (const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Op)) {
1212  bool hasTrunc = false;
1213  for (unsigned i = 0, e = SA->getNumOperands(); i != e && !hasTrunc; ++i) {
1214  const SCEV *S = getTruncateExpr(SA->getOperand(i), Ty);
1215  if (!isa<SCEVCastExpr>(SA->getOperand(i)))
1216  hasTrunc = isa<SCEVTruncateExpr>(S);
1217  Operands.push_back(S);
1218  }
1219  if (!hasTrunc)
1220  return getAddExpr(Operands);
1221  UniqueSCEVs.FindNodeOrInsertPos(ID, IP); // Mutates IP, returns NULL.
1222  }
1223 
1224  // trunc(x1*x2*...*xN) --> trunc(x1)*trunc(x2)*...*trunc(xN) if we can
1225  // eliminate all the truncates, or we replace other casts with truncates.
1226  if (const SCEVMulExpr *SM = dyn_cast<SCEVMulExpr>(Op)) {
1228  bool hasTrunc = false;
1229  for (unsigned i = 0, e = SM->getNumOperands(); i != e && !hasTrunc; ++i) {
1230  const SCEV *S = getTruncateExpr(SM->getOperand(i), Ty);
1231  if (!isa<SCEVCastExpr>(SM->getOperand(i)))
1232  hasTrunc = isa<SCEVTruncateExpr>(S);
1233  Operands.push_back(S);
1234  }
1235  if (!hasTrunc)
1236  return getMulExpr(Operands);
1237  UniqueSCEVs.FindNodeOrInsertPos(ID, IP); // Mutates IP, returns NULL.
1238  }
1239 
1240  // If the input value is a chrec scev, truncate the chrec's operands.
1241  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1243  for (const SCEV *Op : AddRec->operands())
1244  Operands.push_back(getTruncateExpr(Op, Ty));
1245  return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1246  }
1247 
1248  // The cast wasn't folded; create an explicit cast node. We can reuse
1249  // the existing insert position since if we get here, we won't have
1250  // made any changes which would invalidate it.
1251  SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1252  Op, Ty);
1253  UniqueSCEVs.InsertNode(S, IP);
1254  return S;
1255 }
1256 
1257 // Get the limit of a recurrence such that incrementing by Step cannot cause
1258 // signed overflow as long as the value of the recurrence within the
1259 // loop does not exceed this limit before incrementing.
1260 static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1261  ICmpInst::Predicate *Pred,
1262  ScalarEvolution *SE) {
1263  unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1264  if (SE->isKnownPositive(Step)) {
1265  *Pred = ICmpInst::ICMP_SLT;
1266  return SE->getConstant(APInt::getSignedMinValue(BitWidth) -
1267  SE->getSignedRangeMax(Step));
1268  }
1269  if (SE->isKnownNegative(Step)) {
1270  *Pred = ICmpInst::ICMP_SGT;
1271  return SE->getConstant(APInt::getSignedMaxValue(BitWidth) -
1272  SE->getSignedRangeMin(Step));
1273  }
1274  return nullptr;
1275 }
1276 
1277 // Get the limit of a recurrence such that incrementing by Step cannot cause
1278 // unsigned overflow as long as the value of the recurrence within the loop does
1279 // not exceed this limit before incrementing.
1280 static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step,
1281  ICmpInst::Predicate *Pred,
1282  ScalarEvolution *SE) {
1283  unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1284  *Pred = ICmpInst::ICMP_ULT;
1285 
1286  return SE->getConstant(APInt::getMinValue(BitWidth) -
1287  SE->getUnsignedRangeMax(Step));
1288 }
1289 
1290 namespace {
1291 
1292 struct ExtendOpTraitsBase {
1293  typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1294  unsigned);
1295 };
1296 
1297 // Used to make code generic over signed and unsigned overflow.
1298 template <typename ExtendOp> struct ExtendOpTraits {
1299  // Members present:
1300  //
1301  // static const SCEV::NoWrapFlags WrapType;
1302  //
1303  // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1304  //
1305  // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1306  // ICmpInst::Predicate *Pred,
1307  // ScalarEvolution *SE);
1308 };
1309 
1310 template <>
1311 struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1312  static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1313 
1314  static const GetExtendExprTy GetExtendExpr;
1315 
1316  static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1317  ICmpInst::Predicate *Pred,
1318  ScalarEvolution *SE) {
1319  return getSignedOverflowLimitForStep(Step, Pred, SE);
1320  }
1321 };
1322 
1323 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1324  SCEVSignExtendExpr>::GetExtendExpr = &ScalarEvolution::getSignExtendExpr;
1325 
1326 template <>
1327 struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1328  static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1329 
1330  static const GetExtendExprTy GetExtendExpr;
1331 
1332  static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1333  ICmpInst::Predicate *Pred,
1334  ScalarEvolution *SE) {
1335  return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1336  }
1337 };
1338 
1339 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1340  SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr;
1341 }
1342 
1343 // The recurrence AR has been shown to have no signed/unsigned wrap or something
1344 // close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1345 // easily prove NSW/NUW for its preincrement or postincrement sibling. This
1346 // allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1347 // Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1348 // expression "Step + sext/zext(PreIncAR)" is congruent with
1349 // "sext/zext(PostIncAR)"
1350 template <typename ExtendOpTy>
1351 static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1352  ScalarEvolution *SE, unsigned Depth) {
1353  auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1354  auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1355 
1356  const Loop *L = AR->getLoop();
1357  const SCEV *Start = AR->getStart();
1358  const SCEV *Step = AR->getStepRecurrence(*SE);
1359 
1360  // Check for a simple looking step prior to loop entry.
1361  const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1362  if (!SA)
1363  return nullptr;
1364 
1365  // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1366  // subtraction is expensive. For this purpose, perform a quick and dirty
1367  // difference, by checking for Step in the operand list.
1369  for (const SCEV *Op : SA->operands())
1370  if (Op != Step)
1371  DiffOps.push_back(Op);
1372 
1373  if (DiffOps.size() == SA->getNumOperands())
1374  return nullptr;
1375 
1376  // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1377  // `Step`:
1378 
1379  // 1. NSW/NUW flags on the step increment.
1380  auto PreStartFlags =
1382  const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1383  const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1384  SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1385 
1386  // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1387  // "S+X does not sign/unsign-overflow".
1388  //
1389 
1390  const SCEV *BECount = SE->getBackedgeTakenCount(L);
1391  if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1392  !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1393  return PreStart;
1394 
1395  // 2. Direct overflow check on the step operation's expression.
1396  unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1397  Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1398  const SCEV *OperandExtendedStart =
1399  SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1400  (SE->*GetExtendExpr)(Step, WideTy, Depth));
1401  if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1402  if (PreAR && AR->getNoWrapFlags(WrapType)) {
1403  // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1404  // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1405  // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1406  const_cast<SCEVAddRecExpr *>(PreAR)->setNoWrapFlags(WrapType);
1407  }
1408  return PreStart;
1409  }
1410 
1411  // 3. Loop precondition.
1412  ICmpInst::Predicate Pred;
1413  const SCEV *OverflowLimit =
1414  ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1415 
1416  if (OverflowLimit &&
1417  SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1418  return PreStart;
1419 
1420  return nullptr;
1421 }
1422 
1423 // Get the normalized zero or sign extended expression for this AddRec's Start.
1424 template <typename ExtendOpTy>
1425 static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1426  ScalarEvolution *SE,
1427  unsigned Depth) {
1428  auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1429 
1430  const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1431  if (!PreStart)
1432  return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1433 
1434  return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1435  Depth),
1436  (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1437 }
1438 
1439 // Try to prove away overflow by looking at "nearby" add recurrences. A
1440 // motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1441 // does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1442 //
1443 // Formally:
1444 //
1445 // {S,+,X} == {S-T,+,X} + T
1446 // => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1447 //
1448 // If ({S-T,+,X} + T) does not overflow ... (1)
1449 //
1450 // RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1451 //
1452 // If {S-T,+,X} does not overflow ... (2)
1453 //
1454 // RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1455 // == {Ext(S-T)+Ext(T),+,Ext(X)}
1456 //
1457 // If (S-T)+T does not overflow ... (3)
1458 //
1459 // RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1460 // == {Ext(S),+,Ext(X)} == LHS
1461 //
1462 // Thus, if (1), (2) and (3) are true for some T, then
1463 // Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1464 //
1465 // (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1466 // does not overflow" restricted to the 0th iteration. Therefore we only need
1467 // to check for (1) and (2).
1468 //
1469 // In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1470 // is `Delta` (defined below).
1471 //
1472 template <typename ExtendOpTy>
1473 bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1474  const SCEV *Step,
1475  const Loop *L) {
1476  auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1477 
1478  // We restrict `Start` to a constant to prevent SCEV from spending too much
1479  // time here. It is correct (but more expensive) to continue with a
1480  // non-constant `Start` and do a general SCEV subtraction to compute
1481  // `PreStart` below.
1482  //
1483  const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1484  if (!StartC)
1485  return false;
1486 
1487  APInt StartAI = StartC->getAPInt();
1488 
1489  for (unsigned Delta : {-2, -1, 1, 2}) {
1490  const SCEV *PreStart = getConstant(StartAI - Delta);
1491 
1494  ID.AddPointer(PreStart);
1495  ID.AddPointer(Step);
1496  ID.AddPointer(L);
1497  void *IP = nullptr;
1498  const auto *PreAR =
1499  static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1500 
1501  // Give up if we don't already have the add recurrence we need because
1502  // actually constructing an add recurrence is relatively expensive.
1503  if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1504  const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1506  const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1507  DeltaS, &Pred, this);
1508  if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1509  return true;
1510  }
1511  }
1512 
1513  return false;
1514 }
1515 
1516 const SCEV *
1518  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1519  "This is not an extending conversion!");
1520  assert(isSCEVable(Ty) &&
1521  "This is not a conversion to a SCEVable type!");
1522  Ty = getEffectiveSCEVType(Ty);
1523 
1524  // Fold if the operand is constant.
1525  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1526  return getConstant(
1527  cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(), Ty)));
1528 
1529  // zext(zext(x)) --> zext(x)
1530  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1531  return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1532 
1533  // Before doing any expensive analysis, check to see if we've already
1534  // computed a SCEV for this Op and Ty.
1537  ID.AddPointer(Op);
1538  ID.AddPointer(Ty);
1539  void *IP = nullptr;
1540  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1541  if (Depth > MaxExtDepth) {
1542  SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1543  Op, Ty);
1544  UniqueSCEVs.InsertNode(S, IP);
1545  return S;
1546  }
1547 
1548  // zext(trunc(x)) --> zext(x) or x or trunc(x)
1549  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1550  // It's possible the bits taken off by the truncate were all zero bits. If
1551  // so, we should be able to simplify this further.
1552  const SCEV *X = ST->getOperand();
1553  ConstantRange CR = getUnsignedRange(X);
1554  unsigned TruncBits = getTypeSizeInBits(ST->getType());
1555  unsigned NewBits = getTypeSizeInBits(Ty);
1556  if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1557  CR.zextOrTrunc(NewBits)))
1558  return getTruncateOrZeroExtend(X, Ty);
1559  }
1560 
1561  // If the input value is a chrec scev, and we can prove that the value
1562  // did not overflow the old, smaller, value, we can zero extend all of the
1563  // operands (often constants). This allows analysis of something like
1564  // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1565  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1566  if (AR->isAffine()) {
1567  const SCEV *Start = AR->getStart();
1568  const SCEV *Step = AR->getStepRecurrence(*this);
1569  unsigned BitWidth = getTypeSizeInBits(AR->getType());
1570  const Loop *L = AR->getLoop();
1571 
1572  if (!AR->hasNoUnsignedWrap()) {
1573  auto NewFlags = proveNoWrapViaConstantRanges(AR);
1574  const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(NewFlags);
1575  }
1576 
1577  // If we have special knowledge that this addrec won't overflow,
1578  // we don't need to do any further analysis.
1579  if (AR->hasNoUnsignedWrap())
1580  return getAddRecExpr(
1581  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1),
1582  getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
1583 
1584  // Check whether the backedge-taken count is SCEVCouldNotCompute.
1585  // Note that this serves two purposes: It filters out loops that are
1586  // simply not analyzable, and it covers the case where this code is
1587  // being called from within backedge-taken count analysis, such that
1588  // attempting to ask for the backedge-taken count would likely result
1589  // in infinite recursion. In the later case, the analysis code will
1590  // cope with a conservative value, and it will take care to purge
1591  // that value once it has finished.
1592  const SCEV *MaxBECount = getMaxBackedgeTakenCount(L);
1593  if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1594  // Manually compute the final value for AR, checking for
1595  // overflow.
1596 
1597  // Check whether the backedge-taken count can be losslessly casted to
1598  // the addrec's type. The count is always unsigned.
1599  const SCEV *CastedMaxBECount =
1600  getTruncateOrZeroExtend(MaxBECount, Start->getType());
1601  const SCEV *RecastedMaxBECount =
1602  getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType());
1603  if (MaxBECount == RecastedMaxBECount) {
1604  Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1605  // Check whether Start+Step*MaxBECount has no unsigned overflow.
1606  const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1607  SCEV::FlagAnyWrap, Depth + 1);
1608  const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1610  Depth + 1),
1611  WideTy, Depth + 1);
1612  const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1613  const SCEV *WideMaxBECount =
1614  getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1615  const SCEV *OperandExtendedAdd =
1616  getAddExpr(WideStart,
1617  getMulExpr(WideMaxBECount,
1618  getZeroExtendExpr(Step, WideTy, Depth + 1),
1619  SCEV::FlagAnyWrap, Depth + 1),
1620  SCEV::FlagAnyWrap, Depth + 1);
1621  if (ZAdd == OperandExtendedAdd) {
1622  // Cache knowledge of AR NUW, which is propagated to this AddRec.
1623  const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
1624  // Return the expression with the addrec on the outside.
1625  return getAddRecExpr(
1626  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1627  Depth + 1),
1628  getZeroExtendExpr(Step, Ty, Depth + 1), L,
1629  AR->getNoWrapFlags());
1630  }
1631  // Similar to above, only this time treat the step value as signed.
1632  // This covers loops that count down.
1633  OperandExtendedAdd =
1634  getAddExpr(WideStart,
1635  getMulExpr(WideMaxBECount,
1636  getSignExtendExpr(Step, WideTy, Depth + 1),
1637  SCEV::FlagAnyWrap, Depth + 1),
1638  SCEV::FlagAnyWrap, Depth + 1);
1639  if (ZAdd == OperandExtendedAdd) {
1640  // Cache knowledge of AR NW, which is propagated to this AddRec.
1641  // Negative step causes unsigned wrap, but it still can't self-wrap.
1642  const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
1643  // Return the expression with the addrec on the outside.
1644  return getAddRecExpr(
1645  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1646  Depth + 1),
1647  getSignExtendExpr(Step, Ty, Depth + 1), L,
1648  AR->getNoWrapFlags());
1649  }
1650  }
1651  }
1652 
1653  // Normally, in the cases we can prove no-overflow via a
1654  // backedge guarding condition, we can also compute a backedge
1655  // taken count for the loop. The exceptions are assumptions and
1656  // guards present in the loop -- SCEV is not great at exploiting
1657  // these to compute max backedge taken counts, but can still use
1658  // these to prove lack of overflow. Use this fact to avoid
1659  // doing extra work that may not pay off.
1660  if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1661  !AC.assumptions().empty()) {
1662  // If the backedge is guarded by a comparison with the pre-inc
1663  // value the addrec is safe. Also, if the entry is guarded by
1664  // a comparison with the start value and the backedge is
1665  // guarded by a comparison with the post-inc value, the addrec
1666  // is safe.
1667  if (isKnownPositive(Step)) {
1668  const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
1669  getUnsignedRangeMax(Step));
1670  if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) ||
1671  (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_ULT, Start, N) &&
1672  isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT,
1673  AR->getPostIncExpr(*this), N))) {
1674  // Cache knowledge of AR NUW, which is propagated to this
1675  // AddRec.
1676  const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
1677  // Return the expression with the addrec on the outside.
1678  return getAddRecExpr(
1679  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1680  Depth + 1),
1681  getZeroExtendExpr(Step, Ty, Depth + 1), L,
1682  AR->getNoWrapFlags());
1683  }
1684  } else if (isKnownNegative(Step)) {
1685  const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) -
1686  getSignedRangeMin(Step));
1687  if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) ||
1688  (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_UGT, Start, N) &&
1689  isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT,
1690  AR->getPostIncExpr(*this), N))) {
1691  // Cache knowledge of AR NW, which is propagated to this
1692  // AddRec. Negative step causes unsigned wrap, but it
1693  // still can't self-wrap.
1694  const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
1695  // Return the expression with the addrec on the outside.
1696  return getAddRecExpr(
1697  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1698  Depth + 1),
1699  getSignExtendExpr(Step, Ty, Depth + 1), L,
1700  AR->getNoWrapFlags());
1701  }
1702  }
1703  }
1704 
1705  if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1706  const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
1707  return getAddRecExpr(
1708  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1),
1709  getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
1710  }
1711  }
1712 
1713  if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1714  // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1715  if (SA->hasNoUnsignedWrap()) {
1716  // If the addition does not unsign overflow then we can, by definition,
1717  // commute the zero extension with the addition operation.
1719  for (const auto *Op : SA->operands())
1720  Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1721  return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1722  }
1723  }
1724 
1725  // The cast wasn't folded; create an explicit cast node.
1726  // Recompute the insert position, as it may have been invalidated.
1727  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1728  SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1729  Op, Ty);
1730  UniqueSCEVs.InsertNode(S, IP);
1731  return S;
1732 }
1733 
1734 const SCEV *
1736  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1737  "This is not an extending conversion!");
1738  assert(isSCEVable(Ty) &&
1739  "This is not a conversion to a SCEVable type!");
1740  Ty = getEffectiveSCEVType(Ty);
1741 
1742  // Fold if the operand is constant.
1743  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1744  return getConstant(
1745  cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(), Ty)));
1746 
1747  // sext(sext(x)) --> sext(x)
1748  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1749  return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1750 
1751  // sext(zext(x)) --> zext(x)
1752  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1753  return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1754 
1755  // Before doing any expensive analysis, check to see if we've already
1756  // computed a SCEV for this Op and Ty.
1759  ID.AddPointer(Op);
1760  ID.AddPointer(Ty);
1761  void *IP = nullptr;
1762  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1763  // Limit recursion depth.
1764  if (Depth > MaxExtDepth) {
1765  SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1766  Op, Ty);
1767  UniqueSCEVs.InsertNode(S, IP);
1768  return S;
1769  }
1770 
1771  // sext(trunc(x)) --> sext(x) or x or trunc(x)
1772  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1773  // It's possible the bits taken off by the truncate were all sign bits. If
1774  // so, we should be able to simplify this further.
1775  const SCEV *X = ST->getOperand();
1776  ConstantRange CR = getSignedRange(X);
1777  unsigned TruncBits = getTypeSizeInBits(ST->getType());
1778  unsigned NewBits = getTypeSizeInBits(Ty);
1779  if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1780  CR.sextOrTrunc(NewBits)))
1781  return getTruncateOrSignExtend(X, Ty);
1782  }
1783 
1784  // sext(C1 + (C2 * x)) --> C1 + sext(C2 * x) if C1 < C2
1785  if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1786  if (SA->getNumOperands() == 2) {
1787  auto *SC1 = dyn_cast<SCEVConstant>(SA->getOperand(0));
1788  auto *SMul = dyn_cast<SCEVMulExpr>(SA->getOperand(1));
1789  if (SMul && SC1) {
1790  if (auto *SC2 = dyn_cast<SCEVConstant>(SMul->getOperand(0))) {
1791  const APInt &C1 = SC1->getAPInt();
1792  const APInt &C2 = SC2->getAPInt();
1793  if (C1.isStrictlyPositive() && C2.isStrictlyPositive() &&
1794  C2.ugt(C1) && C2.isPowerOf2())
1795  return getAddExpr(getSignExtendExpr(SC1, Ty, Depth + 1),
1796  getSignExtendExpr(SMul, Ty, Depth + 1),
1797  SCEV::FlagAnyWrap, Depth + 1);
1798  }
1799  }
1800  }
1801 
1802  // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1803  if (SA->hasNoSignedWrap()) {
1804  // If the addition does not sign overflow then we can, by definition,
1805  // commute the sign extension with the addition operation.
1807  for (const auto *Op : SA->operands())
1808  Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1809  return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1810  }
1811  }
1812  // If the input value is a chrec scev, and we can prove that the value
1813  // did not overflow the old, smaller, value, we can sign extend all of the
1814  // operands (often constants). This allows analysis of something like
1815  // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
1816  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1817  if (AR->isAffine()) {
1818  const SCEV *Start = AR->getStart();
1819  const SCEV *Step = AR->getStepRecurrence(*this);
1820  unsigned BitWidth = getTypeSizeInBits(AR->getType());
1821  const Loop *L = AR->getLoop();
1822 
1823  if (!AR->hasNoSignedWrap()) {
1824  auto NewFlags = proveNoWrapViaConstantRanges(AR);
1825  const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(NewFlags);
1826  }
1827 
1828  // If we have special knowledge that this addrec won't overflow,
1829  // we don't need to do any further analysis.
1830  if (AR->hasNoSignedWrap())
1831  return getAddRecExpr(
1832  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
1833  getSignExtendExpr(Step, Ty, Depth + 1), L, SCEV::FlagNSW);
1834 
1835  // Check whether the backedge-taken count is SCEVCouldNotCompute.
1836  // Note that this serves two purposes: It filters out loops that are
1837  // simply not analyzable, and it covers the case where this code is
1838  // being called from within backedge-taken count analysis, such that
1839  // attempting to ask for the backedge-taken count would likely result
1840  // in infinite recursion. In the later case, the analysis code will
1841  // cope with a conservative value, and it will take care to purge
1842  // that value once it has finished.
1843  const SCEV *MaxBECount = getMaxBackedgeTakenCount(L);
1844  if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1845  // Manually compute the final value for AR, checking for
1846  // overflow.
1847 
1848  // Check whether the backedge-taken count can be losslessly casted to
1849  // the addrec's type. The count is always unsigned.
1850  const SCEV *CastedMaxBECount =
1851  getTruncateOrZeroExtend(MaxBECount, Start->getType());
1852  const SCEV *RecastedMaxBECount =
1853  getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType());
1854  if (MaxBECount == RecastedMaxBECount) {
1855  Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1856  // Check whether Start+Step*MaxBECount has no signed overflow.
1857  const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
1858  SCEV::FlagAnyWrap, Depth + 1);
1859  const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
1861  Depth + 1),
1862  WideTy, Depth + 1);
1863  const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
1864  const SCEV *WideMaxBECount =
1865  getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1866  const SCEV *OperandExtendedAdd =
1867  getAddExpr(WideStart,
1868  getMulExpr(WideMaxBECount,
1869  getSignExtendExpr(Step, WideTy, Depth + 1),
1870  SCEV::FlagAnyWrap, Depth + 1),
1871  SCEV::FlagAnyWrap, Depth + 1);
1872  if (SAdd == OperandExtendedAdd) {
1873  // Cache knowledge of AR NSW, which is propagated to this AddRec.
1874  const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
1875  // Return the expression with the addrec on the outside.
1876  return getAddRecExpr(
1877  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
1878  Depth + 1),
1879  getSignExtendExpr(Step, Ty, Depth + 1), L,
1880  AR->getNoWrapFlags());
1881  }
1882  // Similar to above, only this time treat the step value as unsigned.
1883  // This covers loops that count up with an unsigned step.
1884  OperandExtendedAdd =
1885  getAddExpr(WideStart,
1886  getMulExpr(WideMaxBECount,
1887  getZeroExtendExpr(Step, WideTy, Depth + 1),
1888  SCEV::FlagAnyWrap, Depth + 1),
1889  SCEV::FlagAnyWrap, Depth + 1);
1890  if (SAdd == OperandExtendedAdd) {
1891  // If AR wraps around then
1892  //
1893  // abs(Step) * MaxBECount > unsigned-max(AR->getType())
1894  // => SAdd != OperandExtendedAdd
1895  //
1896  // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
1897  // (SAdd == OperandExtendedAdd => AR is NW)
1898 
1899  const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
1900 
1901  // Return the expression with the addrec on the outside.
1902  return getAddRecExpr(
1903  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
1904  Depth + 1),
1905  getZeroExtendExpr(Step, Ty, Depth + 1), L,
1906  AR->getNoWrapFlags());
1907  }
1908  }
1909  }
1910 
1911  // Normally, in the cases we can prove no-overflow via a
1912  // backedge guarding condition, we can also compute a backedge
1913  // taken count for the loop. The exceptions are assumptions and
1914  // guards present in the loop -- SCEV is not great at exploiting
1915  // these to compute max backedge taken counts, but can still use
1916  // these to prove lack of overflow. Use this fact to avoid
1917  // doing extra work that may not pay off.
1918 
1919  if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1920  !AC.assumptions().empty()) {
1921  // If the backedge is guarded by a comparison with the pre-inc
1922  // value the addrec is safe. Also, if the entry is guarded by
1923  // a comparison with the start value and the backedge is
1924  // guarded by a comparison with the post-inc value, the addrec
1925  // is safe.
1926  ICmpInst::Predicate Pred;
1927  const SCEV *OverflowLimit =
1928  getSignedOverflowLimitForStep(Step, &Pred, this);
1929  if (OverflowLimit &&
1930  (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
1931  (isLoopEntryGuardedByCond(L, Pred, Start, OverflowLimit) &&
1932  isLoopBackedgeGuardedByCond(L, Pred, AR->getPostIncExpr(*this),
1933  OverflowLimit)))) {
1934  // Cache knowledge of AR NSW, then propagate NSW to the wide AddRec.
1935  const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
1936  return getAddRecExpr(
1937  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
1938  getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
1939  }
1940  }
1941 
1942  // If Start and Step are constants, check if we can apply this
1943  // transformation:
1944  // sext{C1,+,C2} --> C1 + sext{0,+,C2} if C1 < C2
1945  auto *SC1 = dyn_cast<SCEVConstant>(Start);
1946  auto *SC2 = dyn_cast<SCEVConstant>(Step);
1947  if (SC1 && SC2) {
1948  const APInt &C1 = SC1->getAPInt();
1949  const APInt &C2 = SC2->getAPInt();
1950  if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && C2.ugt(C1) &&
1951  C2.isPowerOf2()) {
1952  Start = getSignExtendExpr(Start, Ty, Depth + 1);
1953  const SCEV *NewAR = getAddRecExpr(getZero(AR->getType()), Step, L,
1954  AR->getNoWrapFlags());
1955  return getAddExpr(Start, getSignExtendExpr(NewAR, Ty, Depth + 1),
1956  SCEV::FlagAnyWrap, Depth + 1);
1957  }
1958  }
1959 
1960  if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
1961  const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
1962  return getAddRecExpr(
1963  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
1964  getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
1965  }
1966  }
1967 
1968  // If the input value is provably positive and we could not simplify
1969  // away the sext build a zext instead.
1970  if (isKnownNonNegative(Op))
1971  return getZeroExtendExpr(Op, Ty, Depth + 1);
1972 
1973  // The cast wasn't folded; create an explicit cast node.
1974  // Recompute the insert position, as it may have been invalidated.
1975  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1976  SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1977  Op, Ty);
1978  UniqueSCEVs.InsertNode(S, IP);
1979  return S;
1980 }
1981 
1982 /// getAnyExtendExpr - Return a SCEV for the given operand extended with
1983 /// unspecified bits out to the given type.
1984 ///
1986  Type *Ty) {
1987  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1988  "This is not an extending conversion!");
1989  assert(isSCEVable(Ty) &&
1990  "This is not a conversion to a SCEVable type!");
1991  Ty = getEffectiveSCEVType(Ty);
1992 
1993  // Sign-extend negative constants.
1994  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1995  if (SC->getAPInt().isNegative())
1996  return getSignExtendExpr(Op, Ty);
1997 
1998  // Peel off a truncate cast.
1999  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2000  const SCEV *NewOp = T->getOperand();
2001  if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2002  return getAnyExtendExpr(NewOp, Ty);
2003  return getTruncateOrNoop(NewOp, Ty);
2004  }
2005 
2006  // Next try a zext cast. If the cast is folded, use it.
2007  const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2008  if (!isa<SCEVZeroExtendExpr>(ZExt))
2009  return ZExt;
2010 
2011  // Next try a sext cast. If the cast is folded, use it.
2012  const SCEV *SExt = getSignExtendExpr(Op, Ty);
2013  if (!isa<SCEVSignExtendExpr>(SExt))
2014  return SExt;
2015 
2016  // Force the cast to be folded into the operands of an addrec.
2017  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2019  for (const SCEV *Op : AR->operands())
2020  Ops.push_back(getAnyExtendExpr(Op, Ty));
2021  return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2022  }
2023 
2024  // If the expression is obviously signed, use the sext cast value.
2025  if (isa<SCEVSMaxExpr>(Op))
2026  return SExt;
2027 
2028  // Absent any other information, use the zext cast value.
2029  return ZExt;
2030 }
2031 
2032 /// Process the given Ops list, which is a list of operands to be added under
2033 /// the given scale, update the given map. This is a helper function for
2034 /// getAddRecExpr. As an example of what it does, given a sequence of operands
2035 /// that would form an add expression like this:
2036 ///
2037 /// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2038 ///
2039 /// where A and B are constants, update the map with these values:
2040 ///
2041 /// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2042 ///
2043 /// and add 13 + A*B*29 to AccumulatedConstant.
2044 /// This will allow getAddRecExpr to produce this:
2045 ///
2046 /// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2047 ///
2048 /// This form often exposes folding opportunities that are hidden in
2049 /// the original operand list.
2050 ///
2051 /// Return true iff it appears that any interesting folding opportunities
2052 /// may be exposed. This helps getAddRecExpr short-circuit extra work in
2053 /// the common case where no interesting opportunities are present, and
2054 /// is also used as a check to avoid infinite recursion.
2055 ///
2056 static bool
2059  APInt &AccumulatedConstant,
2060  const SCEV *const *Ops, size_t NumOperands,
2061  const APInt &Scale,
2062  ScalarEvolution &SE) {
2063  bool Interesting = false;
2064 
2065  // Iterate over the add operands. They are sorted, with constants first.
2066  unsigned i = 0;
2067  while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2068  ++i;
2069  // Pull a buried constant out to the outside.
2070  if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2071  Interesting = true;
2072  AccumulatedConstant += Scale * C->getAPInt();
2073  }
2074 
2075  // Next comes everything else. We're especially interested in multiplies
2076  // here, but they're in the middle, so just visit the rest with one loop.
2077  for (; i != NumOperands; ++i) {
2078  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
2079  if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2080  APInt NewScale =
2081  Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2082  if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2083  // A multiplication of a constant with another add; recurse.
2084  const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2085  Interesting |=
2086  CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2087  Add->op_begin(), Add->getNumOperands(),
2088  NewScale, SE);
2089  } else {
2090  // A multiplication of a constant with some other value. Update
2091  // the map.
2092  SmallVector<const SCEV *, 4> MulOps(Mul->op_begin()+1, Mul->op_end());
2093  const SCEV *Key = SE.getMulExpr(MulOps);
2094  auto Pair = M.insert({Key, NewScale});
2095  if (Pair.second) {
2096  NewOps.push_back(Pair.first->first);
2097  } else {
2098  Pair.first->second += NewScale;
2099  // The map already had an entry for this value, which may indicate
2100  // a folding opportunity.
2101  Interesting = true;
2102  }
2103  }
2104  } else {
2105  // An ordinary operand. Update the map.
2106  std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2107  M.insert({Ops[i], Scale});
2108  if (Pair.second) {
2109  NewOps.push_back(Pair.first->first);
2110  } else {
2111  Pair.first->second += Scale;
2112  // The map already had an entry for this value, which may indicate
2113  // a folding opportunity.
2114  Interesting = true;
2115  }
2116  }
2117  }
2118 
2119  return Interesting;
2120 }
2121 
2122 // We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2123 // `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2124 // can't-overflow flags for the operation if possible.
2125 static SCEV::NoWrapFlags
2127  const SmallVectorImpl<const SCEV *> &Ops,
2129  using namespace std::placeholders;
2130  typedef OverflowingBinaryOperator OBO;
2131 
2132  bool CanAnalyze =
2133  Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr;
2134  (void)CanAnalyze;
2135  assert(CanAnalyze && "don't call from other places!");
2136 
2137  int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2138  SCEV::NoWrapFlags SignOrUnsignWrap =
2139  ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2140 
2141  // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2142  auto IsKnownNonNegative = [&](const SCEV *S) {
2143  return SE->isKnownNonNegative(S);
2144  };
2145 
2146  if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2147  Flags =
2148  ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2149 
2150  SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2151 
2152  if (SignOrUnsignWrap != SignOrUnsignMask && Type == scAddExpr &&
2153  Ops.size() == 2 && isa<SCEVConstant>(Ops[0])) {
2154 
2155  // (A + C) --> (A + C)<nsw> if the addition does not sign overflow
2156  // (A + C) --> (A + C)<nuw> if the addition does not unsign overflow
2157 
2158  const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2159  if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2161  Instruction::Add, C, OBO::NoSignedWrap);
2162  if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2163  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2164  }
2165  if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2167  Instruction::Add, C, OBO::NoUnsignedWrap);
2168  if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2169  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2170  }
2171  }
2172 
2173  return Flags;
2174 }
2175 
2177  if (!isLoopInvariant(S, L))
2178  return false;
2179  // If a value depends on a SCEVUnknown which is defined after the loop, we
2180  // conservatively assume that we cannot calculate it at the loop's entry.
2181  struct FindDominatedSCEVUnknown {
2182  bool Found = false;
2183  const Loop *L;
2184  DominatorTree &DT;
2185  LoopInfo &LI;
2186 
2187  FindDominatedSCEVUnknown(const Loop *L, DominatorTree &DT, LoopInfo &LI)
2188  : L(L), DT(DT), LI(LI) {}
2189 
2190  bool checkSCEVUnknown(const SCEVUnknown *SU) {
2191  if (auto *I = dyn_cast<Instruction>(SU->getValue())) {
2192  if (DT.dominates(L->getHeader(), I->getParent()))
2193  Found = true;
2194  else
2195  assert(DT.dominates(I->getParent(), L->getHeader()) &&
2196  "No dominance relationship between SCEV and loop?");
2197  }
2198  return false;
2199  }
2200 
2201  bool follow(const SCEV *S) {
2202  switch (static_cast<SCEVTypes>(S->getSCEVType())) {
2203  case scConstant:
2204  return false;
2205  case scAddRecExpr:
2206  case scTruncate:
2207  case scZeroExtend:
2208  case scSignExtend:
2209  case scAddExpr:
2210  case scMulExpr:
2211  case scUMaxExpr:
2212  case scSMaxExpr:
2213  case scUDivExpr:
2214  return true;
2215  case scUnknown:
2216  return checkSCEVUnknown(cast<SCEVUnknown>(S));
2217  case scCouldNotCompute:
2218  llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
2219  }
2220  return false;
2221  }
2222 
2223  bool isDone() { return Found; }
2224  };
2225 
2226  FindDominatedSCEVUnknown FSU(L, DT, LI);
2228  ST.visitAll(S);
2229  return !FSU.Found;
2230 }
2231 
2232 /// Get a canonical add expression, or something simpler if possible.
2235  unsigned Depth) {
2236  assert(!(Flags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2237  "only nuw or nsw allowed");
2238  assert(!Ops.empty() && "Cannot get empty add!");
2239  if (Ops.size() == 1) return Ops[0];
2240 #ifndef NDEBUG
2241  Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2242  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2243  assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2244  "SCEVAddExpr operand types don't match!");
2245 #endif
2246 
2247  // Sort by complexity, this groups all similar expression types together.
2248  GroupByComplexity(Ops, &LI, DT);
2249 
2250  Flags = StrengthenNoWrapFlags(this, scAddExpr, Ops, Flags);
2251 
2252  // If there are any constants, fold them together.
2253  unsigned Idx = 0;
2254  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2255  ++Idx;
2256  assert(Idx < Ops.size());
2257  while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2258  // We found two constants, fold them together!
2259  Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt());
2260  if (Ops.size() == 2) return Ops[0];
2261  Ops.erase(Ops.begin()+1); // Erase the folded element
2262  LHSC = cast<SCEVConstant>(Ops[0]);
2263  }
2264 
2265  // If we are left with a constant zero being added, strip it off.
2266  if (LHSC->getValue()->isZero()) {
2267  Ops.erase(Ops.begin());
2268  --Idx;
2269  }
2270 
2271  if (Ops.size() == 1) return Ops[0];
2272  }
2273 
2274  // Limit recursion calls depth.
2275  if (Depth > MaxArithDepth)
2276  return getOrCreateAddExpr(Ops, Flags);
2277 
2278  // Okay, check to see if the same value occurs in the operand list more than
2279  // once. If so, merge them together into an multiply expression. Since we
2280  // sorted the list, these values are required to be adjacent.
2281  Type *Ty = Ops[0]->getType();
2282  bool FoundMatch = false;
2283  for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2284  if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2285  // Scan ahead to count how many equal operands there are.
2286  unsigned Count = 2;
2287  while (i+Count != e && Ops[i+Count] == Ops[i])
2288  ++Count;
2289  // Merge the values into a multiply.
2290  const SCEV *Scale = getConstant(Ty, Count);
2291  const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2292  if (Ops.size() == Count)
2293  return Mul;
2294  Ops[i] = Mul;
2295  Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2296  --i; e -= Count - 1;
2297  FoundMatch = true;
2298  }
2299  if (FoundMatch)
2300  return getAddExpr(Ops, Flags);
2301 
2302  // Check for truncates. If all the operands are truncated from the same
2303  // type, see if factoring out the truncate would permit the result to be
2304  // folded. eg., trunc(x) + m*trunc(n) --> trunc(x + trunc(m)*n)
2305  // if the contents of the resulting outer trunc fold to something simple.
2306  for (; Idx < Ops.size() && isa<SCEVTruncateExpr>(Ops[Idx]); ++Idx) {
2307  const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(Ops[Idx]);
2308  Type *DstType = Trunc->getType();
2309  Type *SrcType = Trunc->getOperand()->getType();
2311  bool Ok = true;
2312  // Check all the operands to see if they can be represented in the
2313  // source type of the truncate.
2314  for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
2315  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
2316  if (T->getOperand()->getType() != SrcType) {
2317  Ok = false;
2318  break;
2319  }
2320  LargeOps.push_back(T->getOperand());
2321  } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2322  LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2323  } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
2324  SmallVector<const SCEV *, 8> LargeMulOps;
2325  for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2326  if (const SCEVTruncateExpr *T =
2327  dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2328  if (T->getOperand()->getType() != SrcType) {
2329  Ok = false;
2330  break;
2331  }
2332  LargeMulOps.push_back(T->getOperand());
2333  } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2334  LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2335  } else {
2336  Ok = false;
2337  break;
2338  }
2339  }
2340  if (Ok)
2341  LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2342  } else {
2343  Ok = false;
2344  break;
2345  }
2346  }
2347  if (Ok) {
2348  // Evaluate the expression in the larger type.
2349  const SCEV *Fold = getAddExpr(LargeOps, Flags, Depth + 1);
2350  // If it folds to something simple, use it. Otherwise, don't.
2351  if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2352  return getTruncateExpr(Fold, DstType);
2353  }
2354  }
2355 
2356  // Skip past any other cast SCEVs.
2357  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2358  ++Idx;
2359 
2360  // If there are add operands they would be next.
2361  if (Idx < Ops.size()) {
2362  bool DeletedAdd = false;
2363  while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2364  if (Ops.size() > AddOpsInlineThreshold ||
2365  Add->getNumOperands() > AddOpsInlineThreshold)
2366  break;
2367  // If we have an add, expand the add operands onto the end of the operands
2368  // list.
2369  Ops.erase(Ops.begin()+Idx);
2370  Ops.append(Add->op_begin(), Add->op_end());
2371  DeletedAdd = true;
2372  }
2373 
2374  // If we deleted at least one add, we added operands to the end of the list,
2375  // and they are not necessarily sorted. Recurse to resort and resimplify
2376  // any operands we just acquired.
2377  if (DeletedAdd)
2378  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2379  }
2380 
2381  // Skip over the add expression until we get to a multiply.
2382  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2383  ++Idx;
2384 
2385  // Check to see if there are any folding opportunities present with
2386  // operands multiplied by constant values.
2387  if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2388  uint64_t BitWidth = getTypeSizeInBits(Ty);
2391  APInt AccumulatedConstant(BitWidth, 0);
2392  if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2393  Ops.data(), Ops.size(),
2394  APInt(BitWidth, 1), *this)) {
2395  struct APIntCompare {
2396  bool operator()(const APInt &LHS, const APInt &RHS) const {
2397  return LHS.ult(RHS);
2398  }
2399  };
2400 
2401  // Some interesting folding opportunity is present, so its worthwhile to
2402  // re-generate the operands list. Group the operands by constant scale,
2403  // to avoid multiplying by the same constant scale multiple times.
2404  std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2405  for (const SCEV *NewOp : NewOps)
2406  MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2407  // Re-generate the operands list.
2408  Ops.clear();
2409  if (AccumulatedConstant != 0)
2410  Ops.push_back(getConstant(AccumulatedConstant));
2411  for (auto &MulOp : MulOpLists)
2412  if (MulOp.first != 0)
2413  Ops.push_back(getMulExpr(
2414  getConstant(MulOp.first),
2415  getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2416  SCEV::FlagAnyWrap, Depth + 1));
2417  if (Ops.empty())
2418  return getZero(Ty);
2419  if (Ops.size() == 1)
2420  return Ops[0];
2421  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2422  }
2423  }
2424 
2425  // If we are adding something to a multiply expression, make sure the
2426  // something is not already an operand of the multiply. If so, merge it into
2427  // the multiply.
2428  for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2429  const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2430  for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2431  const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2432  if (isa<SCEVConstant>(MulOpSCEV))
2433  continue;
2434  for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2435  if (MulOpSCEV == Ops[AddOp]) {
2436  // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2437  const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2438  if (Mul->getNumOperands() != 2) {
2439  // If the multiply has more than two operands, we must get the
2440  // Y*Z term.
2441  SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
2442  Mul->op_begin()+MulOp);
2443  MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
2444  InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2445  }
2446  SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2447  const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2448  const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2449  SCEV::FlagAnyWrap, Depth + 1);
2450  if (Ops.size() == 2) return OuterMul;
2451  if (AddOp < Idx) {
2452  Ops.erase(Ops.begin()+AddOp);
2453  Ops.erase(Ops.begin()+Idx-1);
2454  } else {
2455  Ops.erase(Ops.begin()+Idx);
2456  Ops.erase(Ops.begin()+AddOp-1);
2457  }
2458  Ops.push_back(OuterMul);
2459  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2460  }
2461 
2462  // Check this multiply against other multiplies being added together.
2463  for (unsigned OtherMulIdx = Idx+1;
2464  OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2465  ++OtherMulIdx) {
2466  const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2467  // If MulOp occurs in OtherMul, we can fold the two multiplies
2468  // together.
2469  for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2470  OMulOp != e; ++OMulOp)
2471  if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2472  // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2473  const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2474  if (Mul->getNumOperands() != 2) {
2475  SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
2476  Mul->op_begin()+MulOp);
2477  MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
2478  InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2479  }
2480  const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2481  if (OtherMul->getNumOperands() != 2) {
2482  SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(),
2483  OtherMul->op_begin()+OMulOp);
2484  MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end());
2485  InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2486  }
2487  SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2488  const SCEV *InnerMulSum =
2489  getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2490  const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2491  SCEV::FlagAnyWrap, Depth + 1);
2492  if (Ops.size() == 2) return OuterMul;
2493  Ops.erase(Ops.begin()+Idx);
2494  Ops.erase(Ops.begin()+OtherMulIdx-1);
2495  Ops.push_back(OuterMul);
2496  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2497  }
2498  }
2499  }
2500  }
2501 
2502  // If there are any add recurrences in the operands list, see if any other
2503  // added values are loop invariant. If so, we can fold them into the
2504  // recurrence.
2505  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2506  ++Idx;
2507 
2508  // Scan over all recurrences, trying to fold loop invariants into them.
2509  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2510  // Scan all of the other operands to this add and add them to the vector if
2511  // they are loop invariant w.r.t. the recurrence.
2513  const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2514  const Loop *AddRecLoop = AddRec->getLoop();
2515  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2516  if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2517  LIOps.push_back(Ops[i]);
2518  Ops.erase(Ops.begin()+i);
2519  --i; --e;
2520  }
2521 
2522  // If we found some loop invariants, fold them into the recurrence.
2523  if (!LIOps.empty()) {
2524  // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2525  LIOps.push_back(AddRec->getStart());
2526 
2527  SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(),
2528  AddRec->op_end());
2529  // This follows from the fact that the no-wrap flags on the outer add
2530  // expression are applicable on the 0th iteration, when the add recurrence
2531  // will be equal to its start value.
2532  AddRecOps[0] = getAddExpr(LIOps, Flags, Depth + 1);
2533 
2534  // Build the new addrec. Propagate the NUW and NSW flags if both the
2535  // outer add and the inner addrec are guaranteed to have no overflow.
2536  // Always propagate NW.
2537  Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2538  const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2539 
2540  // If all of the other operands were loop invariant, we are done.
2541  if (Ops.size() == 1) return NewRec;
2542 
2543  // Otherwise, add the folded AddRec by the non-invariant parts.
2544  for (unsigned i = 0;; ++i)
2545  if (Ops[i] == AddRec) {
2546  Ops[i] = NewRec;
2547  break;
2548  }
2549  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2550  }
2551 
2552  // Okay, if there weren't any loop invariants to be folded, check to see if
2553  // there are multiple AddRec's with the same loop induction variable being
2554  // added together. If so, we can fold them.
2555  for (unsigned OtherIdx = Idx+1;
2556  OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2557  ++OtherIdx) {
2558  // We expect the AddRecExpr's to be sorted in reverse dominance order,
2559  // so that the 1st found AddRecExpr is dominated by all others.
2560  assert(DT.dominates(
2561  cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2562  AddRec->getLoop()->getHeader()) &&
2563  "AddRecExprs are not sorted in reverse dominance order?");
2564  if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2565  // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2566  SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(),
2567  AddRec->op_end());
2568  for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2569  ++OtherIdx) {
2570  const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2571  if (OtherAddRec->getLoop() == AddRecLoop) {
2572  for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2573  i != e; ++i) {
2574  if (i >= AddRecOps.size()) {
2575  AddRecOps.append(OtherAddRec->op_begin()+i,
2576  OtherAddRec->op_end());
2577  break;
2578  }
2579  SmallVector<const SCEV *, 2> TwoOps = {
2580  AddRecOps[i], OtherAddRec->getOperand(i)};
2581  AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2582  }
2583  Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2584  }
2585  }
2586  // Step size has changed, so we cannot guarantee no self-wraparound.
2587  Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2588  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2589  }
2590  }
2591 
2592  // Otherwise couldn't fold anything into this recurrence. Move onto the
2593  // next one.
2594  }
2595 
2596  // Okay, it looks like we really DO need an add expr. Check to see if we
2597  // already have one, otherwise create a new one.
2598  return getOrCreateAddExpr(Ops, Flags);
2599 }
2600 
2601 const SCEV *
2602 ScalarEvolution::getOrCreateAddExpr(SmallVectorImpl<const SCEV *> &Ops,
2605  ID.AddInteger(scAddExpr);
2606  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2607  ID.AddPointer(Ops[i]);
2608  void *IP = nullptr;
2609  SCEVAddExpr *S =
2610  static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2611  if (!S) {
2612  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2613  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2614  S = new (SCEVAllocator)
2615  SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
2616  UniqueSCEVs.InsertNode(S, IP);
2617  }
2618  S->setNoWrapFlags(Flags);
2619  return S;
2620 }
2621 
2622 const SCEV *
2623 ScalarEvolution::getOrCreateMulExpr(SmallVectorImpl<const SCEV *> &Ops,
2624  SCEV::NoWrapFlags Flags) {
2626  ID.AddInteger(scMulExpr);
2627  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2628  ID.AddPointer(Ops[i]);
2629  void *IP = nullptr;
2630  SCEVMulExpr *S =
2631  static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2632  if (!S) {
2633  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2634  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2635  S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
2636  O, Ops.size());
2637  UniqueSCEVs.InsertNode(S, IP);
2638  }
2639  S->setNoWrapFlags(Flags);
2640  return S;
2641 }
2642 
2643 static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
2644  uint64_t k = i*j;
2645  if (j > 1 && k / j != i) Overflow = true;
2646  return k;
2647 }
2648 
2649 /// Compute the result of "n choose k", the binomial coefficient. If an
2650 /// intermediate computation overflows, Overflow will be set and the return will
2651 /// be garbage. Overflow is not cleared on absence of overflow.
2652 static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
2653  // We use the multiplicative formula:
2654  // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
2655  // At each iteration, we take the n-th term of the numeral and divide by the
2656  // (k-n)th term of the denominator. This division will always produce an
2657  // integral result, and helps reduce the chance of overflow in the
2658  // intermediate computations. However, we can still overflow even when the
2659  // final result would fit.
2660 
2661  if (n == 0 || n == k) return 1;
2662  if (k > n) return 0;
2663 
2664  if (k > n/2)
2665  k = n-k;
2666 
2667  uint64_t r = 1;
2668  for (uint64_t i = 1; i <= k; ++i) {
2669  r = umul_ov(r, n-(i-1), Overflow);
2670  r /= i;
2671  }
2672  return r;
2673 }
2674 
2675 /// Determine if any of the operands in this SCEV are a constant or if
2676 /// any of the add or multiply expressions in this SCEV contain a constant.
2677 static bool containsConstantSomewhere(const SCEV *StartExpr) {
2679  Ops.push_back(StartExpr);
2680  while (!Ops.empty()) {
2681  const SCEV *CurrentExpr = Ops.pop_back_val();
2682  if (isa<SCEVConstant>(*CurrentExpr))
2683  return true;
2684 
2685  if (isa<SCEVAddExpr>(*CurrentExpr) || isa<SCEVMulExpr>(*CurrentExpr)) {
2686  const auto *CurrentNAry = cast<SCEVNAryExpr>(CurrentExpr);
2687  Ops.append(CurrentNAry->op_begin(), CurrentNAry->op_end());
2688  }
2689  }
2690  return false;
2691 }
2692 
2693 /// Get a canonical multiply expression, or something simpler if possible.
2695  SCEV::NoWrapFlags Flags,
2696  unsigned Depth) {
2697  assert(Flags == maskFlags(Flags, SCEV::FlagNUW | SCEV::FlagNSW) &&
2698  "only nuw or nsw allowed");
2699  assert(!Ops.empty() && "Cannot get empty mul!");
2700  if (Ops.size() == 1) return Ops[0];
2701 #ifndef NDEBUG
2702  Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2703  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2704  assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2705  "SCEVMulExpr operand types don't match!");
2706 #endif
2707 
2708  // Sort by complexity, this groups all similar expression types together.
2709  GroupByComplexity(Ops, &LI, DT);
2710 
2711  Flags = StrengthenNoWrapFlags(this, scMulExpr, Ops, Flags);
2712 
2713  // Limit recursion calls depth.
2714  if (Depth > MaxArithDepth)
2715  return getOrCreateMulExpr(Ops, Flags);
2716 
2717  // If there are any constants, fold them together.
2718  unsigned Idx = 0;
2719  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2720 
2721  // C1*(C2+V) -> C1*C2 + C1*V
2722  if (Ops.size() == 2)
2723  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
2724  // If any of Add's ops are Adds or Muls with a constant,
2725  // apply this transformation as well.
2726  if (Add->getNumOperands() == 2)
2728  return getAddExpr(getMulExpr(LHSC, Add->getOperand(0),
2729  SCEV::FlagAnyWrap, Depth + 1),
2730  getMulExpr(LHSC, Add->getOperand(1),
2731  SCEV::FlagAnyWrap, Depth + 1),
2732  SCEV::FlagAnyWrap, Depth + 1);
2733 
2734  ++Idx;
2735  while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2736  // We found two constants, fold them together!
2737  ConstantInt *Fold =
2738  ConstantInt::get(getContext(), LHSC->getAPInt() * RHSC->getAPInt());
2739  Ops[0] = getConstant(Fold);
2740  Ops.erase(Ops.begin()+1); // Erase the folded element
2741  if (Ops.size() == 1) return Ops[0];
2742  LHSC = cast<SCEVConstant>(Ops[0]);
2743  }
2744 
2745  // If we are left with a constant one being multiplied, strip it off.
2746  if (cast<SCEVConstant>(Ops[0])->getValue()->isOne()) {
2747  Ops.erase(Ops.begin());
2748  --Idx;
2749  } else if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
2750  // If we have a multiply of zero, it will always be zero.
2751  return Ops[0];
2752  } else if (Ops[0]->isAllOnesValue()) {
2753  // If we have a mul by -1 of an add, try distributing the -1 among the
2754  // add operands.
2755  if (Ops.size() == 2) {
2756  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
2758  bool AnyFolded = false;
2759  for (const SCEV *AddOp : Add->operands()) {
2760  const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
2761  Depth + 1);
2762  if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
2763  NewOps.push_back(Mul);
2764  }
2765  if (AnyFolded)
2766  return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
2767  } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
2768  // Negation preserves a recurrence's no self-wrap property.
2770  for (const SCEV *AddRecOp : AddRec->operands())
2771  Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
2772  Depth + 1));
2773 
2774  return getAddRecExpr(Operands, AddRec->getLoop(),
2775  AddRec->getNoWrapFlags(SCEV::FlagNW));
2776  }
2777  }
2778  }
2779 
2780  if (Ops.size() == 1)
2781  return Ops[0];
2782  }
2783 
2784  // Skip over the add expression until we get to a multiply.
2785  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2786  ++Idx;
2787 
2788  // If there are mul operands inline them all into this expression.
2789  if (Idx < Ops.size()) {
2790  bool DeletedMul = false;
2791  while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2792  if (Ops.size() > MulOpsInlineThreshold)
2793  break;
2794  // If we have an mul, expand the mul operands onto the end of the
2795  // operands list.
2796  Ops.erase(Ops.begin()+Idx);
2797  Ops.append(Mul->op_begin(), Mul->op_end());
2798  DeletedMul = true;
2799  }
2800 
2801  // If we deleted at least one mul, we added operands to the end of the
2802  // list, and they are not necessarily sorted. Recurse to resort and
2803  // resimplify any operands we just acquired.
2804  if (DeletedMul)
2805  return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2806  }
2807 
2808  // If there are any add recurrences in the operands list, see if any other
2809  // added values are loop invariant. If so, we can fold them into the
2810  // recurrence.
2811  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2812  ++Idx;
2813 
2814  // Scan over all recurrences, trying to fold loop invariants into them.
2815  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2816  // Scan all of the other operands to this mul and add them to the vector
2817  // if they are loop invariant w.r.t. the recurrence.
2819  const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2820  const Loop *AddRecLoop = AddRec->getLoop();
2821  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2822  if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2823  LIOps.push_back(Ops[i]);
2824  Ops.erase(Ops.begin()+i);
2825  --i; --e;
2826  }
2827 
2828  // If we found some loop invariants, fold them into the recurrence.
2829  if (!LIOps.empty()) {
2830  // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
2832  NewOps.reserve(AddRec->getNumOperands());
2833  const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
2834  for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
2835  NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
2836  SCEV::FlagAnyWrap, Depth + 1));
2837 
2838  // Build the new addrec. Propagate the NUW and NSW flags if both the
2839  // outer mul and the inner addrec are guaranteed to have no overflow.
2840  //
2841  // No self-wrap cannot be guaranteed after changing the step size, but
2842  // will be inferred if either NUW or NSW is true.
2843  Flags = AddRec->getNoWrapFlags(clearFlags(Flags, SCEV::FlagNW));
2844  const SCEV *NewRec = getAddRecExpr(NewOps, AddRecLoop, Flags);
2845 
2846  // If all of the other operands were loop invariant, we are done.
2847  if (Ops.size() == 1) return NewRec;
2848 
2849  // Otherwise, multiply the folded AddRec by the non-invariant parts.
2850  for (unsigned i = 0;; ++i)
2851  if (Ops[i] == AddRec) {
2852  Ops[i] = NewRec;
2853  break;
2854  }
2855  return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2856  }
2857 
2858  // Okay, if there weren't any loop invariants to be folded, check to see
2859  // if there are multiple AddRec's with the same loop induction variable
2860  // being multiplied together. If so, we can fold them.
2861 
2862  // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
2863  // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
2864  // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
2865  // ]]],+,...up to x=2n}.
2866  // Note that the arguments to choose() are always integers with values
2867  // known at compile time, never SCEV objects.
2868  //
2869  // The implementation avoids pointless extra computations when the two
2870  // addrec's are of different length (mathematically, it's equivalent to
2871  // an infinite stream of zeros on the right).
2872  bool OpsModified = false;
2873  for (unsigned OtherIdx = Idx+1;
2874  OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2875  ++OtherIdx) {
2876  const SCEVAddRecExpr *OtherAddRec =
2877  dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2878  if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop)
2879  continue;
2880 
2881  bool Overflow = false;
2882  Type *Ty = AddRec->getType();
2883  bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
2884  SmallVector<const SCEV*, 7> AddRecOps;
2885  for (int x = 0, xe = AddRec->getNumOperands() +
2886  OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
2887  const SCEV *Term = getZero(Ty);
2888  for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
2889  uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
2890  for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
2891  ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
2892  z < ze && !Overflow; ++z) {
2893  uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
2894  uint64_t Coeff;
2895  if (LargerThan64Bits)
2896  Coeff = umul_ov(Coeff1, Coeff2, Overflow);
2897  else
2898  Coeff = Coeff1*Coeff2;
2899  const SCEV *CoeffTerm = getConstant(Ty, Coeff);
2900  const SCEV *Term1 = AddRec->getOperand(y-z);
2901  const SCEV *Term2 = OtherAddRec->getOperand(z);
2902  Term = getAddExpr(Term, getMulExpr(CoeffTerm, Term1, Term2,
2903  SCEV::FlagAnyWrap, Depth + 1),
2904  SCEV::FlagAnyWrap, Depth + 1);
2905  }
2906  }
2907  AddRecOps.push_back(Term);
2908  }
2909  if (!Overflow) {
2910  const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
2912  if (Ops.size() == 2) return NewAddRec;
2913  Ops[Idx] = NewAddRec;
2914  Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2915  OpsModified = true;
2916  AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
2917  if (!AddRec)
2918  break;
2919  }
2920  }
2921  if (OpsModified)
2922  return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2923 
2924  // Otherwise couldn't fold anything into this recurrence. Move onto the
2925  // next one.
2926  }
2927 
2928  // Okay, it looks like we really DO need an mul expr. Check to see if we
2929  // already have one, otherwise create a new one.
2930  return getOrCreateMulExpr(Ops, Flags);
2931 }
2932 
2933 /// Get a canonical unsigned division expression, or something simpler if
2934 /// possible.
2936  const SCEV *RHS) {
2937  assert(getEffectiveSCEVType(LHS->getType()) ==
2938  getEffectiveSCEVType(RHS->getType()) &&
2939  "SCEVUDivExpr operand types don't match!");
2940 
2941  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
2942  if (RHSC->getValue()->isOne())
2943  return LHS; // X udiv 1 --> x
2944  // If the denominator is zero, the result of the udiv is undefined. Don't
2945  // try to analyze it, because the resolution chosen here may differ from
2946  // the resolution chosen in other parts of the compiler.
2947  if (!RHSC->getValue()->isZero()) {
2948  // Determine if the division can be folded into the operands of
2949  // its operands.
2950  // TODO: Generalize this to non-constants by using known-bits information.
2951  Type *Ty = LHS->getType();
2952  unsigned LZ = RHSC->getAPInt().countLeadingZeros();
2953  unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
2954  // For non-power-of-two values, effectively round the value up to the
2955  // nearest power of two.
2956  if (!RHSC->getAPInt().isPowerOf2())
2957  ++MaxShiftAmt;
2958  IntegerType *ExtTy =
2959  IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
2960  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
2961  if (const SCEVConstant *Step =
2962  dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
2963  // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
2964  const APInt &StepInt = Step->getAPInt();
2965  const APInt &DivInt = RHSC->getAPInt();
2966  if (!StepInt.urem(DivInt) &&
2967  getZeroExtendExpr(AR, ExtTy) ==
2968  getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
2969  getZeroExtendExpr(Step, ExtTy),
2970  AR->getLoop(), SCEV::FlagAnyWrap)) {
2972  for (const SCEV *Op : AR->operands())
2973  Operands.push_back(getUDivExpr(Op, RHS));
2974  return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
2975  }
2976  /// Get a canonical UDivExpr for a recurrence.
2977  /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
2978  // We can currently only fold X%N if X is constant.
2979  const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
2980  if (StartC && !DivInt.urem(StepInt) &&
2981  getZeroExtendExpr(AR, ExtTy) ==
2982  getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
2983  getZeroExtendExpr(Step, ExtTy),
2984  AR->getLoop(), SCEV::FlagAnyWrap)) {
2985  const APInt &StartInt = StartC->getAPInt();
2986  const APInt &StartRem = StartInt.urem(StepInt);
2987  if (StartRem != 0)
2988  LHS = getAddRecExpr(getConstant(StartInt - StartRem), Step,
2989  AR->getLoop(), SCEV::FlagNW);
2990  }
2991  }
2992  // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
2993  if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
2995  for (const SCEV *Op : M->operands())
2996  Operands.push_back(getZeroExtendExpr(Op, ExtTy));
2997  if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
2998  // Find an operand that's safely divisible.
2999  for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3000  const SCEV *Op = M->getOperand(i);
3001  const SCEV *Div = getUDivExpr(Op, RHSC);
3002  if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3003  Operands = SmallVector<const SCEV *, 4>(M->op_begin(),
3004  M->op_end());
3005  Operands[i] = Div;
3006  return getMulExpr(Operands);
3007  }
3008  }
3009  }
3010  // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3011  if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3013  for (const SCEV *Op : A->operands())
3014  Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3015  if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3016  Operands.clear();
3017  for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3018  const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3019  if (isa<SCEVUDivExpr>(Op) ||
3020  getMulExpr(Op, RHS) != A->getOperand(i))
3021  break;
3022  Operands.push_back(Op);
3023  }
3024  if (Operands.size() == A->getNumOperands())
3025  return getAddExpr(Operands);
3026  }
3027  }
3028 
3029  // Fold if both operands are constant.
3030  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
3031  Constant *LHSCV = LHSC->getValue();
3032  Constant *RHSCV = RHSC->getValue();
3033  return getConstant(cast<ConstantInt>(ConstantExpr::getUDiv(LHSCV,
3034  RHSCV)));
3035  }
3036  }
3037  }
3038 
3040  ID.AddInteger(scUDivExpr);
3041  ID.AddPointer(LHS);
3042  ID.AddPointer(RHS);
3043  void *IP = nullptr;
3044  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3045  SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3046  LHS, RHS);
3047  UniqueSCEVs.InsertNode(S, IP);
3048  return S;
3049 }
3050 
3051 static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3052  APInt A = C1->getAPInt().abs();
3053  APInt B = C2->getAPInt().abs();
3054  uint32_t ABW = A.getBitWidth();
3055  uint32_t BBW = B.getBitWidth();
3056 
3057  if (ABW > BBW)
3058  B = B.zext(ABW);
3059  else if (ABW < BBW)
3060  A = A.zext(BBW);
3061 
3062  return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3063 }
3064 
3065 /// Get a canonical unsigned division expression, or something simpler if
3066 /// possible. There is no representation for an exact udiv in SCEV IR, but we
3067 /// can attempt to remove factors from the LHS and RHS. We can't do this when
3068 /// it's not exact because the udiv may be clearing bits.
3070  const SCEV *RHS) {
3071  // TODO: we could try to find factors in all sorts of things, but for now we
3072  // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3073  // end of this file for inspiration.
3074 
3075  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
3076  if (!Mul || !Mul->hasNoUnsignedWrap())
3077  return getUDivExpr(LHS, RHS);
3078 
3079  if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3080  // If the mulexpr multiplies by a constant, then that constant must be the
3081  // first element of the mulexpr.
3082  if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3083  if (LHSCst == RHSCst) {
3085  Operands.append(Mul->op_begin() + 1, Mul->op_end());
3086  return getMulExpr(Operands);
3087  }
3088 
3089  // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3090  // that there's a factor provided by one of the other terms. We need to
3091  // check.
3092  APInt Factor = gcd(LHSCst, RHSCst);
3093  if (!Factor.isIntN(1)) {
3094  LHSCst =
3095  cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3096  RHSCst =
3097  cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3099  Operands.push_back(LHSCst);
3100  Operands.append(Mul->op_begin() + 1, Mul->op_end());
3101  LHS = getMulExpr(Operands);
3102  RHS = RHSCst;
3103  Mul = dyn_cast<SCEVMulExpr>(LHS);
3104  if (!Mul)
3105  return getUDivExactExpr(LHS, RHS);
3106  }
3107  }
3108  }
3109 
3110  for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3111  if (Mul->getOperand(i) == RHS) {
3113  Operands.append(Mul->op_begin(), Mul->op_begin() + i);
3114  Operands.append(Mul->op_begin() + i + 1, Mul->op_end());
3115  return getMulExpr(Operands);
3116  }
3117  }
3118 
3119  return getUDivExpr(LHS, RHS);
3120 }
3121 
3122 /// Get an add recurrence expression for the specified loop. Simplify the
3123 /// expression as much as possible.
3124 const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3125  const Loop *L,
3126  SCEV::NoWrapFlags Flags) {
3128  Operands.push_back(Start);
3129  if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3130  if (StepChrec->getLoop() == L) {
3131  Operands.append(StepChrec->op_begin(), StepChrec->op_end());
3132  return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3133  }
3134 
3135  Operands.push_back(Step);
3136  return getAddRecExpr(Operands, L, Flags);
3137 }
3138 
3139 /// Get an add recurrence expression for the specified loop. Simplify the
3140 /// expression as much as possible.
3141 const SCEV *
3143  const Loop *L, SCEV::NoWrapFlags Flags) {
3144  if (Operands.size() == 1) return Operands[0];
3145 #ifndef NDEBUG
3146  Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3147  for (unsigned i = 1, e = Operands.size(); i != e; ++i)
3148  assert(getEffectiveSCEVType(Operands[i]->getType()) == ETy &&
3149  "SCEVAddRecExpr operand types don't match!");
3150  for (unsigned i = 0, e = Operands.size(); i != e; ++i)
3151  assert(isLoopInvariant(Operands[i], L) &&
3152  "SCEVAddRecExpr operand is not loop-invariant!");
3153 #endif
3154 
3155  if (Operands.back()->isZero()) {
3156  Operands.pop_back();
3157  return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3158  }
3159 
3160  // It's tempting to want to call getMaxBackedgeTakenCount count here and
3161  // use that information to infer NUW and NSW flags. However, computing a
3162  // BE count requires calling getAddRecExpr, so we may not yet have a
3163  // meaningful BE count at this point (and if we don't, we'd be stuck
3164  // with a SCEVCouldNotCompute as the cached BE count).
3165 
3166  Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3167 
3168  // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3169  if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3170  const Loop *NestedLoop = NestedAR->getLoop();
3171  if (L->contains(NestedLoop)
3172  ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3173  : (!NestedLoop->contains(L) &&
3174  DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3175  SmallVector<const SCEV *, 4> NestedOperands(NestedAR->op_begin(),
3176  NestedAR->op_end());
3177  Operands[0] = NestedAR->getStart();
3178  // AddRecs require their operands be loop-invariant with respect to their
3179  // loops. Don't perform this transformation if it would break this
3180  // requirement.
3181  bool AllInvariant = all_of(
3182  Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3183 
3184  if (AllInvariant) {
3185  // Create a recurrence for the outer loop with the same step size.
3186  //
3187  // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3188  // inner recurrence has the same property.
3189  SCEV::NoWrapFlags OuterFlags =
3190  maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3191 
3192  NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3193  AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3194  return isLoopInvariant(Op, NestedLoop);
3195  });
3196 
3197  if (AllInvariant) {
3198  // Ok, both add recurrences are valid after the transformation.
3199  //
3200  // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3201  // the outer recurrence has the same property.
3202  SCEV::NoWrapFlags InnerFlags =
3203  maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3204  return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3205  }
3206  }
3207  // Reset Operands to its original state.
3208  Operands[0] = NestedAR;
3209  }
3210  }
3211 
3212  // Okay, it looks like we really DO need an addrec expr. Check to see if we
3213  // already have one, otherwise create a new one.
3216  for (unsigned i = 0, e = Operands.size(); i != e; ++i)
3217  ID.AddPointer(Operands[i]);
3218  ID.AddPointer(L);
3219  void *IP = nullptr;
3220  SCEVAddRecExpr *S =
3221  static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3222  if (!S) {
3223  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Operands.size());
3224  std::uninitialized_copy(Operands.begin(), Operands.end(), O);
3225  S = new (SCEVAllocator) SCEVAddRecExpr(ID.Intern(SCEVAllocator),
3226  O, Operands.size(), L);
3227  UniqueSCEVs.InsertNode(S, IP);
3228  }
3229  S->setNoWrapFlags(Flags);
3230  return S;
3231 }
3232 
3233 const SCEV *
3235  const SmallVectorImpl<const SCEV *> &IndexExprs) {
3236  const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3237  // getSCEV(Base)->getType() has the same address space as Base->getType()
3238  // because SCEV::getType() preserves the address space.
3239  Type *IntPtrTy = getEffectiveSCEVType(BaseExpr->getType());
3240  // FIXME(PR23527): Don't blindly transfer the inbounds flag from the GEP
3241  // instruction to its SCEV, because the Instruction may be guarded by control
3242  // flow and the no-overflow bits may not be valid for the expression in any
3243  // context. This can be fixed similarly to how these flags are handled for
3244  // adds.
3247 
3248  const SCEV *TotalOffset = getZero(IntPtrTy);
3249  // The array size is unimportant. The first thing we do on CurTy is getting
3250  // its element type.
3251  Type *CurTy = ArrayType::get(GEP->getSourceElementType(), 0);
3252  for (const SCEV *IndexExpr : IndexExprs) {
3253  // Compute the (potentially symbolic) offset in bytes for this index.
3254  if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3255  // For a struct, add the member offset.
3256  ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3257  unsigned FieldNo = Index->getZExtValue();
3258  const SCEV *FieldOffset = getOffsetOfExpr(IntPtrTy, STy, FieldNo);
3259 
3260  // Add the field offset to the running total offset.
3261  TotalOffset = getAddExpr(TotalOffset, FieldOffset);
3262 
3263  // Update CurTy to the type of the field at Index.
3264  CurTy = STy->getTypeAtIndex(Index);
3265  } else {
3266  // Update CurTy to its element type.
3267  CurTy = cast<SequentialType>(CurTy)->getElementType();
3268  // For an array, add the element offset, explicitly scaled.
3269  const SCEV *ElementSize = getSizeOfExpr(IntPtrTy, CurTy);
3270  // Getelementptr indices are signed.
3271  IndexExpr = getTruncateOrSignExtend(IndexExpr, IntPtrTy);
3272 
3273  // Multiply the index by the element size to compute the element offset.
3274  const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, Wrap);
3275 
3276  // Add the element offset to the running total offset.
3277  TotalOffset = getAddExpr(TotalOffset, LocalOffset);
3278  }
3279  }
3280 
3281  // Add the total offset from all the GEP indices to the base.
3282  return getAddExpr(BaseExpr, TotalOffset, Wrap);
3283 }
3284 
3286  const SCEV *RHS) {
3287  SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
3288  return getSMaxExpr(Ops);
3289 }
3290 
3291 const SCEV *
3293  assert(!Ops.empty() && "Cannot get empty smax!");
3294  if (Ops.size() == 1) return Ops[0];
3295 #ifndef NDEBUG
3296  Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3297  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3298  assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3299  "SCEVSMaxExpr operand types don't match!");
3300 #endif
3301 
3302  // Sort by complexity, this groups all similar expression types together.
3303  GroupByComplexity(Ops, &LI, DT);
3304 
3305  // If there are any constants, fold them together.
3306  unsigned Idx = 0;
3307  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3308  ++Idx;
3309  assert(Idx < Ops.size());
3310  while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3311  // We found two constants, fold them together!
3312  ConstantInt *Fold = ConstantInt::get(
3313  getContext(), APIntOps::smax(LHSC->getAPInt(), RHSC->getAPInt()));
3314  Ops[0] = getConstant(Fold);
3315  Ops.erase(Ops.begin()+1); // Erase the folded element
3316  if (Ops.size() == 1) return Ops[0];
3317  LHSC = cast<SCEVConstant>(Ops[0]);
3318  }
3319 
3320  // If we are left with a constant minimum-int, strip it off.
3321  if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(true)) {
3322  Ops.erase(Ops.begin());
3323  --Idx;
3324  } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(true)) {
3325  // If we have an smax with a constant maximum-int, it will always be
3326  // maximum-int.
3327  return Ops[0];
3328  }
3329 
3330  if (Ops.size() == 1) return Ops[0];
3331  }
3332 
3333  // Find the first SMax
3334  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scSMaxExpr)
3335  ++Idx;
3336 
3337  // Check to see if one of the operands is an SMax. If so, expand its operands
3338  // onto our operand list, and recurse to simplify.
3339  if (Idx < Ops.size()) {
3340  bool DeletedSMax = false;
3341  while (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(Ops[Idx])) {
3342  Ops.erase(Ops.begin()+Idx);
3343  Ops.append(SMax->op_begin(), SMax->op_end());
3344  DeletedSMax = true;
3345  }
3346 
3347  if (DeletedSMax)
3348  return getSMaxExpr(Ops);
3349  }
3350 
3351  // Okay, check to see if the same value occurs in the operand list twice. If
3352  // so, delete one. Since we sorted the list, these values are required to
3353  // be adjacent.
3354  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
3355  // X smax Y smax Y --> X smax Y
3356  // X smax Y --> X, if X is always greater than Y
3357  if (Ops[i] == Ops[i+1] ||
3358  isKnownPredicate(ICmpInst::ICMP_SGE, Ops[i], Ops[i+1])) {
3359  Ops.erase(Ops.begin()+i+1, Ops.begin()+i+2);
3360  --i; --e;
3361  } else if (isKnownPredicate(ICmpInst::ICMP_SLE, Ops[i], Ops[i+1])) {
3362  Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
3363  --i; --e;
3364  }
3365 
3366  if (Ops.size() == 1) return Ops[0];
3367 
3368  assert(!Ops.empty() && "Reduced smax down to nothing!");
3369 
3370  // Okay, it looks like we really DO need an smax expr. Check to see if we
3371  // already have one, otherwise create a new one.
3373  ID.AddInteger(scSMaxExpr);
3374  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3375  ID.AddPointer(Ops[i]);
3376  void *IP = nullptr;
3377  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3378  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3379  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3380  SCEV *S = new (SCEVAllocator) SCEVSMaxExpr(ID.Intern(SCEVAllocator),
3381  O, Ops.size());
3382  UniqueSCEVs.InsertNode(S, IP);
3383  return S;
3384 }
3385 
3387  const SCEV *RHS) {
3388  SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
3389  return getUMaxExpr(Ops);
3390 }
3391 
3392 const SCEV *
3394  assert(!Ops.empty() && "Cannot get empty umax!");
3395  if (Ops.size() == 1) return Ops[0];
3396 #ifndef NDEBUG
3397  Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3398  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3399  assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3400  "SCEVUMaxExpr operand types don't match!");
3401 #endif
3402 
3403  // Sort by complexity, this groups all similar expression types together.
3404  GroupByComplexity(Ops, &LI, DT);
3405 
3406  // If there are any constants, fold them together.
3407  unsigned Idx = 0;
3408  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3409  ++Idx;
3410  assert(Idx < Ops.size());
3411  while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3412  // We found two constants, fold them together!
3413  ConstantInt *Fold = ConstantInt::get(
3414  getContext(), APIntOps::umax(LHSC->getAPInt(), RHSC->getAPInt()));
3415  Ops[0] = getConstant(Fold);
3416  Ops.erase(Ops.begin()+1); // Erase the folded element
3417  if (Ops.size() == 1) return Ops[0];
3418  LHSC = cast<SCEVConstant>(Ops[0]);
3419  }
3420 
3421  // If we are left with a constant minimum-int, strip it off.
3422  if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(false)) {
3423  Ops.erase(Ops.begin());
3424  --Idx;
3425  } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(false)) {
3426  // If we have an umax with a constant maximum-int, it will always be
3427  // maximum-int.
3428  return Ops[0];
3429  }
3430 
3431  if (Ops.size() == 1) return Ops[0];
3432  }
3433 
3434  // Find the first UMax
3435  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scUMaxExpr)
3436  ++Idx;
3437 
3438  // Check to see if one of the operands is a UMax. If so, expand its operands
3439  // onto our operand list, and recurse to simplify.
3440  if (Idx < Ops.size()) {
3441  bool DeletedUMax = false;
3442  while (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(Ops[Idx])) {
3443  Ops.erase(Ops.begin()+Idx);
3444  Ops.append(UMax->op_begin(), UMax->op_end());
3445  DeletedUMax = true;
3446  }
3447 
3448  if (DeletedUMax)
3449  return getUMaxExpr(Ops);
3450  }
3451 
3452  // Okay, check to see if the same value occurs in the operand list twice. If
3453  // so, delete one. Since we sorted the list, these values are required to
3454  // be adjacent.
3455  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
3456  // X umax Y umax Y --> X umax Y
3457  // X umax Y --> X, if X is always greater than Y
3458  if (Ops[i] == Ops[i+1] ||
3459  isKnownPredicate(ICmpInst::ICMP_UGE, Ops[i], Ops[i+1])) {
3460  Ops.erase(Ops.begin()+i+1, Ops.begin()+i+2);
3461  --i; --e;
3462  } else if (isKnownPredicate(ICmpInst::ICMP_ULE, Ops[i], Ops[i+1])) {
3463  Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
3464  --i; --e;
3465  }
3466 
3467  if (Ops.size() == 1) return Ops[0];
3468 
3469  assert(!Ops.empty() && "Reduced umax down to nothing!");
3470 
3471  // Okay, it looks like we really DO need a umax expr. Check to see if we
3472  // already have one, otherwise create a new one.
3474  ID.AddInteger(scUMaxExpr);
3475  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3476  ID.AddPointer(Ops[i]);
3477  void *IP = nullptr;
3478  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3479  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3480  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3481  SCEV *S = new (SCEVAllocator) SCEVUMaxExpr(ID.Intern(SCEVAllocator),
3482  O, Ops.size());
3483  UniqueSCEVs.InsertNode(S, IP);
3484  return S;
3485 }
3486 
3488  const SCEV *RHS) {
3489  // ~smax(~x, ~y) == smin(x, y).
3490  return getNotSCEV(getSMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS)));
3491 }
3492 
3494  const SCEV *RHS) {
3495  // ~umax(~x, ~y) == umin(x, y)
3496  return getNotSCEV(getUMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS)));
3497 }
3498 
3499 const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) {
3500  // We can bypass creating a target-independent
3501  // constant expression and then folding it back into a ConstantInt.
3502  // This is just a compile-time optimization.
3503  return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
3504 }
3505 
3507  StructType *STy,
3508  unsigned FieldNo) {
3509  // We can bypass creating a target-independent
3510  // constant expression and then folding it back into a ConstantInt.
3511  // This is just a compile-time optimization.
3512  return getConstant(
3513  IntTy, getDataLayout().getStructLayout(STy)->getElementOffset(FieldNo));
3514 }
3515 
3517  // Don't attempt to do anything other than create a SCEVUnknown object
3518  // here. createSCEV only calls getUnknown after checking for all other
3519  // interesting possibilities, and any other code that calls getUnknown
3520  // is doing so in order to hide a value from SCEV canonicalization.
3521 
3523  ID.AddInteger(scUnknown);
3524  ID.AddPointer(V);
3525  void *IP = nullptr;
3526  if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
3527  assert(cast<SCEVUnknown>(S)->getValue() == V &&
3528  "Stale SCEVUnknown in uniquing map!");
3529  return S;
3530  }
3531  SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
3532  FirstUnknown);
3533  FirstUnknown = cast<SCEVUnknown>(S);
3534  UniqueSCEVs.InsertNode(S, IP);
3535  return S;
3536 }
3537 
3538 //===----------------------------------------------------------------------===//
3539 // Basic SCEV Analysis and PHI Idiom Recognition Code
3540 //
3541 
3542 /// Test if values of the given type are analyzable within the SCEV
3543 /// framework. This primarily includes integer types, and it can optionally
3544 /// include pointer types if the ScalarEvolution class has access to
3545 /// target-specific information.
3547  // Integers and pointers are always SCEVable.
3548  return Ty->isIntegerTy() || Ty->isPointerTy();
3549 }
3550 
3551 /// Return the size in bits of the specified type, for which isSCEVable must
3552 /// return true.
3554  assert(isSCEVable(Ty) && "Type is not SCEVable!");
3555  return getDataLayout().getTypeSizeInBits(Ty);
3556 }
3557 
3558 /// Return a type with the same bitwidth as the given type and which represents
3559 /// how SCEV will treat the given type, for which isSCEVable must return
3560 /// true. For pointer types, this is the pointer-sized integer type.
3562  assert(isSCEVable(Ty) && "Type is not SCEVable!");
3563 
3564  if (Ty->isIntegerTy())
3565  return Ty;
3566 
3567  // The only other support type is pointer.
3568  assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
3569  return getDataLayout().getIntPtrType(Ty);
3570 }
3571 
3573  return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
3574 }
3575 
3577  return CouldNotCompute.get();
3578 }
3579 
3580 bool ScalarEvolution::checkValidity(const SCEV *S) const {
3581  bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
3582  auto *SU = dyn_cast<SCEVUnknown>(S);
3583  return SU && SU->getValue() == nullptr;
3584  });
3585 
3586  return !ContainsNulls;
3587 }
3588 
3590  HasRecMapType::iterator I = HasRecMap.find(S);
3591  if (I != HasRecMap.end())
3592  return I->second;
3593 
3594  bool FoundAddRec = SCEVExprContains(S, isa<SCEVAddRecExpr, const SCEV *>);
3595  HasRecMap.insert({S, FoundAddRec});
3596  return FoundAddRec;
3597 }
3598 
3599 /// Try to split a SCEVAddExpr into a pair of {SCEV, ConstantInt}.
3600 /// If \p S is a SCEVAddExpr and is composed of a sub SCEV S' and an
3601 /// offset I, then return {S', I}, else return {\p S, nullptr}.
3602 static std::pair<const SCEV *, ConstantInt *> splitAddExpr(const SCEV *S) {
3603  const auto *Add = dyn_cast<SCEVAddExpr>(S);
3604  if (!Add)
3605  return {S, nullptr};
3606 
3607  if (Add->getNumOperands() != 2)
3608  return {S, nullptr};
3609 
3610  auto *ConstOp = dyn_cast<SCEVConstant>(Add->getOperand(0));
3611  if (!ConstOp)
3612  return {S, nullptr};
3613 
3614  return {Add->getOperand(1), ConstOp->getValue()};
3615 }
3616 
3617 /// Return the ValueOffsetPair set for \p S. \p S can be represented
3618 /// by the value and offset from any ValueOffsetPair in the set.
3621  ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
3622  if (SI == ExprValueMap.end())
3623  return nullptr;
3624 #ifndef NDEBUG
3625  if (VerifySCEVMap) {
3626  // Check there is no dangling Value in the set returned.
3627  for (const auto &VE : SI->second)
3628  assert(ValueExprMap.count(VE.first));
3629  }
3630 #endif
3631  return &SI->second;
3632 }
3633 
3634 /// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
3635 /// cannot be used separately. eraseValueFromMap should be used to remove
3636 /// V from ValueExprMap and ExprValueMap at the same time.
3638  ValueExprMapType::iterator I = ValueExprMap.find_as(V);
3639  if (I != ValueExprMap.end()) {
3640  const SCEV *S = I->second;
3641  // Remove {V, 0} from the set of ExprValueMap[S]
3642  if (SetVector<ValueOffsetPair> *SV = getSCEVValues(S))
3643  SV->remove({V, nullptr});
3644 
3645  // Remove {V, Offset} from the set of ExprValueMap[Stripped]
3646  const SCEV *Stripped;
3648  std::tie(Stripped, Offset) = splitAddExpr(S);
3649  if (Offset != nullptr) {
3650  if (SetVector<ValueOffsetPair> *SV = getSCEVValues(Stripped))
3651  SV->remove({V, Offset});
3652  }
3653  ValueExprMap.erase(V);
3654  }
3655 }
3656 
3657 /// Return an existing SCEV if it exists, otherwise analyze the expression and
3658 /// create a new one.
3660  assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
3661 
3662  const SCEV *S = getExistingSCEV(V);
3663  if (S == nullptr) {
3664  S = createSCEV(V);
3665  // During PHI resolution, it is possible to create two SCEVs for the same
3666  // V, so it is needed to double check whether V->S is inserted into
3667  // ValueExprMap before insert S->{V, 0} into ExprValueMap.
3668  std::pair<ValueExprMapType::iterator, bool> Pair =
3669  ValueExprMap.insert({SCEVCallbackVH(V, this), S});
3670  if (Pair.second) {
3671  ExprValueMap[S].insert({V, nullptr});
3672 
3673  // If S == Stripped + Offset, add Stripped -> {V, Offset} into
3674  // ExprValueMap.
3675  const SCEV *Stripped = S;
3676  ConstantInt *Offset = nullptr;
3677  std::tie(Stripped, Offset) = splitAddExpr(S);
3678  // If stripped is SCEVUnknown, don't bother to save
3679  // Stripped -> {V, offset}. It doesn't simplify and sometimes even
3680  // increase the complexity of the expansion code.
3681  // If V is GetElementPtrInst, don't save Stripped -> {V, offset}
3682  // because it may generate add/sub instead of GEP in SCEV expansion.
3683  if (Offset != nullptr && !isa<SCEVUnknown>(Stripped) &&
3684  !isa<GetElementPtrInst>(V))
3685  ExprValueMap[Stripped].insert({V, Offset});
3686  }
3687  }
3688  return S;
3689 }
3690 
3691 const SCEV *ScalarEvolution::getExistingSCEV(Value *V) {
3692  assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
3693 
3694  ValueExprMapType::iterator I = ValueExprMap.find_as(V);
3695  if (I != ValueExprMap.end()) {
3696  const SCEV *S = I->second;
3697  if (checkValidity(S))
3698  return S;
3699  eraseValueFromMap(V);
3700  forgetMemoizedResults(S);
3701  }
3702  return nullptr;
3703 }
3704 
3705 /// Return a SCEV corresponding to -V = -1*V
3706 ///
3708  SCEV::NoWrapFlags Flags) {
3709  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
3710  return getConstant(
3711  cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
3712 
3713  Type *Ty = V->getType();
3714  Ty = getEffectiveSCEVType(Ty);
3715  return getMulExpr(
3716  V, getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty))), Flags);
3717 }
3718 
3719 /// Return a SCEV corresponding to ~V = -1-V
3721  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
3722  return getConstant(
3723  cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
3724 
3725  Type *Ty = V->getType();
3726  Ty = getEffectiveSCEVType(Ty);
3727  const SCEV *AllOnes =
3728  getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty)));
3729  return getMinusSCEV(AllOnes, V);
3730 }
3731 
3732 const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
3733  SCEV::NoWrapFlags Flags,
3734  unsigned Depth) {
3735  // Fast path: X - X --> 0.
3736  if (LHS == RHS)
3737  return getZero(LHS->getType());
3738 
3739  // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
3740  // makes it so that we cannot make much use of NUW.
3741  auto AddFlags = SCEV::FlagAnyWrap;
3742  const bool RHSIsNotMinSigned =
3743  !getSignedRangeMin(RHS).isMinSignedValue();
3744  if (maskFlags(Flags, SCEV::FlagNSW) == SCEV::FlagNSW) {
3745  // Let M be the minimum representable signed value. Then (-1)*RHS
3746  // signed-wraps if and only if RHS is M. That can happen even for
3747  // a NSW subtraction because e.g. (-1)*M signed-wraps even though
3748  // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
3749  // (-1)*RHS, we need to prove that RHS != M.
3750  //
3751  // If LHS is non-negative and we know that LHS - RHS does not
3752  // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
3753  // either by proving that RHS > M or that LHS >= 0.
3754  if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
3755  AddFlags = SCEV::FlagNSW;
3756  }
3757  }
3758 
3759  // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
3760  // RHS is NSW and LHS >= 0.
3761  //
3762  // The difficulty here is that the NSW flag may have been proven
3763  // relative to a loop that is to be found in a recurrence in LHS and
3764  // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
3765  // larger scope than intended.
3766  auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
3767 
3768  return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
3769 }
3770 
3771 const SCEV *
3773  Type *SrcTy = V->getType();
3774  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
3775  (Ty->isIntegerTy() || Ty->isPointerTy()) &&
3776  "Cannot truncate or zero extend with non-integer arguments!");
3777  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
3778  return V; // No conversion
3779  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
3780  return getTruncateExpr(V, Ty);
3781  return getZeroExtendExpr(V, Ty);
3782 }
3783 
3784 const SCEV *
3786  Type *Ty) {
3787  Type *SrcTy = V->getType();
3788  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
3789  (Ty->isIntegerTy() || Ty->isPointerTy()) &&
3790  "Cannot truncate or zero extend with non-integer arguments!");
3791  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
3792  return V; // No conversion
3793  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
3794  return getTruncateExpr(V, Ty);
3795  return getSignExtendExpr(V, Ty);
3796 }
3797 
3798 const SCEV *
3800  Type *SrcTy = V->getType();
3801  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
3802  (Ty->isIntegerTy() || Ty->isPointerTy()) &&
3803  "Cannot noop or zero extend with non-integer arguments!");
3804  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
3805  "getNoopOrZeroExtend cannot truncate!");
3806  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
3807  return V; // No conversion
3808  return getZeroExtendExpr(V, Ty);
3809 }
3810 
3811 const SCEV *
3813  Type *SrcTy = V->getType();
3814  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
3815  (Ty->isIntegerTy() || Ty->isPointerTy()) &&
3816  "Cannot noop or sign extend with non-integer arguments!");
3817  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
3818  "getNoopOrSignExtend cannot truncate!");
3819  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
3820  return V; // No conversion
3821  return getSignExtendExpr(V, Ty);
3822 }
3823 
3824 const SCEV *
3826  Type *SrcTy = V->getType();
3827  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
3828  (Ty->isIntegerTy() || Ty->isPointerTy()) &&
3829  "Cannot noop or any extend with non-integer arguments!");
3830  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
3831  "getNoopOrAnyExtend cannot truncate!");
3832  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
3833  return V; // No conversion
3834  return getAnyExtendExpr(V, Ty);
3835 }
3836 
3837 const SCEV *
3839  Type *SrcTy = V->getType();
3840  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
3841  (Ty->isIntegerTy() || Ty->isPointerTy()) &&
3842  "Cannot truncate or noop with non-integer arguments!");
3843  assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) &&
3844  "getTruncateOrNoop cannot extend!");
3845  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
3846  return V; // No conversion
3847  return getTruncateExpr(V, Ty);
3848 }
3849 
3851  const SCEV *RHS) {
3852  const SCEV *PromotedLHS = LHS;
3853  const SCEV *PromotedRHS = RHS;
3854 
3855  if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
3856  PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
3857  else
3858  PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
3859 
3860  return getUMaxExpr(PromotedLHS, PromotedRHS);
3861 }
3862 
3864  const SCEV *RHS) {
3865  const SCEV *PromotedLHS = LHS;
3866  const SCEV *PromotedRHS = RHS;
3867 
3868  if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
3869  PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
3870  else
3871  PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
3872 
3873  return getUMinExpr(PromotedLHS, PromotedRHS);
3874 }
3875 
3877  // A pointer operand may evaluate to a nonpointer expression, such as null.
3878  if (!V->getType()->isPointerTy())
3879  return V;
3880 
3881  if (const SCEVCastExpr *Cast = dyn_cast<SCEVCastExpr>(V)) {
3882  return getPointerBase(Cast->getOperand());
3883  } else if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(V)) {
3884  const SCEV *PtrOp = nullptr;
3885  for (const SCEV *NAryOp : NAry->operands()) {
3886  if (NAryOp->getType()->isPointerTy()) {
3887  // Cannot find the base of an expression with multiple pointer operands.
3888  if (PtrOp)
3889  return V;
3890  PtrOp = NAryOp;
3891  }
3892  }
3893  if (!PtrOp)
3894  return V;
3895  return getPointerBase(PtrOp);
3896  }
3897  return V;
3898 }
3899 
3900 /// Push users of the given Instruction onto the given Worklist.
3901 static void
3903  SmallVectorImpl<Instruction *> &Worklist) {
3904  // Push the def-use children onto the Worklist stack.
3905  for (User *U : I->users())
3906  Worklist.push_back(cast<Instruction>(U));
3907 }
3908 
3909 void ScalarEvolution::forgetSymbolicName(Instruction *PN, const SCEV *SymName) {
3911  PushDefUseChildren(PN, Worklist);
3912 
3914  Visited.insert(PN);
3915  while (!Worklist.empty()) {
3916  Instruction *I = Worklist.pop_back_val();
3917  if (!Visited.insert(I).second)
3918  continue;
3919 
3920  auto It = ValueExprMap.find_as(static_cast<Value *>(I));
3921  if (It != ValueExprMap.end()) {
3922  const SCEV *Old = It->second;
3923 
3924  // Short-circuit the def-use traversal if the symbolic name
3925  // ceases to appear in expressions.
3926  if (Old != SymName && !hasOperand(Old, SymName))
3927  continue;
3928 
3929  // SCEVUnknown for a PHI either means that it has an unrecognized
3930  // structure, it's a PHI that's in the progress of being computed
3931  // by createNodeForPHI, or it's a single-value PHI. In the first case,
3932  // additional loop trip count information isn't going to change anything.
3933  // In the second case, createNodeForPHI will perform the necessary
3934  // updates on its own when it gets to that point. In the third, we do
3935  // want to forget the SCEVUnknown.
3936  if (!isa<PHINode>(I) ||
3937  !isa<SCEVUnknown>(Old) ||
3938  (I != PN && Old == SymName)) {
3939  eraseValueFromMap(It->first);
3940  forgetMemoizedResults(Old);
3941  }
3942  }
3943 
3944  PushDefUseChildren(I, Worklist);
3945  }
3946 }
3947 
3948 namespace {
3949 class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
3950 public:
3951  static const SCEV *rewrite(const SCEV *S, const Loop *L,
3952  ScalarEvolution &SE) {
3953  SCEVInitRewriter Rewriter(L, SE);
3954  const SCEV *Result = Rewriter.visit(S);
3955  return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
3956  }
3957 
3958  SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
3959  : SCEVRewriteVisitor(SE), L(L), Valid(true) {}
3960 
3961  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
3962  if (!SE.isLoopInvariant(Expr, L))
3963  Valid = false;
3964  return Expr;
3965  }
3966 
3967  const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
3968  // Only allow AddRecExprs for this loop.
3969  if (Expr->getLoop() == L)
3970  return Expr->getStart();
3971  Valid = false;
3972  return Expr;
3973  }
3974 
3975  bool isValid() { return Valid; }
3976 
3977 private:
3978  const Loop *L;
3979  bool Valid;
3980 };
3981 
3982 class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
3983 public:
3984  static const SCEV *rewrite(const SCEV *S, const Loop *L,
3985  ScalarEvolution &SE) {
3986  SCEVShiftRewriter Rewriter(L, SE);
3987  const SCEV *Result = Rewriter.visit(S);
3988  return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
3989  }
3990 
3991  SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
3992  : SCEVRewriteVisitor(SE), L(L), Valid(true) {}
3993 
3994  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
3995  // Only allow AddRecExprs for this loop.
3996  if (!SE.isLoopInvariant(Expr, L))
3997  Valid = false;
3998  return Expr;
3999  }
4000 
4001  const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4002  if (Expr->getLoop() == L && Expr->isAffine())
4003  return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
4004  Valid = false;
4005  return Expr;
4006  }
4007  bool isValid() { return Valid; }
4008 
4009 private:
4010  const Loop *L;
4011  bool Valid;
4012 };
4013 } // end anonymous namespace
4014 
4016 ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
4017  if (!AR->isAffine())
4018  return SCEV::FlagAnyWrap;
4019 
4020  typedef OverflowingBinaryOperator OBO;
4022 
4023  if (!AR->hasNoSignedWrap()) {
4024  ConstantRange AddRecRange = getSignedRange(AR);
4025  ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
4026 
4028  Instruction::Add, IncRange, OBO::NoSignedWrap);
4029  if (NSWRegion.contains(AddRecRange))
4030  Result = ScalarEvolution::setFlags(Result, SCEV::FlagNSW);
4031  }
4032 
4033  if (!AR->hasNoUnsignedWrap()) {
4034  ConstantRange AddRecRange = getUnsignedRange(AR);
4035  ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
4036 
4038  Instruction::Add, IncRange, OBO::NoUnsignedWrap);
4039  if (NUWRegion.contains(AddRecRange))
4040  Result = ScalarEvolution::setFlags(Result, SCEV::FlagNUW);
4041  }
4042 
4043  return Result;
4044 }
4045 
4046 namespace {
4047 /// Represents an abstract binary operation. This may exist as a
4048 /// normal instruction or constant expression, or may have been
4049 /// derived from an expression tree.
4050 struct BinaryOp {
4051  unsigned Opcode;
4052  Value *LHS;
4053  Value *RHS;
4054  bool IsNSW;
4055  bool IsNUW;
4056 
4057  /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
4058  /// constant expression.
4059  Operator *Op;
4060 
4061  explicit BinaryOp(Operator *Op)
4062  : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
4063  IsNSW(false), IsNUW(false), Op(Op) {
4064  if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
4065  IsNSW = OBO->hasNoSignedWrap();
4066  IsNUW = OBO->hasNoUnsignedWrap();
4067  }
4068  }
4069 
4070  explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
4071  bool IsNUW = false)
4072  : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW),
4073  Op(nullptr) {}
4074 };
4075 }
4076 
4077 
4078 /// Try to map \p V into a BinaryOp, and return \c None on failure.
4080  auto *Op = dyn_cast<Operator>(V);
4081  if (!Op)
4082  return None;
4083 
4084  // Implementation detail: all the cleverness here should happen without
4085  // creating new SCEV expressions -- our caller knowns tricks to avoid creating
4086  // SCEV expressions when possible, and we should not break that.
4087 
4088  switch (Op->getOpcode()) {
4089  case Instruction::Add:
4090  case Instruction::Sub:
4091  case Instruction::Mul:
4092  case Instruction::UDiv:
4093  case Instruction::And:
4094  case Instruction::Or:
4095  case Instruction::AShr:
4096  case Instruction::Shl:
4097  return BinaryOp(Op);
4098 
4099  case Instruction::Xor:
4100  if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
4101  // If the RHS of the xor is a signmask, then this is just an add.
4102  // Instcombine turns add of signmask into xor as a strength reduction step.
4103  if (RHSC->getValue().isSignMask())
4104  return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
4105  return BinaryOp(Op);
4106 
4107  case Instruction::LShr:
4108  // Turn logical shift right of a constant into a unsigned divide.
4109  if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
4110  uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
4111 
4112  // If the shift count is not less than the bitwidth, the result of
4113  // the shift is undefined. Don't try to analyze it, because the
4114  // resolution chosen here may differ from the resolution chosen in
4115  // other parts of the compiler.
4116  if (SA->getValue().ult(BitWidth)) {
4117  Constant *X =
4118  ConstantInt::get(SA->getContext(),
4119  APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
4120  return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
4121  }
4122  }
4123  return BinaryOp(Op);
4124 
4125  case Instruction::ExtractValue: {
4126  auto *EVI = cast<ExtractValueInst>(Op);
4127  if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
4128  break;
4129 
4130  auto *CI = dyn_cast<CallInst>(EVI->getAggregateOperand());
4131  if (!CI)
4132  break;
4133 
4134  if (auto *F = CI->getCalledFunction())
4135  switch (F->getIntrinsicID()) {
4136  case Intrinsic::sadd_with_overflow:
4137  case Intrinsic::uadd_with_overflow: {
4138  if (!isOverflowIntrinsicNoWrap(cast<IntrinsicInst>(CI), DT))
4139  return BinaryOp(Instruction::Add, CI->getArgOperand(0),
4140  CI->getArgOperand(1));
4141 
4142  // Now that we know that all uses of the arithmetic-result component of
4143  // CI are guarded by the overflow check, we can go ahead and pretend
4144  // that the arithmetic is non-overflowing.
4145  if (F->getIntrinsicID() == Intrinsic::sadd_with_overflow)
4146  return BinaryOp(Instruction::Add, CI->getArgOperand(0),
4147  CI->getArgOperand(1), /* IsNSW = */ true,
4148  /* IsNUW = */ false);
4149  else
4150  return BinaryOp(Instruction::Add, CI->getArgOperand(0),
4151  CI->getArgOperand(1), /* IsNSW = */ false,
4152  /* IsNUW*/ true);
4153  }
4154 
4155  case Intrinsic::ssub_with_overflow:
4156  case Intrinsic::usub_with_overflow:
4157  return BinaryOp(Instruction::Sub, CI->getArgOperand(0),
4158  CI->getArgOperand(1));
4159 
4160  case Intrinsic::smul_with_overflow:
4161  case Intrinsic::umul_with_overflow:
4162  return BinaryOp(Instruction::Mul, CI->getArgOperand(0),
4163  CI->getArgOperand(1));
4164  default:
4165  break;
4166  }
4167  }
4168 
4169  default:
4170  break;
4171  }
4172 
4173  return None;
4174 }
4175 
4176 /// Helper function to createAddRecFromPHIWithCasts. We have a phi
4177 /// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
4178 /// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
4179 /// way. This function checks if \p Op, an operand of this SCEVAddExpr,
4180 /// follows one of the following patterns:
4181 /// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
4182 /// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
4183 /// If the SCEV expression of \p Op conforms with one of the expected patterns
4184 /// we return the type of the truncation operation, and indicate whether the
4185 /// truncated type should be treated as signed/unsigned by setting
4186 /// \p Signed to true/false, respectively.
4187 static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
4188  bool &Signed, ScalarEvolution &SE) {
4189 
4190  // The case where Op == SymbolicPHI (that is, with no type conversions on
4191  // the way) is handled by the regular add recurrence creating logic and
4192  // would have already been triggered in createAddRecForPHI. Reaching it here
4193  // means that createAddRecFromPHI had failed for this PHI before (e.g.,
4194  // because one of the other operands of the SCEVAddExpr updating this PHI is
4195  // not invariant).
4196  //
4197  // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
4198  // this case predicates that allow us to prove that Op == SymbolicPHI will
4199  // be added.
4200  if (Op == SymbolicPHI)
4201  return nullptr;
4202 
4203  unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
4204  unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
4205  if (SourceBits != NewBits)
4206  return nullptr;
4207 
4210  if (!SExt && !ZExt)
4211  return nullptr;
4212  const SCEVTruncateExpr *Trunc =
4213  SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand())
4214  : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand());
4215  if (!Trunc)
4216  return nullptr;
4217  const SCEV *X = Trunc->getOperand();
4218  if (X != SymbolicPHI)
4219  return nullptr;
4220  Signed = SExt ? true : false;
4221  return Trunc->getType();
4222 }
4223 
4224 static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
4225  if (!PN->getType()->isIntegerTy())
4226  return nullptr;
4227  const Loop *L = LI.getLoopFor(PN->getParent());
4228  if (!L || L->getHeader() != PN->getParent())
4229  return nullptr;
4230  return L;
4231 }
4232 
4233 // Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
4234 // computation that updates the phi follows the following pattern:
4235 // (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
4236 // which correspond to a phi->trunc->sext/zext->add->phi update chain.
4237 // If so, try to see if it can be rewritten as an AddRecExpr under some
4238 // Predicates. If successful, return them as a pair. Also cache the results
4239 // of the analysis.
4240 //
4241 // Example usage scenario:
4242 // Say the Rewriter is called for the following SCEV:
4243 // 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
4244 // where:
4245 // %X = phi i64 (%Start, %BEValue)
4246 // It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
4247 // and call this function with %SymbolicPHI = %X.
4248 //
4249 // The analysis will find that the value coming around the backedge has
4250 // the following SCEV:
4251 // BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
4252 // Upon concluding that this matches the desired pattern, the function
4253 // will return the pair {NewAddRec, SmallPredsVec} where:
4254 // NewAddRec = {%Start,+,%Step}
4255 // SmallPredsVec = {P1, P2, P3} as follows:
4256 // P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
4257 // P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
4258 // P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
4259 // The returned pair means that SymbolicPHI can be rewritten into NewAddRec
4260 // under the predicates {P1,P2,P3}.
4261 // This predicated rewrite will be cached in PredicatedSCEVRewrites:
4262 // PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
4263 //
4264 // TODO's:
4265 //
4266 // 1) Extend the Induction descriptor to also support inductions that involve
4267 // casts: When needed (namely, when we are called in the context of the
4268 // vectorizer induction analysis), a Set of cast instructions will be
4269 // populated by this method, and provided back to isInductionPHI. This is
4270 // needed to allow the vectorizer to properly record them to be ignored by
4271 // the cost model and to avoid vectorizing them (otherwise these casts,
4272 // which are redundant under the runtime overflow checks, will be
4273 // vectorized, which can be costly).
4274 //
4275 // 2) Support additional induction/PHISCEV patterns: We also want to support
4276 // inductions where the sext-trunc / zext-trunc operations (partly) occur
4277 // after the induction update operation (the induction increment):
4278 //
4279 // (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
4280 // which correspond to a phi->add->trunc->sext/zext->phi update chain.
4281 //
4282 // (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
4283 // which correspond to a phi->trunc->add->sext/zext->phi update chain.
4284 //
4285 // 3) Outline common code with createAddRecFromPHI to avoid duplication.
4286 //
4288 ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
4290 
4291  // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
4292  // return an AddRec expression under some predicate.
4293 
4294  auto *PN = cast<PHINode>(SymbolicPHI->getValue());
4295  const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
4296  assert (L && "Expecting an integer loop header phi");
4297 
4298  // The loop may have multiple entrances or multiple exits; we can analyze
4299  // this phi as an addrec if it has a unique entry value and a unique
4300  // backedge value.
4301  Value *BEValueV = nullptr, *StartValueV = nullptr;
4302  for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
4303  Value *V = PN->getIncomingValue(i);
4304  if (L->contains(PN->getIncomingBlock(i))) {
4305  if (!BEValueV) {
4306  BEValueV = V;
4307  } else if (BEValueV != V) {
4308  BEValueV = nullptr;
4309  break;
4310  }
4311  } else if (!StartValueV) {
4312  StartValueV = V;
4313  } else if (StartValueV != V) {
4314  StartValueV = nullptr;
4315  break;
4316  }
4317  }
4318  if (!BEValueV || !StartValueV)
4319  return None;
4320 
4321  const SCEV *BEValue = getSCEV(BEValueV);
4322 
4323  // If the value coming around the backedge is an add with the symbolic
4324  // value we just inserted, possibly with casts that we can ignore under
4325  // an appropriate runtime guard, then we found a simple induction variable!
4326  const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
4327  if (!Add)
4328  return None;
4329 
4330  // If there is a single occurrence of the symbolic value, possibly
4331  // casted, replace it with a recurrence.
4332  unsigned FoundIndex = Add->getNumOperands();
4333  Type *TruncTy = nullptr;
4334  bool Signed;
4335  for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
4336  if ((TruncTy =
4337  isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
4338  if (FoundIndex == e) {
4339  FoundIndex = i;
4340  break;
4341  }
4342 
4343  if (FoundIndex == Add->getNumOperands())
4344  return None;
4345 
4346  // Create an add with everything but the specified operand.
4348  for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
4349  if (i != FoundIndex)
4350  Ops.push_back(Add->getOperand(i));
4351  const SCEV *Accum = getAddExpr(Ops);
4352 
4353  // The runtime checks will not be valid if the step amount is
4354  // varying inside the loop.
4355  if (!isLoopInvariant(Accum, L))
4356  return None;
4357 
4358 
4359  // *** Part2: Create the predicates
4360 
4361  // Analysis was successful: we have a phi-with-cast pattern for which we
4362  // can return an AddRec expression under the following predicates:
4363  //
4364  // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
4365  // fits within the truncated type (does not overflow) for i = 0 to n-1.
4366  // P2: An Equal predicate that guarantees that
4367  // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
4368  // P3: An Equal predicate that guarantees that
4369  // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
4370  //
4371  // As we next prove, the above predicates guarantee that:
4372  // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
4373  //
4374  //
4375  // More formally, we want to prove that:
4376  // Expr(i+1) = Start + (i+1) * Accum
4377  // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
4378  //
4379  // Given that:
4380  // 1) Expr(0) = Start
4381  // 2) Expr(1) = Start + Accum
4382  // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
4383  // 3) Induction hypothesis (step i):
4384  // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
4385  //
4386  // Proof:
4387  // Expr(i+1) =
4388  // = Start + (i+1)*Accum
4389  // = (Start + i*Accum) + Accum
4390  // = Expr(i) + Accum
4391  // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
4392  // :: from step i
4393  //
4394  // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
4395  //
4396  // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
4397  // + (Ext ix (Trunc iy (Accum) to ix) to iy)
4398  // + Accum :: from P3
4399  //
4400  // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
4401  // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
4402  //
4403  // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
4404  // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
4405  //
4406  // By induction, the same applies to all iterations 1<=i<n:
4407  //
4408 
4409  // Create a truncated addrec for which we will add a no overflow check (P1).
4410  const SCEV *StartVal = getSCEV(StartValueV);
4411  const SCEV *PHISCEV =
4412  getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
4413  getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
4414  const auto *AR = cast<SCEVAddRecExpr>(PHISCEV);
4415 
4419  const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
4420  Predicates.push_back(AddRecPred);
4421 
4422  // Create the Equal Predicates P2,P3:
4423  auto AppendPredicate = [&](const SCEV *Expr) -> void {
4424  assert (isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
4425  const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
4426  const SCEV *ExtendedExpr =
4427  Signed ? getSignExtendExpr(TruncatedExpr, Expr->getType())
4428  : getZeroExtendExpr(TruncatedExpr, Expr->getType());
4429  if (Expr != ExtendedExpr &&
4430  !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
4431  const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
4432  DEBUG (dbgs() << "Added Predicate: " << *Pred);
4433  Predicates.push_back(Pred);
4434  }
4435  };
4436 
4437  AppendPredicate(StartVal);
4438  AppendPredicate(Accum);
4439 
4440  // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
4441  // which the casts had been folded away. The caller can rewrite SymbolicPHI
4442  // into NewAR if it will also add the runtime overflow checks specified in
4443  // Predicates.
4444  auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
4445 
4446  std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
4447  std::make_pair(NewAR, Predicates);
4448  // Remember the result of the analysis for this SCEV at this locayyytion.
4449  PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
4450  return PredRewrite;
4451 }
4452 
4455 
4456  auto *PN = cast<PHINode>(SymbolicPHI->getValue());
4457  const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
4458  if (!L)
4459  return None;
4460 
4461  // Check to see if we already analyzed this PHI.
4462  auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
4463  if (I != PredicatedSCEVRewrites.end()) {
4464  std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
4465  I->second;
4466  // Analysis was done before and failed to create an AddRec:
4467  if (Rewrite.first == SymbolicPHI)
4468  return None;
4469  // Analysis was done before and succeeded to create an AddRec under
4470  // a predicate:
4471  assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
4472  assert(!(Rewrite.second).empty() && "Expected to find Predicates");
4473  return Rewrite;
4474  }
4475 
4477  Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
4478 
4479  // Record in the cache that the analysis failed
4480  if (!Rewrite) {
4482  PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
4483  return None;
4484  }
4485 
4486  return Rewrite;
4487 }
4488 
4489 /// A helper function for createAddRecFromPHI to handle simple cases.
4490 ///
4491 /// This function tries to find an AddRec expression for the simplest (yet most
4492 /// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
4493 /// If it fails, createAddRecFromPHI will use a more general, but slow,
4494 /// technique for finding the AddRec expression.
4495 const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
4496  Value *BEValueV,
4497  Value *StartValueV) {
4498  const Loop *L = LI.getLoopFor(PN->getParent());
4499  assert(L && L->getHeader() == PN->getParent());
4500  assert(BEValueV && StartValueV);
4501 
4502  auto BO = MatchBinaryOp(BEValueV, DT);
4503  if (!BO)
4504  return nullptr;
4505 
4506  if (BO->Opcode != Instruction::Add)
4507  return nullptr;
4508 
4509  const SCEV *Accum = nullptr;
4510  if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
4511  Accum = getSCEV(BO->RHS);
4512  else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
4513  Accum = getSCEV(BO->LHS);
4514 
4515  if (!Accum)
4516  return nullptr;
4517 
4519  if (BO->IsNUW)
4520  Flags = setFlags(Flags, SCEV::FlagNUW);
4521  if (BO->IsNSW)
4522  Flags = setFlags(Flags, SCEV::FlagNSW);
4523 
4524  const SCEV *StartVal = getSCEV(StartValueV);
4525  const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
4526 
4527  ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV;
4528 
4529  // We can add Flags to the post-inc expression only if we
4530  // know that it is *undefined behavior* for BEValueV to
4531  // overflow.
4532  if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
4533  if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
4534  (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
4535 
4536  return PHISCEV;
4537 }
4538 
4539 const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
4540  const Loop *L = LI.getLoopFor(PN->getParent());
4541  if (!L || L->getHeader() != PN->getParent())
4542  return nullptr;
4543 
4544  // The loop may have multiple entrances or multiple exits; we can analyze
4545  // this phi as an addrec if it has a unique entry value and a unique
4546  // backedge value.
4547  Value *BEValueV = nullptr, *StartValueV = nullptr;
4548  for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
4549  Value *V = PN->getIncomingValue(i);
4550  if (L->contains(PN->getIncomingBlock(i))) {
4551  if (!BEValueV) {
4552  BEValueV = V;
4553  } else if (BEValueV != V) {
4554  BEValueV = nullptr;
4555  break;
4556  }
4557  } else if (!StartValueV) {
4558  StartValueV = V;
4559  } else if (StartValueV != V) {
4560  StartValueV = nullptr;
4561  break;
4562  }
4563  }
4564  if (!BEValueV || !StartValueV)
4565  return nullptr;
4566 
4567  assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
4568  "PHI node already processed?");
4569 
4570  // First, try to find AddRec expression without creating a fictituos symbolic
4571  // value for PN.
4572  if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
4573  return S;
4574 
4575  // Handle PHI node value symbolically.
4576  const SCEV *SymbolicName = getUnknown(PN);
4577  ValueExprMap.insert({SCEVCallbackVH(PN, this), SymbolicName});
4578 
4579  // Using this symbolic name for the PHI, analyze the value coming around
4580  // the back-edge.
4581  const SCEV *BEValue = getSCEV(BEValueV);
4582 
4583  // NOTE: If BEValue is loop invariant, we know that the PHI node just
4584  // has a special value for the first iteration of the loop.
4585 
4586  // If the value coming around the backedge is an add with the symbolic
4587  // value we just inserted, then we found a simple induction variable!
4588  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
4589  // If there is a single occurrence of the symbolic value, replace it
4590  // with a recurrence.
4591  unsigned FoundIndex = Add->getNumOperands();
4592  for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
4593  if (Add->getOperand(i) == SymbolicName)
4594  if (FoundIndex == e) {
4595  FoundIndex = i;
4596  break;
4597  }
4598 
4599  if (FoundIndex != Add->getNumOperands()) {
4600  // Create an add with everything but the specified operand.
4602  for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
4603  if (i != FoundIndex)
4604  Ops.push_back(Add->getOperand(i));
4605  const SCEV *Accum = getAddExpr(Ops);
4606 
4607  // This is not a valid addrec if the step amount is varying each
4608  // loop iteration, but is not itself an addrec in this loop.
4609  if (isLoopInvariant(Accum, L) ||
4610  (isa<SCEVAddRecExpr>(Accum) &&
4611  cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
4613 
4614  if (auto BO = MatchBinaryOp(BEValueV, DT)) {
4615  if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
4616  if (BO->IsNUW)
4617  Flags = setFlags(Flags, SCEV::FlagNUW);
4618  if (BO->IsNSW)
4619  Flags = setFlags(Flags, SCEV::FlagNSW);
4620  }
4621  } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
4622  // If the increment is an inbounds GEP, then we know the address
4623  // space cannot be wrapped around. We cannot make any guarantee
4624  // about signed or unsigned overflow because pointers are
4625  // unsigned but we may have a negative index from the base
4626  // pointer. We can guarantee that no unsigned wrap occurs if the
4627  // indices form a positive value.
4628  if (GEP->isInBounds() && GEP->getOperand(0) == PN) {
4629  Flags = setFlags(Flags, SCEV::FlagNW);
4630 
4631  const SCEV *Ptr = getSCEV(GEP->getPointerOperand());
4632  if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr)))
4633  Flags = setFlags(Flags, SCEV::FlagNUW);
4634  }
4635 
4636  // We cannot transfer nuw and nsw flags from subtraction
4637  // operations -- sub nuw X, Y is not the same as add nuw X, -Y
4638  // for instance.
4639  }
4640 
4641  const SCEV *StartVal = getSCEV(StartValueV);
4642  const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
4643 
4644  // Okay, for the entire analysis of this edge we assumed the PHI
4645  // to be symbolic. We now need to go back and purge all of the
4646  // entries for the scalars that use the symbolic expression.
4647  forgetSymbolicName(PN, SymbolicName);
4648  ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV;
4649 
4650  // We can add Flags to the post-inc expression only if we
4651  // know that it is *undefined behavior* for BEValueV to
4652  // overflow.
4653  if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
4654  if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
4655  (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
4656 
4657  return PHISCEV;
4658  }
4659  }
4660  } else {
4661  // Otherwise, this could be a loop like this:
4662  // i = 0; for (j = 1; ..; ++j) { .... i = j; }
4663  // In this case, j = {1,+,1} and BEValue is j.
4664  // Because the other in-value of i (0) fits the evolution of BEValue
4665  // i really is an addrec evolution.
4666  //
4667  // We can generalize this saying that i is the shifted value of BEValue
4668  // by one iteration:
4669  // PHI(f(0), f({1,+,1})) --> f({0,+,1})
4670  const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
4671  const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this);
4672  if (Shifted != getCouldNotCompute() &&
4673  Start != getCouldNotCompute()) {
4674  const SCEV *StartVal = getSCEV(StartValueV);
4675  if (Start == StartVal) {
4676  // Okay, for the entire analysis of this edge we assumed the PHI
4677  // to be symbolic. We now need to go back and purge all of the
4678  // entries for the scalars that use the symbolic expression.
4679  forgetSymbolicName(PN, SymbolicName);
4680  ValueExprMap[SCEVCallbackVH(PN, this)] = Shifted;
4681  return Shifted;
4682  }
4683  }
4684  }
4685 
4686  // Remove the temporary PHI node SCEV that has been inserted while intending
4687  // to create an AddRecExpr for this PHI node. We can not keep this temporary
4688  // as it will prevent later (possibly simpler) SCEV expressions to be added
4689  // to the ValueExprMap.
4690  eraseValueFromMap(PN);
4691 
4692  return nullptr;
4693 }
4694 
4695 // Checks if the SCEV S is available at BB. S is considered available at BB
4696 // if S can be materialized at BB without introducing a fault.
4697 static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S,
4698  BasicBlock *BB) {
4699  struct CheckAvailable {
4700  bool TraversalDone = false;
4701  bool Available = true;
4702 
4703  const Loop *L = nullptr; // The loop BB is in (can be nullptr)
4704  BasicBlock *BB = nullptr;
4705  DominatorTree &DT;
4706 
4707  CheckAvailable(const Loop *L, BasicBlock *BB, DominatorTree &DT)
4708  : L(L), BB(BB), DT(DT) {}
4709 
4710  bool setUnavailable() {
4711  TraversalDone = true;
4712  Available = false;
4713  return false;
4714  }
4715 
4716  bool follow(const SCEV *S) {
4717  switch (S->getSCEVType()) {
4718  case scConstant: case scTruncate: case scZeroExtend: case scSignExtend:
4719  case scAddExpr: case scMulExpr: case scUMaxExpr: case scSMaxExpr:
4720  // These expressions are available if their operand(s) is/are.
4721  return true;
4722 
4723  case scAddRecExpr: {
4724  // We allow add recurrences that are on the loop BB is in, or some
4725  // outer loop. This guarantees availability because the value of the
4726  // add recurrence at BB is simply the "current" value of the induction
4727  // variable. We can relax this in the future; for instance an add
4728  // recurrence on a sibling dominating loop is also available at BB.
4729  const auto *ARLoop = cast<SCEVAddRecExpr>(S)->getLoop();
4730  if (L && (ARLoop == L || ARLoop->contains(L)))
4731  return true;
4732 
4733  return setUnavailable();
4734  }
4735 
4736  case scUnknown: {
4737  // For SCEVUnknown, we check for simple dominance.
4738  const auto *SU = cast<SCEVUnknown>(S);
4739  Value *V = SU->getValue();
4740 
4741  if (isa<Argument>(V))
4742  return false;
4743 
4744  if (isa<Instruction>(V) && DT.dominates(cast<Instruction>(V), BB))
4745  return false;
4746 
4747  return setUnavailable();
4748  }
4749 
4750  case scUDivExpr:
4751  case scCouldNotCompute:
4752  // We do not try to smart about these at all.
4753  return setUnavailable();
4754  }
4755  llvm_unreachable("switch should be fully covered!");
4756  }
4757 
4758  bool isDone() { return TraversalDone; }
4759  };
4760 
4761  CheckAvailable CA(L, BB, DT);
4763 
4764  ST.visitAll(S);
4765  return CA.Available;
4766 }
4767 
4768 // Try to match a control flow sequence that branches out at BI and merges back
4769 // at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
4770 // match.
4772  Value *&C, Value *&LHS, Value *&RHS) {
4773  C = BI->getCondition();
4774 
4775  BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
4776  BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
4777 
4778  if (!LeftEdge.isSingleEdge())
4779  return false;
4780 
4781  assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
4782 
4783  Use &LeftUse = Merge->getOperandUse(0);
4784  Use &RightUse = Merge->getOperandUse(1);
4785 
4786  if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
4787  LHS = LeftUse;
4788  RHS = RightUse;
4789  return true;
4790  }
4791 
4792  if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
4793  LHS = RightUse;
4794  RHS = LeftUse;
4795  return true;
4796  }
4797 
4798  return false;
4799 }
4800 
4801 const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
4802  auto IsReachable =
4803  [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
4804  if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
4805  const Loop *L = LI.getLoopFor(PN->getParent());
4806 
4807  // We don't want to break LCSSA, even in a SCEV expression tree.
4808  for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
4809  if (LI.getLoopFor(PN->getIncomingBlock(i)) != L)
4810  return nullptr;
4811 
4812  // Try to match
4813  //
4814  // br %cond, label %left, label %right
4815  // left:
4816  // br label %merge
4817  // right:
4818  // br label %merge
4819  // merge:
4820  // V = phi [ %x, %left ], [ %y, %right ]
4821  //
4822  // as "select %cond, %x, %y"
4823 
4824  BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
4825  assert(IDom && "At least the entry block should dominate PN");
4826 
4827  auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
4828  Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
4829 
4830  if (BI && BI->isConditional() &&
4831  BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
4832  IsAvailableOnEntry(L, DT, getSCEV(LHS), PN->getParent()) &&
4833  IsAvailableOnEntry(L, DT, getSCEV(RHS), PN->getParent()))
4834  return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
4835  }
4836 
4837  return nullptr;
4838 }
4839 
4840 const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
4841  if (const SCEV *S = createAddRecFromPHI(PN))
4842  return S;
4843 
4844  if (const SCEV *S = createNodeFromSelectLikePHI(PN))
4845  return S;
4846 
4847  // If the PHI has a single incoming value, follow that value, unless the
4848  // PHI's incoming blocks are in a different loop, in which case doing so
4849  // risks breaking LCSSA form. Instcombine would normally zap these, but
4850  // it doesn't have DominatorTree information, so it may miss cases.
4851  if (Value *V = SimplifyInstruction(PN, {getDataLayout(), &TLI, &DT, &AC}))
4852  if (LI.replacementPreservesLCSSAForm(PN, V))
4853  return getSCEV(V);
4854 
4855  // If it's not a loop phi, we can't handle it yet.
4856  return getUnknown(PN);
4857 }
4858 
4859 const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I,
4860  Value *Cond,
4861  Value *TrueVal,
4862  Value *FalseVal) {
4863  // Handle "constant" branch or select. This can occur for instance when a
4864  // loop pass transforms an inner loop and moves on to process the outer loop.
4865  if (auto *CI = dyn_cast<ConstantInt>(Cond))
4866  return getSCEV(CI->isOne() ? TrueVal : FalseVal);
4867 
4868  // Try to match some simple smax or umax patterns.
4869  auto *ICI = dyn_cast<ICmpInst>(Cond);
4870  if (!ICI)
4871  return getUnknown(I);
4872 
4873  Value *LHS = ICI->getOperand(0);
4874  Value *RHS = ICI->getOperand(1);
4875 
4876  switch (ICI->getPredicate()) {
4877  case ICmpInst::ICMP_SLT:
4878  case ICmpInst::ICMP_SLE:
4879  std::swap(LHS, RHS);
4881  case ICmpInst::ICMP_SGT:
4882  case ICmpInst::ICMP_SGE:
4883  // a >s b ? a+x : b+x -> smax(a, b)+x
4884  // a >s b ? b+x : a+x -> smin(a, b)+x
4885  if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) {
4886  const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), I->getType());
4887  const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), I->getType());
4888  const SCEV *LA = getSCEV(TrueVal);
4889  const SCEV *RA = getSCEV(FalseVal);
4890  const SCEV *LDiff = getMinusSCEV(LA, LS);
4891  const SCEV *RDiff = getMinusSCEV(RA, RS);
4892  if (LDiff == RDiff)
4893  return getAddExpr(getSMaxExpr(LS, RS), LDiff);
4894  LDiff = getMinusSCEV(LA, RS);
4895  RDiff = getMinusSCEV(RA, LS);
4896  if (LDiff == RDiff)
4897  return getAddExpr(getSMinExpr(LS, RS), LDiff);
4898  }
4899  break;
4900  case ICmpInst::ICMP_ULT:
4901  case ICmpInst::ICMP_ULE:
4902  std::swap(LHS, RHS);
4904  case ICmpInst::ICMP_UGT:
4905  case ICmpInst::ICMP_UGE:
4906  // a >u b ? a+x : b+x -> umax(a, b)+x
4907  // a >u b ? b+x : a+x -> umin(a, b)+x
4908  if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) {
4909  const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
4910  const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), I->getType());
4911  const SCEV *LA = getSCEV(TrueVal);
4912  const SCEV *RA = getSCEV(FalseVal);
4913  const SCEV *LDiff = getMinusSCEV(LA, LS);
4914  const SCEV *RDiff = getMinusSCEV(RA, RS);
4915  if (LDiff == RDiff)
4916  return getAddExpr(getUMaxExpr(LS, RS), LDiff);
4917  LDiff = getMinusSCEV(LA, RS);
4918  RDiff = getMinusSCEV(RA, LS);
4919  if (LDiff == RDiff)
4920  return getAddExpr(getUMinExpr(LS, RS), LDiff);
4921  }
4922  break;
4923  case ICmpInst::ICMP_NE:
4924  // n != 0 ? n+x : 1+x -> umax(n, 1)+x
4925  if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) &&
4926  isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
4927  const SCEV *One = getOne(I->getType());
4928  const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
4929  const SCEV *LA = getSCEV(TrueVal);
4930  const SCEV *RA = getSCEV(FalseVal);
4931  const SCEV *LDiff = getMinusSCEV(LA, LS);
4932  const SCEV *RDiff = getMinusSCEV(RA, One);
4933  if (LDiff == RDiff)
4934  return getAddExpr(getUMaxExpr(One, LS), LDiff);
4935  }
4936  break;
4937  case ICmpInst::ICMP_EQ:
4938  // n == 0 ? 1+x : n+x -> umax(n, 1)+x
4939  if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) &&
4940  isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
4941  const SCEV *One = getOne(I->getType());
4942  const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
4943  const SCEV *LA = getSCEV(TrueVal);
4944  const SCEV *RA = getSCEV(FalseVal);
4945  const SCEV *LDiff = getMinusSCEV(LA, One);
4946  const SCEV *RDiff = getMinusSCEV(RA, LS);
4947  if (LDiff == RDiff)
4948  return getAddExpr(getUMaxExpr(One, LS), LDiff);
4949  }
4950  break;
4951  default:
4952  break;
4953  }
4954 
4955  return getUnknown(I);
4956 }
4957 
4958 /// Expand GEP instructions into add and multiply operations. This allows them
4959 /// to be analyzed by regular SCEV code.
4960 const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
4961  // Don't attempt to analyze GEPs over unsized objects.
4962  if (!GEP->getSourceElementType()->isSized())
4963  return getUnknown(GEP);
4964 
4965  SmallVector<const SCEV *, 4> IndexExprs;
4966  for (auto Index = GEP->idx_begin(); Index != GEP->idx_end(); ++Index)
4967  IndexExprs.push_back(getSCEV(*Index));
4968  return getGEPExpr(GEP, IndexExprs);
4969 }
4970 
4971 uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) {
4972  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
4973  return C->getAPInt().countTrailingZeros();
4974 
4975  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
4976  return std::min(GetMinTrailingZeros(T->getOperand()),
4977  (uint32_t)getTypeSizeInBits(T->getType()));
4978 
4979  if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
4980  uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
4981  return OpRes == getTypeSizeInBits(E->getOperand()->getType())
4982  ? getTypeSizeInBits(E->getType())
4983  : OpRes;
4984  }
4985 
4986  if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
4987  uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
4988  return OpRes == getTypeSizeInBits(E->getOperand()->getType())
4989  ? getTypeSizeInBits(E->getType())
4990  : OpRes;
4991  }
4992 
4993  if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
4994  // The result is the min of all operands results.
4995  uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
4996  for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
4997  MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
4998  return MinOpRes;
4999  }
5000 
5001  if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
5002  // The result is the sum of all operands results.
5003  uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0));
5004  uint32_t BitWidth = getTypeSizeInBits(M->getType());
5005  for (unsigned i = 1, e = M->getNumOperands();
5006  SumOpRes != BitWidth && i != e; ++i)
5007  SumOpRes =
5008  std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), BitWidth);
5009  return SumOpRes;
5010  }
5011 
5012  if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
5013  // The result is the min of all operands results.
5014  uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
5015  for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
5016  MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
5017  return MinOpRes;
5018  }
5019 
5020  if (const SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) {
5021  // The result is the min of all operands results.
5022  uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
5023  for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
5024  MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
5025  return MinOpRes;
5026  }
5027 
5028  if (const SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) {
5029  // The result is the min of all operands results.
5030  uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
5031  for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
5032  MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
5033  return MinOpRes;
5034  }
5035 
5036  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
5037  // For a SCEVUnknown, ask ValueTracking.
5038  KnownBits Known = computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT);
5039  return Known.countMinTrailingZeros();
5040  }
5041 
5042  // SCEVUDivExpr
5043  return 0;
5044 }
5045 
5047  auto I = MinTrailingZerosCache.find(S);
5048  if (I != MinTrailingZerosCache.end())
5049  return I->second;
5050 
5051  uint32_t Result = GetMinTrailingZerosImpl(S);
5052  auto InsertPair = MinTrailingZerosCache.insert({S, Result});
5053  assert(InsertPair.second && "Should insert a new key");
5054  return InsertPair.first->second;
5055 }
5056 
5057 /// Helper method to assign a range to V from metadata present in the IR.
5059  if (Instruction *I = dyn_cast<Instruction>(V))
5061  return getConstantRangeFromMetadata(*MD);
5062 
5063  return None;
5064 }
5065 
5066 /// Determine the range for a particular SCEV. If SignHint is
5067 /// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
5068 /// with a "cleaner" unsigned (resp. signed) representation.
5069 const ConstantRange &
5070 ScalarEvolution::getRangeRef(const SCEV *S,
5071  ScalarEvolution::RangeSignHint SignHint) {
5073  SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
5074  : SignedRanges;
5075 
5076  // See if we've computed this range already.
5078  if (I != Cache.end())
5079  return I->second;
5080 
5081  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
5082  return setRange(C, SignHint, ConstantRange(C->getAPInt()));
5083 
5084  unsigned BitWidth = getTypeSizeInBits(S->getType());
5085  ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
5086 
5087  // If the value has known zeros, the maximum value will have those known zeros
5088  // as well.
5089  uint32_t TZ = GetMinTrailingZeros(S);
5090  if (TZ != 0) {
5091  if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED)
5092  ConservativeResult =
5094  APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1);
5095  else
5096  ConservativeResult = ConstantRange(
5097  APInt::getSignedMinValue(BitWidth),
5098  APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
5099  }
5100 
5101  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
5102  ConstantRange X = getRangeRef(Add->getOperand(0), SignHint);
5103  for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
5104  X = X.add(getRangeRef(Add->getOperand(i), SignHint));
5105  return setRange(Add, SignHint, ConservativeResult.intersectWith(X));
5106  }
5107 
5108  if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
5109  ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint);
5110  for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
5111  X = X.multiply(getRangeRef(Mul->getOperand(i), SignHint));
5112  return setRange(Mul, SignHint, ConservativeResult.intersectWith(X));
5113  }
5114 
5115  if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) {
5116  ConstantRange X = getRangeRef(SMax->getOperand(0), SignHint);
5117  for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i)
5118  X = X.smax(getRangeRef(SMax->getOperand(i), SignHint));
5119  return setRange(SMax, SignHint, ConservativeResult.intersectWith(X));
5120  }
5121 
5122  if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) {
5123  ConstantRange X = getRangeRef(UMax->getOperand(0), SignHint);
5124  for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i)
5125  X = X.umax(getRangeRef(UMax->getOperand(i), SignHint));
5126  return setRange(UMax, SignHint, ConservativeResult.intersectWith(X));
5127  }
5128 
5129  if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) {
5130  ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint);
5131  ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint);
5132  return setRange(UDiv, SignHint,
5133  ConservativeResult.intersectWith(X.udiv(Y)));
5134  }
5135 
5136  if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) {
5137  ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint);
5138  return setRange(ZExt, SignHint,
5139  ConservativeResult.intersectWith(X.zeroExtend(BitWidth)));
5140  }
5141 
5142  if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) {
5143  ConstantRange X = getRangeRef(SExt->getOperand(), SignHint);
5144  return setRange(SExt, SignHint,
5145  ConservativeResult.intersectWith(X.signExtend(BitWidth)));
5146  }
5147 
5148  if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) {
5149  ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint);
5150  return setRange(Trunc, SignHint,
5151  ConservativeResult.intersectWith(X.truncate(BitWidth)));
5152  }
5153 
5154  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
5155  // If there's no unsigned wrap, the value will never be less than its
5156  // initial value.
5157  if (AddRec->hasNoUnsignedWrap())
5158  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(AddRec->getStart()))
5159  if (!C->getValue()->isZero())
5160  ConservativeResult = ConservativeResult.intersectWith(
5161  ConstantRange(C->getAPInt(), APInt(BitWidth, 0)));
5162 
5163  // If there's no signed wrap, and all the operands have the same sign or
5164  // zero, the value won't ever change sign.
5165  if (AddRec->hasNoSignedWrap()) {
5166  bool AllNonNeg = true;
5167  bool AllNonPos = true;
5168  for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
5169  if (!isKnownNonNegative(AddRec->getOperand(i))) AllNonNeg = false;
5170  if (!isKnownNonPositive(AddRec->getOperand(i))) AllNonPos = false;
5171  }
5172  if (AllNonNeg)
5173  ConservativeResult = ConservativeResult.intersectWith(
5174  ConstantRange(APInt(BitWidth, 0),
5175  APInt::getSignedMinValue(BitWidth)));
5176  else if (AllNonPos)
5177  ConservativeResult = ConservativeResult.intersectWith(
5179  APInt(BitWidth, 1)));
5180  }
5181 
5182  // TODO: non-affine addrec
5183  if (AddRec->isAffine()) {
5184  const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop());
5185  if (!isa<SCEVCouldNotCompute>(MaxBECount) &&
5186  getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) {
5187  auto RangeFromAffine = getRangeForAffineAR(
5188  AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount,
5189  BitWidth);
5190  if (!RangeFromAffine.isFullSet())
5191  ConservativeResult =
5192  ConservativeResult.intersectWith(RangeFromAffine);
5193 
5194  auto RangeFromFactoring = getRangeViaFactoring(
5195  AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount,
5196  BitWidth);
5197  if (!RangeFromFactoring.isFullSet())
5198  ConservativeResult =
5199  ConservativeResult.intersectWith(RangeFromFactoring);
5200  }
5201  }
5202 
5203  return setRange(AddRec, SignHint, std::move(ConservativeResult));
5204  }
5205 
5206  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
5207  // Check if the IR explicitly contains !range metadata.
5208  Optional<ConstantRange> MDRange = GetRangeFromMetadata(U->getValue());
5209  if (MDRange.hasValue())
5210  ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue());
5211 
5212  // Split here to avoid paying the compile-time cost of calling both
5213  // computeKnownBits and ComputeNumSignBits. This restriction can be lifted
5214  // if needed.
5215  const DataLayout &DL = getDataLayout();
5216  if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
5217  // For a SCEVUnknown, ask ValueTracking.
5218  KnownBits Known = computeKnownBits(U->getValue(), DL, 0, &AC, nullptr, &DT);
5219  if (Known.One != ~Known.Zero + 1)
5220  ConservativeResult =
5221  ConservativeResult.intersectWith(ConstantRange(Known.One,
5222  ~Known.Zero + 1));
5223  } else {
5224  assert(SignHint == ScalarEvolution::HINT_RANGE_SIGNED &&
5225  "generalize as needed!");
5226  unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, &AC, nullptr, &DT);
5227  if (NS > 1)
5228  ConservativeResult = ConservativeResult.intersectWith(
5229  ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
5230  APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1));
5231  }
5232 
5233  return setRange(U, SignHint, std::move(ConservativeResult));
5234  }
5235 
5236  return setRange(S, SignHint, std::move(ConservativeResult));
5237 }
5238 
5239 // Given a StartRange, Step and MaxBECount for an expression compute a range of
5240 // values that the expression can take. Initially, the expression has a value
5241 // from StartRange and then is changed by Step up to MaxBECount times. Signed
5242 // argument defines if we treat Step as signed or unsigned.
5244  const ConstantRange &StartRange,
5245  const APInt &MaxBECount,
5246  unsigned BitWidth, bool Signed) {
5247  // If either Step or MaxBECount is 0, then the expression won't change, and we
5248  // just need to return the initial range.
5249  if (Step == 0 || MaxBECount == 0)
5250  return StartRange;
5251 
5252  // If we don't know anything about the initial value (i.e. StartRange is
5253  // FullRange), then we don't know anything about the final range either.
5254  // Return FullRange.
5255  if (StartRange.isFullSet())
5256  return ConstantRange(BitWidth, /* isFullSet = */ true);
5257 
5258  // If Step is signed and negative, then we use its absolute value, but we also
5259  // note that we're moving in the opposite direction.
5260  bool Descending = Signed && Step.isNegative();
5261 
5262  if (Signed)
5263  // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
5264  // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
5265  // This equations hold true due to the well-defined wrap-around behavior of
5266  // APInt.
5267  Step = Step.abs();
5268 
5269  // Check if Offset is more than full span of BitWidth. If it is, the
5270  // expression is guaranteed to overflow.
5271  if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
5272  return ConstantRange(BitWidth, /* isFullSet = */ true);
5273 
5274  // Offset is by how much the expression can change. Checks above guarantee no
5275  // overflow here.
5276  APInt Offset = Step * MaxBECount;
5277 
5278  // Minimum value of the final range will match the minimal value of StartRange
5279  // if the expression is increasing and will be decreased by Offset otherwise.
5280  // Maximum value of the final range will match the maximal value of StartRange
5281  // if the expression is decreasing and will be increased by Offset otherwise.
5282  APInt StartLower = StartRange.getLower();
5283  APInt StartUpper = StartRange.getUpper() - 1;
5284  APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
5285  : (StartUpper + std::move(Offset));
5286 
5287  // It's possible that the new minimum/maximum value will fall into the initial
5288  // range (due to wrap around). This means that the expression can take any
5289  // value in this bitwidth, and we have to return full range.
5290  if (StartRange.contains(MovedBoundary))
5291  return ConstantRange(BitWidth, /* isFullSet = */ true);
5292 
5293  APInt NewLower =
5294  Descending ? std::move(MovedBoundary) : std::move(StartLower);
5295  APInt NewUpper =
5296  Descending ? std::move(StartUpper) : std::move(MovedBoundary);
5297  NewUpper += 1;
5298 
5299  // If we end up with full range, return a proper full range.
5300  if (NewLower == NewUpper)
5301  return ConstantRange(BitWidth, /* isFullSet = */ true);
5302 
5303  // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
5304  return ConstantRange(std::move(NewLower), std::move(NewUpper));
5305 }
5306 
5307 ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
5308  const SCEV *Step,
5309  const SCEV *MaxBECount,
5310  unsigned BitWidth) {
5311  assert(!isa<SCEVCouldNotCompute>(MaxBECount) &&
5312  getTypeSizeInBits(MaxBECount->getType()) <= BitWidth &&
5313  "Precondition!");
5314 
5315  MaxBECount = getNoopOrZeroExtend(MaxBECount, Start->getType());
5316  APInt MaxBECountValue = getUnsignedRangeMax(MaxBECount);
5317 
5318  // First, consider step signed.
5319  ConstantRange StartSRange = getSignedRange(Start);
5320  ConstantRange StepSRange = getSignedRange(Step);
5321 
5322  // If Step can be both positive and negative, we need to find ranges for the
5323  // maximum absolute step values in both directions and union them.
5324  ConstantRange SR =
5325  getRangeForAffineARHelper(StepSRange.getSignedMin(), StartSRange,
5326  MaxBECountValue, BitWidth, /* Signed = */ true);
5327  SR = SR.unionWith(getRangeForAffineARHelper(StepSRange.getSignedMax(),
5328  StartSRange, MaxBECountValue,
5329  BitWidth, /* Signed = */ true));
5330 
5331  // Next, consider step unsigned.
5333  getUnsignedRangeMax(Step), getUnsignedRange(Start),
5334  MaxBECountValue, BitWidth, /* Signed = */ false);
5335 
5336  // Finally, intersect signed and unsigned ranges.
5337  return SR.intersectWith(UR);
5338 }
5339 
5340 ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
5341  const SCEV *Step,
5342  const SCEV *MaxBECount,
5343  unsigned BitWidth) {
5344  // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
5345  // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
5346 
5347  struct SelectPattern {
5348  Value *Condition = nullptr;
5349  APInt TrueValue;
5350  APInt FalseValue;
5351 
5352  explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
5353  const SCEV *S) {
5354  Optional<unsigned> CastOp;
5355  APInt Offset(BitWidth, 0);
5356 
5357  assert(SE.getTypeSizeInBits(S->getType()) == BitWidth &&
5358  "Should be!");
5359 
5360  // Peel off a constant offset:
5361  if (auto *SA = dyn_cast<SCEVAddExpr>(S)) {
5362  // In the future we could consider being smarter here and handle
5363  // {Start+Step,+,Step} too.
5364  if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0)))
5365  return;
5366 
5367  Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt();
5368  S = SA->getOperand(1);
5369  }
5370 
5371  // Peel off a cast operation
5372  if (auto *SCast = dyn_cast<SCEVCastExpr>(S)) {
5373  CastOp = SCast->getSCEVType();
5374  S = SCast->getOperand();
5375  }
5376 
5377  using namespace llvm::PatternMatch;
5378 
5379  auto *SU = dyn_cast<SCEVUnknown>(S);
5380  const APInt *TrueVal, *FalseVal;
5381  if (!SU ||
5382  !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
5383  m_APInt(FalseVal)))) {
5384  Condition = nullptr;
5385  return;
5386  }
5387 
5388  TrueValue = *TrueVal;
5389  FalseValue = *FalseVal;
5390 
5391  // Re-apply the cast we peeled off earlier
5392  if (CastOp.hasValue())
5393  switch (*CastOp) {
5394  default:
5395  llvm_unreachable("Unknown SCEV cast type!");
5396 
5397  case scTruncate:
5398  TrueValue = TrueValue.trunc(BitWidth);
5399  FalseValue = FalseValue.trunc(BitWidth);
5400  break;
5401  case scZeroExtend:
5402  TrueValue = TrueValue.zext(BitWidth);
5403  FalseValue = FalseValue.zext(BitWidth);
5404  break;
5405  case scSignExtend:
5406  TrueValue = TrueValue.sext(BitWidth);
5407  FalseValue = FalseValue.sext(BitWidth);
5408  break;
5409  }
5410 
5411  // Re-apply the constant offset we peeled off earlier
5412  TrueValue += Offset;
5413  FalseValue += Offset;
5414  }
5415 
5416  bool isRecognized() { return Condition != nullptr; }
5417  };
5418 
5419  SelectPattern StartPattern(*this, BitWidth, Start);
5420  if (!StartPattern.isRecognized())
5421  return ConstantRange(BitWidth, /* isFullSet = */ true);
5422 
5423  SelectPattern StepPattern(*this, BitWidth, Step);
5424  if (!StepPattern.isRecognized())
5425  return ConstantRange(BitWidth, /* isFullSet = */ true);
5426 
5427  if (StartPattern.Condition != StepPattern.Condition) {
5428  // We don't handle this case today; but we could, by considering four
5429  // possibilities below instead of two. I'm not sure if there are cases where
5430  // that will help over what getRange already does, though.
5431  return ConstantRange(BitWidth, /* isFullSet = */ true);
5432  }
5433 
5434  // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
5435  // construct arbitrary general SCEV expressions here. This function is called
5436  // from deep in the call stack, and calling getSCEV (on a sext instruction,
5437  // say) can end up caching a suboptimal value.
5438 
5439  // FIXME: without the explicit `this` receiver below, MSVC errors out with
5440  // C2352 and C2512 (otherwise it isn't needed).
5441 
5442  const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
5443  const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
5444  const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
5445  const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
5446 
5447  ConstantRange TrueRange =
5448  this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount, BitWidth);
5449  ConstantRange FalseRange =
5450  this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount, BitWidth);
5451 
5452  return TrueRange.unionWith(FalseRange);
5453 }
5454 
5455 SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
5456  if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
5457  const BinaryOperator *BinOp = cast<BinaryOperator>(V);
5458 
5459  // Return early if there are no flags to propagate to the SCEV.
5461  if (BinOp->hasNoUnsignedWrap())
5462  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
5463  if (BinOp->