LLVM  14.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1 //===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the scalar evolution analysis
10 // engine, which is used primarily to analyze expressions involving induction
11 // variables in loops.
12 //
13 // There are several aspects to this library. First is the representation of
14 // scalar expressions, which are represented as subclasses of the SCEV class.
15 // These classes are used to represent certain types of subexpressions that we
16 // can handle. We only create one SCEV of a particular shape, so
17 // pointer-comparisons for equality are legal.
18 //
19 // One important aspect of the SCEV objects is that they are never cyclic, even
20 // if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21 // the PHI node is one of the idioms that we can represent (e.g., a polynomial
22 // recurrence) then we represent it directly as a recurrence node, otherwise we
23 // represent it as a SCEVUnknown node.
24 //
25 // In addition to being able to represent expressions of various types, we also
26 // have folders that are used to build the *canonical* representation for a
27 // particular expression. These folders are capable of using a variety of
28 // rewrite rules to simplify the expressions.
29 //
30 // Once the folders are defined, we can implement the more interesting
31 // higher-level code, such as the code that recognizes PHI nodes of various
32 // types, computes the execution count of a loop, etc.
33 //
34 // TODO: We should use these routines and value representations to implement
35 // dependence analysis!
36 //
37 //===----------------------------------------------------------------------===//
38 //
39 // There are several good references for the techniques used in this analysis.
40 //
41 // Chains of recurrences -- a method to expedite the evaluation
42 // of closed-form functions
43 // Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44 //
45 // On computational properties of chains of recurrences
46 // Eugene V. Zima
47 //
48 // Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49 // Robert A. van Engelen
50 //
51 // Efficient Symbolic Analysis for Optimizing Compilers
52 // Robert A. van Engelen
53 //
54 // Using the chains of recurrences algebra for data dependence testing and
55 // induction variable substitution
56 // MS Thesis, Johnie Birch
57 //
58 //===----------------------------------------------------------------------===//
59 
61 #include "llvm/ADT/APInt.h"
62 #include "llvm/ADT/ArrayRef.h"
63 #include "llvm/ADT/DenseMap.h"
66 #include "llvm/ADT/FoldingSet.h"
67 #include "llvm/ADT/None.h"
68 #include "llvm/ADT/Optional.h"
69 #include "llvm/ADT/STLExtras.h"
70 #include "llvm/ADT/ScopeExit.h"
71 #include "llvm/ADT/Sequence.h"
72 #include "llvm/ADT/SetVector.h"
73 #include "llvm/ADT/SmallPtrSet.h"
74 #include "llvm/ADT/SmallSet.h"
75 #include "llvm/ADT/SmallVector.h"
76 #include "llvm/ADT/Statistic.h"
77 #include "llvm/ADT/StringRef.h"
81 #include "llvm/Analysis/LoopInfo.h"
86 #include "llvm/Config/llvm-config.h"
87 #include "llvm/IR/Argument.h"
88 #include "llvm/IR/BasicBlock.h"
89 #include "llvm/IR/CFG.h"
90 #include "llvm/IR/Constant.h"
91 #include "llvm/IR/ConstantRange.h"
92 #include "llvm/IR/Constants.h"
93 #include "llvm/IR/DataLayout.h"
94 #include "llvm/IR/DerivedTypes.h"
95 #include "llvm/IR/Dominators.h"
96 #include "llvm/IR/Function.h"
97 #include "llvm/IR/GlobalAlias.h"
98 #include "llvm/IR/GlobalValue.h"
99 #include "llvm/IR/GlobalVariable.h"
100 #include "llvm/IR/InstIterator.h"
101 #include "llvm/IR/InstrTypes.h"
102 #include "llvm/IR/Instruction.h"
103 #include "llvm/IR/Instructions.h"
104 #include "llvm/IR/IntrinsicInst.h"
105 #include "llvm/IR/Intrinsics.h"
106 #include "llvm/IR/LLVMContext.h"
107 #include "llvm/IR/Metadata.h"
108 #include "llvm/IR/Operator.h"
109 #include "llvm/IR/PatternMatch.h"
110 #include "llvm/IR/Type.h"
111 #include "llvm/IR/Use.h"
112 #include "llvm/IR/User.h"
113 #include "llvm/IR/Value.h"
114 #include "llvm/IR/Verifier.h"
115 #include "llvm/InitializePasses.h"
116 #include "llvm/Pass.h"
117 #include "llvm/Support/Casting.h"
119 #include "llvm/Support/Compiler.h"
120 #include "llvm/Support/Debug.h"
122 #include "llvm/Support/KnownBits.h"
125 #include <algorithm>
126 #include <cassert>
127 #include <climits>
128 #include <cstddef>
129 #include <cstdint>
130 #include <cstdlib>
131 #include <map>
132 #include <memory>
133 #include <tuple>
134 #include <utility>
135 #include <vector>
136 
137 using namespace llvm;
138 using namespace PatternMatch;
139 
140 #define DEBUG_TYPE "scalar-evolution"
141 
142 STATISTIC(NumArrayLenItCounts,
143  "Number of trip counts computed with array length");
144 STATISTIC(NumTripCountsComputed,
145  "Number of loops with predictable loop counts");
146 STATISTIC(NumTripCountsNotComputed,
147  "Number of loops without predictable loop counts");
148 STATISTIC(NumBruteForceTripCountsComputed,
149  "Number of loops with trip counts computed by force");
150 
151 static cl::opt<unsigned>
152 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
154  cl::desc("Maximum number of iterations SCEV will "
155  "symbolically execute a constant "
156  "derived loop"),
157  cl::init(100));
158 
159 // FIXME: Enable this with EXPENSIVE_CHECKS when the test suite is clean.
161  "verify-scev", cl::Hidden,
162  cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
164  "verify-scev-strict", cl::Hidden,
165  cl::desc("Enable stricter verification with -verify-scev is passed"));
166 static cl::opt<bool>
167  VerifySCEVMap("verify-scev-maps", cl::Hidden,
168  cl::desc("Verify no dangling value in ScalarEvolution's "
169  "ExprValueMap (slow)"));
170 
171 static cl::opt<bool> VerifyIR(
172  "scev-verify-ir", cl::Hidden,
173  cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
174  cl::init(false));
175 
177  "scev-mulops-inline-threshold", cl::Hidden,
178  cl::desc("Threshold for inlining multiplication operands into a SCEV"),
179  cl::init(32));
180 
182  "scev-addops-inline-threshold", cl::Hidden,
183  cl::desc("Threshold for inlining addition operands into a SCEV"),
184  cl::init(500));
185 
187  "scalar-evolution-max-scev-compare-depth", cl::Hidden,
188  cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
189  cl::init(32));
190 
192  "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
193  cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
194  cl::init(2));
195 
197  "scalar-evolution-max-value-compare-depth", cl::Hidden,
198  cl::desc("Maximum depth of recursive value complexity comparisons"),
199  cl::init(2));
200 
201 static cl::opt<unsigned>
202  MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
203  cl::desc("Maximum depth of recursive arithmetics"),
204  cl::init(32));
205 
207  "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
208  cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
209 
210 static cl::opt<unsigned>
211  MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
212  cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
213  cl::init(8));
214 
215 static cl::opt<unsigned>
216  MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
217  cl::desc("Max coefficients in AddRec during evolving"),
218  cl::init(8));
219 
220 static cl::opt<unsigned>
221  HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
222  cl::desc("Size of the expression which is considered huge"),
223  cl::init(4096));
224 
225 static cl::opt<bool>
226 ClassifyExpressions("scalar-evolution-classify-expressions",
227  cl::Hidden, cl::init(true),
228  cl::desc("When printing analysis, include information on every instruction"));
229 
231  "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
232  cl::init(false),
233  cl::desc("Use more powerful methods of sharpening expression ranges. May "
234  "be costly in terms of compile time"));
235 
236 //===----------------------------------------------------------------------===//
237 // SCEV class definitions
238 //===----------------------------------------------------------------------===//
239 
240 //===----------------------------------------------------------------------===//
241 // Implementation of the SCEV class.
242 //
243 
244 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
246  print(dbgs());
247  dbgs() << '\n';
248 }
249 #endif
250 
251 void SCEV::print(raw_ostream &OS) const {
252  switch (getSCEVType()) {
253  case scConstant:
254  cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
255  return;
256  case scPtrToInt: {
257  const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
258  const SCEV *Op = PtrToInt->getOperand();
259  OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
260  << *PtrToInt->getType() << ")";
261  return;
262  }
263  case scTruncate: {
264  const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
265  const SCEV *Op = Trunc->getOperand();
266  OS << "(trunc " << *Op->getType() << " " << *Op << " to "
267  << *Trunc->getType() << ")";
268  return;
269  }
270  case scZeroExtend: {
271  const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
272  const SCEV *Op = ZExt->getOperand();
273  OS << "(zext " << *Op->getType() << " " << *Op << " to "
274  << *ZExt->getType() << ")";
275  return;
276  }
277  case scSignExtend: {
278  const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
279  const SCEV *Op = SExt->getOperand();
280  OS << "(sext " << *Op->getType() << " " << *Op << " to "
281  << *SExt->getType() << ")";
282  return;
283  }
284  case scAddRecExpr: {
285  const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
286  OS << "{" << *AR->getOperand(0);
287  for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
288  OS << ",+," << *AR->getOperand(i);
289  OS << "}<";
290  if (AR->hasNoUnsignedWrap())
291  OS << "nuw><";
292  if (AR->hasNoSignedWrap())
293  OS << "nsw><";
294  if (AR->hasNoSelfWrap() &&
295  !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)))
296  OS << "nw><";
297  AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
298  OS << ">";
299  return;
300  }
301  case scAddExpr:
302  case scMulExpr:
303  case scUMaxExpr:
304  case scSMaxExpr:
305  case scUMinExpr:
306  case scSMinExpr: {
307  const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
308  const char *OpStr = nullptr;
309  switch (NAry->getSCEVType()) {
310  case scAddExpr: OpStr = " + "; break;
311  case scMulExpr: OpStr = " * "; break;
312  case scUMaxExpr: OpStr = " umax "; break;
313  case scSMaxExpr: OpStr = " smax "; break;
314  case scUMinExpr:
315  OpStr = " umin ";
316  break;
317  case scSMinExpr:
318  OpStr = " smin ";
319  break;
320  default:
321  llvm_unreachable("There are no other nary expression types.");
322  }
323  OS << "(";
324  ListSeparator LS(OpStr);
325  for (const SCEV *Op : NAry->operands())
326  OS << LS << *Op;
327  OS << ")";
328  switch (NAry->getSCEVType()) {
329  case scAddExpr:
330  case scMulExpr:
331  if (NAry->hasNoUnsignedWrap())
332  OS << "<nuw>";
333  if (NAry->hasNoSignedWrap())
334  OS << "<nsw>";
335  break;
336  default:
337  // Nothing to print for other nary expressions.
338  break;
339  }
340  return;
341  }
342  case scUDivExpr: {
343  const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
344  OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
345  return;
346  }
347  case scUnknown: {
348  const SCEVUnknown *U = cast<SCEVUnknown>(this);
349  Type *AllocTy;
350  if (U->isSizeOf(AllocTy)) {
351  OS << "sizeof(" << *AllocTy << ")";
352  return;
353  }
354  if (U->isAlignOf(AllocTy)) {
355  OS << "alignof(" << *AllocTy << ")";
356  return;
357  }
358 
359  Type *CTy;
360  Constant *FieldNo;
361  if (U->isOffsetOf(CTy, FieldNo)) {
362  OS << "offsetof(" << *CTy << ", ";
363  FieldNo->printAsOperand(OS, false);
364  OS << ")";
365  return;
366  }
367 
368  // Otherwise just print it normally.
369  U->getValue()->printAsOperand(OS, false);
370  return;
371  }
372  case scCouldNotCompute:
373  OS << "***COULDNOTCOMPUTE***";
374  return;
375  }
376  llvm_unreachable("Unknown SCEV kind!");
377 }
378 
379 Type *SCEV::getType() const {
380  switch (getSCEVType()) {
381  case scConstant:
382  return cast<SCEVConstant>(this)->getType();
383  case scPtrToInt:
384  case scTruncate:
385  case scZeroExtend:
386  case scSignExtend:
387  return cast<SCEVCastExpr>(this)->getType();
388  case scAddRecExpr:
389  return cast<SCEVAddRecExpr>(this)->getType();
390  case scMulExpr:
391  return cast<SCEVMulExpr>(this)->getType();
392  case scUMaxExpr:
393  case scSMaxExpr:
394  case scUMinExpr:
395  case scSMinExpr:
396  return cast<SCEVMinMaxExpr>(this)->getType();
397  case scAddExpr:
398  return cast<SCEVAddExpr>(this)->getType();
399  case scUDivExpr:
400  return cast<SCEVUDivExpr>(this)->getType();
401  case scUnknown:
402  return cast<SCEVUnknown>(this)->getType();
403  case scCouldNotCompute:
404  llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
405  }
406  llvm_unreachable("Unknown SCEV kind!");
407 }
408 
409 bool SCEV::isZero() const {
410  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
411  return SC->getValue()->isZero();
412  return false;
413 }
414 
415 bool SCEV::isOne() const {
416  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
417  return SC->getValue()->isOne();
418  return false;
419 }
420 
421 bool SCEV::isAllOnesValue() const {
422  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
423  return SC->getValue()->isMinusOne();
424  return false;
425 }
426 
428  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
429  if (!Mul) return false;
430 
431  // If there is a constant factor, it will be first.
432  const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
433  if (!SC) return false;
434 
435  // Return true if the value is negative, this matches things like (-42 * V).
436  return SC->getAPInt().isNegative();
437 }
438 
441 
443  return S->getSCEVType() == scCouldNotCompute;
444 }
445 
448  ID.AddInteger(scConstant);
449  ID.AddPointer(V);
450  void *IP = nullptr;
451  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
452  SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
453  UniqueSCEVs.InsertNode(S, IP);
454  return S;
455 }
456 
458  return getConstant(ConstantInt::get(getContext(), Val));
459 }
460 
461 const SCEV *
463  IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
464  return getConstant(ConstantInt::get(ITy, V, isSigned));
465 }
466 
468  const SCEV *op, Type *ty)
469  : SCEV(ID, SCEVTy, computeExpressionSize(op)), Ty(ty) {
470  Operands[0] = op;
471 }
472 
473 SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
474  Type *ITy)
475  : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
476  assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
477  "Must be a non-bit-width-changing pointer-to-integer cast!");
478 }
479 
481  SCEVTypes SCEVTy, const SCEV *op,
482  Type *ty)
483  : SCEVCastExpr(ID, SCEVTy, op, ty) {}
484 
485 SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
486  Type *ty)
488  assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
489  "Cannot truncate non-integer value!");
490 }
491 
492 SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
493  const SCEV *op, Type *ty)
495  assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
496  "Cannot zero extend non-integer value!");
497 }
498 
499 SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
500  const SCEV *op, Type *ty)
502  assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
503  "Cannot sign extend non-integer value!");
504 }
505 
506 void SCEVUnknown::deleted() {
507  // Clear this SCEVUnknown from various maps.
508  SE->forgetMemoizedResults(this);
509 
510  // Remove this SCEVUnknown from the uniquing map.
511  SE->UniqueSCEVs.RemoveNode(this);
512 
513  // Release the value.
514  setValPtr(nullptr);
515 }
516 
517 void SCEVUnknown::allUsesReplacedWith(Value *New) {
518  // Remove this SCEVUnknown from the uniquing map.
519  SE->UniqueSCEVs.RemoveNode(this);
520 
521  // Update this SCEVUnknown to point to the new value. This is needed
522  // because there may still be outstanding SCEVs which still point to
523  // this SCEVUnknown.
524  setValPtr(New);
525 }
526 
527 bool SCEVUnknown::isSizeOf(Type *&AllocTy) const {
528  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
529  if (VCE->getOpcode() == Instruction::PtrToInt)
530  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
531  if (CE->getOpcode() == Instruction::GetElementPtr &&
532  CE->getOperand(0)->isNullValue() &&
533  CE->getNumOperands() == 2)
534  if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1)))
535  if (CI->isOne()) {
536  AllocTy = cast<GEPOperator>(CE)->getSourceElementType();
537  return true;
538  }
539 
540  return false;
541 }
542 
543 bool SCEVUnknown::isAlignOf(Type *&AllocTy) const {
544  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
545  if (VCE->getOpcode() == Instruction::PtrToInt)
546  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
547  if (CE->getOpcode() == Instruction::GetElementPtr &&
548  CE->getOperand(0)->isNullValue()) {
549  Type *Ty = cast<GEPOperator>(CE)->getSourceElementType();
550  if (StructType *STy = dyn_cast<StructType>(Ty))
551  if (!STy->isPacked() &&
552  CE->getNumOperands() == 3 &&
553  CE->getOperand(1)->isNullValue()) {
554  if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2)))
555  if (CI->isOne() &&
556  STy->getNumElements() == 2 &&
557  STy->getElementType(0)->isIntegerTy(1)) {
558  AllocTy = STy->getElementType(1);
559  return true;
560  }
561  }
562  }
563 
564  return false;
565 }
566 
567 bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const {
568  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
569  if (VCE->getOpcode() == Instruction::PtrToInt)
570  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
571  if (CE->getOpcode() == Instruction::GetElementPtr &&
572  CE->getNumOperands() == 3 &&
573  CE->getOperand(0)->isNullValue() &&
574  CE->getOperand(1)->isNullValue()) {
575  Type *Ty = cast<GEPOperator>(CE)->getSourceElementType();
576  // Ignore vector types here so that ScalarEvolutionExpander doesn't
577  // emit getelementptrs that index into vectors.
578  if (Ty->isStructTy() || Ty->isArrayTy()) {
579  CTy = Ty;
580  FieldNo = CE->getOperand(2);
581  return true;
582  }
583  }
584 
585  return false;
586 }
587 
588 //===----------------------------------------------------------------------===//
589 // SCEV Utilities
590 //===----------------------------------------------------------------------===//
591 
592 /// Compare the two values \p LV and \p RV in terms of their "complexity" where
593 /// "complexity" is a partial (and somewhat ad-hoc) relation used to order
594 /// operands in SCEV expressions. \p EqCache is a set of pairs of values that
595 /// have been previously deemed to be "equally complex" by this routine. It is
596 /// intended to avoid exponential time complexity in cases like:
597 ///
598 /// %a = f(%x, %y)
599 /// %b = f(%a, %a)
600 /// %c = f(%b, %b)
601 ///
602 /// %d = f(%x, %y)
603 /// %e = f(%d, %d)
604 /// %f = f(%e, %e)
605 ///
606 /// CompareValueComplexity(%f, %c)
607 ///
608 /// Since we do not continue running this routine on expression trees once we
609 /// have seen unequal values, there is no need to track them in the cache.
610 static int
612  const LoopInfo *const LI, Value *LV, Value *RV,
613  unsigned Depth) {
614  if (Depth > MaxValueCompareDepth || EqCacheValue.isEquivalent(LV, RV))
615  return 0;
616 
617  // Order pointer values after integer values. This helps SCEVExpander form
618  // GEPs.
619  bool LIsPointer = LV->getType()->isPointerTy(),
620  RIsPointer = RV->getType()->isPointerTy();
621  if (LIsPointer != RIsPointer)
622  return (int)LIsPointer - (int)RIsPointer;
623 
624  // Compare getValueID values.
625  unsigned LID = LV->getValueID(), RID = RV->getValueID();
626  if (LID != RID)
627  return (int)LID - (int)RID;
628 
629  // Sort arguments by their position.
630  if (const auto *LA = dyn_cast<Argument>(LV)) {
631  const auto *RA = cast<Argument>(RV);
632  unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
633  return (int)LArgNo - (int)RArgNo;
634  }
635 
636  if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
637  const auto *RGV = cast<GlobalValue>(RV);
638 
639  const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
640  auto LT = GV->getLinkage();
641  return !(GlobalValue::isPrivateLinkage(LT) ||
643  };
644 
645  // Use the names to distinguish the two values, but only if the
646  // names are semantically important.
647  if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
648  return LGV->getName().compare(RGV->getName());
649  }
650 
651  // For instructions, compare their loop depth, and their operand count. This
652  // is pretty loose.
653  if (const auto *LInst = dyn_cast<Instruction>(LV)) {
654  const auto *RInst = cast<Instruction>(RV);
655 
656  // Compare loop depths.
657  const BasicBlock *LParent = LInst->getParent(),
658  *RParent = RInst->getParent();
659  if (LParent != RParent) {
660  unsigned LDepth = LI->getLoopDepth(LParent),
661  RDepth = LI->getLoopDepth(RParent);
662  if (LDepth != RDepth)
663  return (int)LDepth - (int)RDepth;
664  }
665 
666  // Compare the number of operands.
667  unsigned LNumOps = LInst->getNumOperands(),
668  RNumOps = RInst->getNumOperands();
669  if (LNumOps != RNumOps)
670  return (int)LNumOps - (int)RNumOps;
671 
672  for (unsigned Idx : seq(0u, LNumOps)) {
673  int Result =
674  CompareValueComplexity(EqCacheValue, LI, LInst->getOperand(Idx),
675  RInst->getOperand(Idx), Depth + 1);
676  if (Result != 0)
677  return Result;
678  }
679  }
680 
681  EqCacheValue.unionSets(LV, RV);
682  return 0;
683 }
684 
685 // Return negative, zero, or positive, if LHS is less than, equal to, or greater
686 // than RHS, respectively. A three-way result allows recursive comparisons to be
687 // more efficient.
688 // If the max analysis depth was reached, return None, assuming we do not know
689 // if they are equivalent for sure.
690 static Optional<int>
692  EquivalenceClasses<const Value *> &EqCacheValue,
693  const LoopInfo *const LI, const SCEV *LHS,
694  const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
695  // Fast-path: SCEVs are uniqued so we can do a quick equality check.
696  if (LHS == RHS)
697  return 0;
698 
699  // Primarily, sort the SCEVs by their getSCEVType().
700  SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
701  if (LType != RType)
702  return (int)LType - (int)RType;
703 
704  if (EqCacheSCEV.isEquivalent(LHS, RHS))
705  return 0;
706 
708  return None;
709 
710  // Aside from the getSCEVType() ordering, the particular ordering
711  // isn't very important except that it's beneficial to be consistent,
712  // so that (a + b) and (b + a) don't end up as different expressions.
713  switch (LType) {
714  case scUnknown: {
715  const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
716  const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
717 
718  int X = CompareValueComplexity(EqCacheValue, LI, LU->getValue(),
719  RU->getValue(), Depth + 1);
720  if (X == 0)
721  EqCacheSCEV.unionSets(LHS, RHS);
722  return X;
723  }
724 
725  case scConstant: {
726  const SCEVConstant *LC = cast<SCEVConstant>(LHS);
727  const SCEVConstant *RC = cast<SCEVConstant>(RHS);
728 
729  // Compare constant values.
730  const APInt &LA = LC->getAPInt();
731  const APInt &RA = RC->getAPInt();
732  unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
733  if (LBitWidth != RBitWidth)
734  return (int)LBitWidth - (int)RBitWidth;
735  return LA.ult(RA) ? -1 : 1;
736  }
737 
738  case scAddRecExpr: {
739  const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
740  const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
741 
742  // There is always a dominance between two recs that are used by one SCEV,
743  // so we can safely sort recs by loop header dominance. We require such
744  // order in getAddExpr.
745  const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
746  if (LLoop != RLoop) {
747  const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
748  assert(LHead != RHead && "Two loops share the same header?");
749  if (DT.dominates(LHead, RHead))
750  return 1;
751  else
752  assert(DT.dominates(RHead, LHead) &&
753  "No dominance between recurrences used by one SCEV?");
754  return -1;
755  }
756 
757  // Addrec complexity grows with operand count.
758  unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands();
759  if (LNumOps != RNumOps)
760  return (int)LNumOps - (int)RNumOps;
761 
762  // Lexicographically compare.
763  for (unsigned i = 0; i != LNumOps; ++i) {
764  auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
765  LA->getOperand(i), RA->getOperand(i), DT,
766  Depth + 1);
767  if (X != 0)
768  return X;
769  }
770  EqCacheSCEV.unionSets(LHS, RHS);
771  return 0;
772  }
773 
774  case scAddExpr:
775  case scMulExpr:
776  case scSMaxExpr:
777  case scUMaxExpr:
778  case scSMinExpr:
779  case scUMinExpr: {
780  const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS);
781  const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
782 
783  // Lexicographically compare n-ary expressions.
784  unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
785  if (LNumOps != RNumOps)
786  return (int)LNumOps - (int)RNumOps;
787 
788  for (unsigned i = 0; i != LNumOps; ++i) {
789  auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
790  LC->getOperand(i), RC->getOperand(i), DT,
791  Depth + 1);
792  if (X != 0)
793  return X;
794  }
795  EqCacheSCEV.unionSets(LHS, RHS);
796  return 0;
797  }
798 
799  case scUDivExpr: {
800  const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS);
801  const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
802 
803  // Lexicographically compare udiv expressions.
804  auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getLHS(),
805  RC->getLHS(), DT, Depth + 1);
806  if (X != 0)
807  return X;
808  X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getRHS(),
809  RC->getRHS(), DT, Depth + 1);
810  if (X == 0)
811  EqCacheSCEV.unionSets(LHS, RHS);
812  return X;
813  }
814 
815  case scPtrToInt:
816  case scTruncate:
817  case scZeroExtend:
818  case scSignExtend: {
819  const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS);
820  const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
821 
822  // Compare cast expressions by operand.
823  auto X =
824  CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getOperand(),
825  RC->getOperand(), DT, Depth + 1);
826  if (X == 0)
827  EqCacheSCEV.unionSets(LHS, RHS);
828  return X;
829  }
830 
831  case scCouldNotCompute:
832  llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
833  }
834  llvm_unreachable("Unknown SCEV kind!");
835 }
836 
837 /// Given a list of SCEV objects, order them by their complexity, and group
838 /// objects of the same complexity together by value. When this routine is
839 /// finished, we know that any duplicates in the vector are consecutive and that
840 /// complexity is monotonically increasing.
841 ///
842 /// Note that we go take special precautions to ensure that we get deterministic
843 /// results from this routine. In other words, we don't want the results of
844 /// this to depend on where the addresses of various SCEV objects happened to
845 /// land in memory.
847  LoopInfo *LI, DominatorTree &DT) {
848  if (Ops.size() < 2) return; // Noop
849 
852 
853  // Whether LHS has provably less complexity than RHS.
854  auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
855  auto Complexity =
856  CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT);
857  return Complexity && *Complexity < 0;
858  };
859  if (Ops.size() == 2) {
860  // This is the common case, which also happens to be trivially simple.
861  // Special case it.
862  const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
863  if (IsLessComplex(RHS, LHS))
864  std::swap(LHS, RHS);
865  return;
866  }
867 
868  // Do the rough sort by complexity.
869  llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
870  return IsLessComplex(LHS, RHS);
871  });
872 
873  // Now that we are sorted by complexity, group elements of the same
874  // complexity. Note that this is, at worst, N^2, but the vector is likely to
875  // be extremely short in practice. Note that we take this approach because we
876  // do not want to depend on the addresses of the objects we are grouping.
877  for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
878  const SCEV *S = Ops[i];
879  unsigned Complexity = S->getSCEVType();
880 
881  // If there are any objects of the same complexity and same value as this
882  // one, group them.
883  for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
884  if (Ops[j] == S) { // Found a duplicate.
885  // Move it to immediately after i'th element.
886  std::swap(Ops[i+1], Ops[j]);
887  ++i; // no need to rescan it.
888  if (i == e-2) return; // Done!
889  }
890  }
891  }
892 }
893 
894 /// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
895 /// least HugeExprThreshold nodes).
897  return any_of(Ops, [](const SCEV *S) {
898  return S->getExpressionSize() >= HugeExprThreshold;
899  });
900 }
901 
902 //===----------------------------------------------------------------------===//
903 // Simple SCEV method implementations
904 //===----------------------------------------------------------------------===//
905 
906 /// Compute BC(It, K). The result has width W. Assume, K > 0.
907 static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
908  ScalarEvolution &SE,
909  Type *ResultTy) {
910  // Handle the simplest case efficiently.
911  if (K == 1)
912  return SE.getTruncateOrZeroExtend(It, ResultTy);
913 
914  // We are using the following formula for BC(It, K):
915  //
916  // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
917  //
918  // Suppose, W is the bitwidth of the return value. We must be prepared for
919  // overflow. Hence, we must assure that the result of our computation is
920  // equal to the accurate one modulo 2^W. Unfortunately, division isn't
921  // safe in modular arithmetic.
922  //
923  // However, this code doesn't use exactly that formula; the formula it uses
924  // is something like the following, where T is the number of factors of 2 in
925  // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
926  // exponentiation:
927  //
928  // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
929  //
930  // This formula is trivially equivalent to the previous formula. However,
931  // this formula can be implemented much more efficiently. The trick is that
932  // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
933  // arithmetic. To do exact division in modular arithmetic, all we have
934  // to do is multiply by the inverse. Therefore, this step can be done at
935  // width W.
936  //
937  // The next issue is how to safely do the division by 2^T. The way this
938  // is done is by doing the multiplication step at a width of at least W + T
939  // bits. This way, the bottom W+T bits of the product are accurate. Then,
940  // when we perform the division by 2^T (which is equivalent to a right shift
941  // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
942  // truncated out after the division by 2^T.
943  //
944  // In comparison to just directly using the first formula, this technique
945  // is much more efficient; using the first formula requires W * K bits,
946  // but this formula less than W + K bits. Also, the first formula requires
947  // a division step, whereas this formula only requires multiplies and shifts.
948  //
949  // It doesn't matter whether the subtraction step is done in the calculation
950  // width or the input iteration count's width; if the subtraction overflows,
951  // the result must be zero anyway. We prefer here to do it in the width of
952  // the induction variable because it helps a lot for certain cases; CodeGen
953  // isn't smart enough to ignore the overflow, which leads to much less
954  // efficient code if the width of the subtraction is wider than the native
955  // register width.
956  //
957  // (It's possible to not widen at all by pulling out factors of 2 before
958  // the multiplication; for example, K=2 can be calculated as
959  // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
960  // extra arithmetic, so it's not an obvious win, and it gets
961  // much more complicated for K > 3.)
962 
963  // Protection from insane SCEVs; this bound is conservative,
964  // but it probably doesn't matter.
965  if (K > 1000)
966  return SE.getCouldNotCompute();
967 
968  unsigned W = SE.getTypeSizeInBits(ResultTy);
969 
970  // Calculate K! / 2^T and T; we divide out the factors of two before
971  // multiplying for calculating K! / 2^T to avoid overflow.
972  // Other overflow doesn't matter because we only care about the bottom
973  // W bits of the result.
974  APInt OddFactorial(W, 1);
975  unsigned T = 1;
976  for (unsigned i = 3; i <= K; ++i) {
977  APInt Mult(W, i);
978  unsigned TwoFactors = Mult.countTrailingZeros();
979  T += TwoFactors;
980  Mult.lshrInPlace(TwoFactors);
981  OddFactorial *= Mult;
982  }
983 
984  // We need at least W + T bits for the multiplication step
985  unsigned CalculationBits = W + T;
986 
987  // Calculate 2^T, at width T+W.
988  APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
989 
990  // Calculate the multiplicative inverse of K! / 2^T;
991  // this multiplication factor will perform the exact division by
992  // K! / 2^T.
994  APInt MultiplyFactor = OddFactorial.zext(W+1);
995  MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
996  MultiplyFactor = MultiplyFactor.trunc(W);
997 
998  // Calculate the product, at width T+W
999  IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
1000  CalculationBits);
1001  const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
1002  for (unsigned i = 1; i != K; ++i) {
1003  const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
1004  Dividend = SE.getMulExpr(Dividend,
1005  SE.getTruncateOrZeroExtend(S, CalculationTy));
1006  }
1007 
1008  // Divide by 2^T
1009  const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
1010 
1011  // Truncate the result, and divide by K! / 2^T.
1012 
1013  return SE.getMulExpr(SE.getConstant(MultiplyFactor),
1014  SE.getTruncateOrZeroExtend(DivResult, ResultTy));
1015 }
1016 
1017 /// Return the value of this chain of recurrences at the specified iteration
1018 /// number. We can evaluate this recurrence by multiplying each element in the
1019 /// chain by the binomial coefficient corresponding to it. In other words, we
1020 /// can evaluate {A,+,B,+,C,+,D} as:
1021 ///
1022 /// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
1023 ///
1024 /// where BC(It, k) stands for binomial coefficient.
1026  ScalarEvolution &SE) const {
1027  return evaluateAtIteration(makeArrayRef(op_begin(), op_end()), It, SE);
1028 }
1029 
1030 const SCEV *
1032  const SCEV *It, ScalarEvolution &SE) {
1033  assert(Operands.size() > 0);
1034  const SCEV *Result = Operands[0];
1035  for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
1036  // The computation is correct in the face of overflow provided that the
1037  // multiplication is performed _after_ the evaluation of the binomial
1038  // coefficient.
1039  const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
1040  if (isa<SCEVCouldNotCompute>(Coeff))
1041  return Coeff;
1042 
1043  Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
1044  }
1045  return Result;
1046 }
1047 
1048 //===----------------------------------------------------------------------===//
1049 // SCEV Expression folder implementations
1050 //===----------------------------------------------------------------------===//
1051 
1053  unsigned Depth) {
1054  assert(Depth <= 1 &&
1055  "getLosslessPtrToIntExpr() should self-recurse at most once.");
1056 
1057  // We could be called with an integer-typed operands during SCEV rewrites.
1058  // Since the operand is an integer already, just perform zext/trunc/self cast.
1059  if (!Op->getType()->isPointerTy())
1060  return Op;
1061 
1062  // What would be an ID for such a SCEV cast expression?
1064  ID.AddInteger(scPtrToInt);
1065  ID.AddPointer(Op);
1066 
1067  void *IP = nullptr;
1068 
1069  // Is there already an expression for such a cast?
1070  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1071  return S;
1072 
1073  // It isn't legal for optimizations to construct new ptrtoint expressions
1074  // for non-integral pointers.
1075  if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1076  return getCouldNotCompute();
1077 
1078  Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1079 
1080  // We can only trivially model ptrtoint if SCEV's effective (integer) type
1081  // is sufficiently wide to represent all possible pointer values.
1082  // We could theoretically teach SCEV to truncate wider pointers, but
1083  // that isn't implemented for now.
1085  getDataLayout().getTypeSizeInBits(IntPtrTy))
1086  return getCouldNotCompute();
1087 
1088  // If not, is this expression something we can't reduce any further?
1089  if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1090  // Perform some basic constant folding. If the operand of the ptr2int cast
1091  // is a null pointer, don't create a ptr2int SCEV expression (that will be
1092  // left as-is), but produce a zero constant.
1093  // NOTE: We could handle a more general case, but lack motivational cases.
1094  if (isa<ConstantPointerNull>(U->getValue()))
1095  return getZero(IntPtrTy);
1096 
1097  // Create an explicit cast node.
1098  // We can reuse the existing insert position since if we get here,
1099  // we won't have made any changes which would invalidate it.
1100  SCEV *S = new (SCEVAllocator)
1101  SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1102  UniqueSCEVs.InsertNode(S, IP);
1103  addToLoopUseLists(S);
1104  return S;
1105  }
1106 
1107  assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
1108  "non-SCEVUnknown's.");
1109 
1110  // Otherwise, we've got some expression that is more complex than just a
1111  // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1112  // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1113  // only, and the expressions must otherwise be integer-typed.
1114  // So sink the cast down to the SCEVUnknown's.
1115 
1116  /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1117  /// which computes a pointer-typed value, and rewrites the whole expression
1118  /// tree so that *all* the computations are done on integers, and the only
1119  /// pointer-typed operands in the expression are SCEVUnknown.
1120  class SCEVPtrToIntSinkingRewriter
1121  : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1123 
1124  public:
1125  SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1126 
1127  static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1128  SCEVPtrToIntSinkingRewriter Rewriter(SE);
1129  return Rewriter.visit(Scev);
1130  }
1131 
1132  const SCEV *visit(const SCEV *S) {
1133  Type *STy = S->getType();
1134  // If the expression is not pointer-typed, just keep it as-is.
1135  if (!STy->isPointerTy())
1136  return S;
1137  // Else, recursively sink the cast down into it.
1138  return Base::visit(S);
1139  }
1140 
1141  const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1143  bool Changed = false;
1144  for (auto *Op : Expr->operands()) {
1145  Operands.push_back(visit(Op));
1146  Changed |= Op != Operands.back();
1147  }
1148  return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1149  }
1150 
1151  const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1153  bool Changed = false;
1154  for (auto *Op : Expr->operands()) {
1155  Operands.push_back(visit(Op));
1156  Changed |= Op != Operands.back();
1157  }
1158  return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1159  }
1160 
1161  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1162  assert(Expr->getType()->isPointerTy() &&
1163  "Should only reach pointer-typed SCEVUnknown's.");
1164  return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1165  }
1166  };
1167 
1168  // And actually perform the cast sinking.
1169  const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1170  assert(IntOp->getType()->isIntegerTy() &&
1171  "We must have succeeded in sinking the cast, "
1172  "and ending up with an integer-typed expression!");
1173  return IntOp;
1174 }
1175 
1177  assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1178 
1179  const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1180  if (isa<SCEVCouldNotCompute>(IntOp))
1181  return IntOp;
1182 
1183  return getTruncateOrZeroExtend(IntOp, Ty);
1184 }
1185 
1187  unsigned Depth) {
1188  assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1189  "This is not a truncating conversion!");
1190  assert(isSCEVable(Ty) &&
1191  "This is not a conversion to a SCEVable type!");
1192  assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1193  Ty = getEffectiveSCEVType(Ty);
1194 
1196  ID.AddInteger(scTruncate);
1197  ID.AddPointer(Op);
1198  ID.AddPointer(Ty);
1199  void *IP = nullptr;
1200  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1201 
1202  // Fold if the operand is constant.
1203  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1204  return getConstant(
1205  cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1206 
1207  // trunc(trunc(x)) --> trunc(x)
1208  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
1209  return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1210 
1211  // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1212  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1213  return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1214 
1215  // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1216  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1217  return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1218 
1219  if (Depth > MaxCastDepth) {
1220  SCEV *S =
1221  new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1222  UniqueSCEVs.InsertNode(S, IP);
1223  addToLoopUseLists(S);
1224  return S;
1225  }
1226 
1227  // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1228  // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1229  // if after transforming we have at most one truncate, not counting truncates
1230  // that replace other casts.
1231  if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) {
1232  auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1234  unsigned numTruncs = 0;
1235  for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1236  ++i) {
1237  const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1238  if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1239  isa<SCEVTruncateExpr>(S))
1240  numTruncs++;
1241  Operands.push_back(S);
1242  }
1243  if (numTruncs < 2) {
1244  if (isa<SCEVAddExpr>(Op))
1245  return getAddExpr(Operands);
1246  else if (isa<SCEVMulExpr>(Op))
1247  return getMulExpr(Operands);
1248  else
1249  llvm_unreachable("Unexpected SCEV type for Op.");
1250  }
1251  // Although we checked in the beginning that ID is not in the cache, it is
1252  // possible that during recursion and different modification ID was inserted
1253  // into the cache. So if we find it, just return it.
1254  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1255  return S;
1256  }
1257 
1258  // If the input value is a chrec scev, truncate the chrec's operands.
1259  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1261  for (const SCEV *Op : AddRec->operands())
1262  Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1263  return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1264  }
1265 
1266  // Return zero if truncating to known zeros.
1267  uint32_t MinTrailingZeros = GetMinTrailingZeros(Op);
1268  if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1269  return getZero(Ty);
1270 
1271  // The cast wasn't folded; create an explicit cast node. We can reuse
1272  // the existing insert position since if we get here, we won't have
1273  // made any changes which would invalidate it.
1274  SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1275  Op, Ty);
1276  UniqueSCEVs.InsertNode(S, IP);
1277  addToLoopUseLists(S);
1278  return S;
1279 }
1280 
1281 // Get the limit of a recurrence such that incrementing by Step cannot cause
1282 // signed overflow as long as the value of the recurrence within the
1283 // loop does not exceed this limit before incrementing.
1284 static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1285  ICmpInst::Predicate *Pred,
1286  ScalarEvolution *SE) {
1287  unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1288  if (SE->isKnownPositive(Step)) {
1289  *Pred = ICmpInst::ICMP_SLT;
1291  SE->getSignedRangeMax(Step));
1292  }
1293  if (SE->isKnownNegative(Step)) {
1294  *Pred = ICmpInst::ICMP_SGT;
1296  SE->getSignedRangeMin(Step));
1297  }
1298  return nullptr;
1299 }
1300 
1301 // Get the limit of a recurrence such that incrementing by Step cannot cause
1302 // unsigned overflow as long as the value of the recurrence within the loop does
1303 // not exceed this limit before incrementing.
1304 static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step,
1305  ICmpInst::Predicate *Pred,
1306  ScalarEvolution *SE) {
1307  unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1308  *Pred = ICmpInst::ICMP_ULT;
1309 
1311  SE->getUnsignedRangeMax(Step));
1312 }
1313 
1314 namespace {
1315 
1316 struct ExtendOpTraitsBase {
1317  typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1318  unsigned);
1319 };
1320 
1321 // Used to make code generic over signed and unsigned overflow.
1322 template <typename ExtendOp> struct ExtendOpTraits {
1323  // Members present:
1324  //
1325  // static const SCEV::NoWrapFlags WrapType;
1326  //
1327  // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1328  //
1329  // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1330  // ICmpInst::Predicate *Pred,
1331  // ScalarEvolution *SE);
1332 };
1333 
1334 template <>
1335 struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1336  static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1337 
1338  static const GetExtendExprTy GetExtendExpr;
1339 
1340  static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1341  ICmpInst::Predicate *Pred,
1342  ScalarEvolution *SE) {
1343  return getSignedOverflowLimitForStep(Step, Pred, SE);
1344  }
1345 };
1346 
1347 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1349 
1350 template <>
1351 struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1352  static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1353 
1354  static const GetExtendExprTy GetExtendExpr;
1355 
1356  static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1357  ICmpInst::Predicate *Pred,
1358  ScalarEvolution *SE) {
1359  return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1360  }
1361 };
1362 
1363 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1365 
1366 } // end anonymous namespace
1367 
1368 // The recurrence AR has been shown to have no signed/unsigned wrap or something
1369 // close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1370 // easily prove NSW/NUW for its preincrement or postincrement sibling. This
1371 // allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1372 // Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1373 // expression "Step + sext/zext(PreIncAR)" is congruent with
1374 // "sext/zext(PostIncAR)"
1375 template <typename ExtendOpTy>
1376 static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1377  ScalarEvolution *SE, unsigned Depth) {
1378  auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1379  auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1380 
1381  const Loop *L = AR->getLoop();
1382  const SCEV *Start = AR->getStart();
1383  const SCEV *Step = AR->getStepRecurrence(*SE);
1384 
1385  // Check for a simple looking step prior to loop entry.
1386  const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1387  if (!SA)
1388  return nullptr;
1389 
1390  // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1391  // subtraction is expensive. For this purpose, perform a quick and dirty
1392  // difference, by checking for Step in the operand list.
1394  for (const SCEV *Op : SA->operands())
1395  if (Op != Step)
1396  DiffOps.push_back(Op);
1397 
1398  if (DiffOps.size() == SA->getNumOperands())
1399  return nullptr;
1400 
1401  // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1402  // `Step`:
1403 
1404  // 1. NSW/NUW flags on the step increment.
1405  auto PreStartFlags =
1407  const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1408  const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1409  SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1410 
1411  // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1412  // "S+X does not sign/unsign-overflow".
1413  //
1414 
1415  const SCEV *BECount = SE->getBackedgeTakenCount(L);
1416  if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1417  !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1418  return PreStart;
1419 
1420  // 2. Direct overflow check on the step operation's expression.
1421  unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1422  Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1423  const SCEV *OperandExtendedStart =
1424  SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1425  (SE->*GetExtendExpr)(Step, WideTy, Depth));
1426  if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1427  if (PreAR && AR->getNoWrapFlags(WrapType)) {
1428  // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1429  // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1430  // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1431  SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1432  }
1433  return PreStart;
1434  }
1435 
1436  // 3. Loop precondition.
1437  ICmpInst::Predicate Pred;
1438  const SCEV *OverflowLimit =
1439  ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1440 
1441  if (OverflowLimit &&
1442  SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1443  return PreStart;
1444 
1445  return nullptr;
1446 }
1447 
1448 // Get the normalized zero or sign extended expression for this AddRec's Start.
1449 template <typename ExtendOpTy>
1450 static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1451  ScalarEvolution *SE,
1452  unsigned Depth) {
1453  auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1454 
1455  const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1456  if (!PreStart)
1457  return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1458 
1459  return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1460  Depth),
1461  (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1462 }
1463 
1464 // Try to prove away overflow by looking at "nearby" add recurrences. A
1465 // motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1466 // does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1467 //
1468 // Formally:
1469 //
1470 // {S,+,X} == {S-T,+,X} + T
1471 // => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1472 //
1473 // If ({S-T,+,X} + T) does not overflow ... (1)
1474 //
1475 // RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1476 //
1477 // If {S-T,+,X} does not overflow ... (2)
1478 //
1479 // RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1480 // == {Ext(S-T)+Ext(T),+,Ext(X)}
1481 //
1482 // If (S-T)+T does not overflow ... (3)
1483 //
1484 // RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1485 // == {Ext(S),+,Ext(X)} == LHS
1486 //
1487 // Thus, if (1), (2) and (3) are true for some T, then
1488 // Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1489 //
1490 // (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1491 // does not overflow" restricted to the 0th iteration. Therefore we only need
1492 // to check for (1) and (2).
1493 //
1494 // In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1495 // is `Delta` (defined below).
1496 template <typename ExtendOpTy>
1497 bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1498  const SCEV *Step,
1499  const Loop *L) {
1500  auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1501 
1502  // We restrict `Start` to a constant to prevent SCEV from spending too much
1503  // time here. It is correct (but more expensive) to continue with a
1504  // non-constant `Start` and do a general SCEV subtraction to compute
1505  // `PreStart` below.
1506  const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1507  if (!StartC)
1508  return false;
1509 
1510  APInt StartAI = StartC->getAPInt();
1511 
1512  for (unsigned Delta : {-2, -1, 1, 2}) {
1513  const SCEV *PreStart = getConstant(StartAI - Delta);
1514 
1516  ID.AddInteger(scAddRecExpr);
1517  ID.AddPointer(PreStart);
1518  ID.AddPointer(Step);
1519  ID.AddPointer(L);
1520  void *IP = nullptr;
1521  const auto *PreAR =
1522  static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1523 
1524  // Give up if we don't already have the add recurrence we need because
1525  // actually constructing an add recurrence is relatively expensive.
1526  if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1527  const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1529  const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1530  DeltaS, &Pred, this);
1531  if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1532  return true;
1533  }
1534  }
1535 
1536  return false;
1537 }
1538 
1539 // Finds an integer D for an expression (C + x + y + ...) such that the top
1540 // level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1541 // unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1542 // maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1543 // the (C + x + y + ...) expression is \p WholeAddExpr.
1545  const SCEVConstant *ConstantTerm,
1546  const SCEVAddExpr *WholeAddExpr) {
1547  const APInt &C = ConstantTerm->getAPInt();
1548  const unsigned BitWidth = C.getBitWidth();
1549  // Find number of trailing zeros of (x + y + ...) w/o the C first:
1550  uint32_t TZ = BitWidth;
1551  for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1552  TZ = std::min(TZ, SE.GetMinTrailingZeros(WholeAddExpr->getOperand(I)));
1553  if (TZ) {
1554  // Set D to be as many least significant bits of C as possible while still
1555  // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1556  return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1557  }
1558  return APInt(BitWidth, 0);
1559 }
1560 
1561 // Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1562 // level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1563 // number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1564 // ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1566  const APInt &ConstantStart,
1567  const SCEV *Step) {
1568  const unsigned BitWidth = ConstantStart.getBitWidth();
1569  const uint32_t TZ = SE.GetMinTrailingZeros(Step);
1570  if (TZ)
1571  return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1572  : ConstantStart;
1573  return APInt(BitWidth, 0);
1574 }
1575 
1576 const SCEV *
1578  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1579  "This is not an extending conversion!");
1580  assert(isSCEVable(Ty) &&
1581  "This is not a conversion to a SCEVable type!");
1582  assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1583  Ty = getEffectiveSCEVType(Ty);
1584 
1585  // Fold if the operand is constant.
1586  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1587  return getConstant(
1588  cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(), Ty)));
1589 
1590  // zext(zext(x)) --> zext(x)
1591  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1592  return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1593 
1594  // Before doing any expensive analysis, check to see if we've already
1595  // computed a SCEV for this Op and Ty.
1597  ID.AddInteger(scZeroExtend);
1598  ID.AddPointer(Op);
1599  ID.AddPointer(Ty);
1600  void *IP = nullptr;
1601  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1602  if (Depth > MaxCastDepth) {
1603  SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1604  Op, Ty);
1605  UniqueSCEVs.InsertNode(S, IP);
1606  addToLoopUseLists(S);
1607  return S;
1608  }
1609 
1610  // zext(trunc(x)) --> zext(x) or x or trunc(x)
1611  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1612  // It's possible the bits taken off by the truncate were all zero bits. If
1613  // so, we should be able to simplify this further.
1614  const SCEV *X = ST->getOperand();
1616  unsigned TruncBits = getTypeSizeInBits(ST->getType());
1617  unsigned NewBits = getTypeSizeInBits(Ty);
1618  if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1619  CR.zextOrTrunc(NewBits)))
1620  return getTruncateOrZeroExtend(X, Ty, Depth);
1621  }
1622 
1623  // If the input value is a chrec scev, and we can prove that the value
1624  // did not overflow the old, smaller, value, we can zero extend all of the
1625  // operands (often constants). This allows analysis of something like
1626  // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1627  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1628  if (AR->isAffine()) {
1629  const SCEV *Start = AR->getStart();
1630  const SCEV *Step = AR->getStepRecurrence(*this);
1631  unsigned BitWidth = getTypeSizeInBits(AR->getType());
1632  const Loop *L = AR->getLoop();
1633 
1634  if (!AR->hasNoUnsignedWrap()) {
1635  auto NewFlags = proveNoWrapViaConstantRanges(AR);
1636  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1637  }
1638 
1639  // If we have special knowledge that this addrec won't overflow,
1640  // we don't need to do any further analysis.
1641  if (AR->hasNoUnsignedWrap())
1642  return getAddRecExpr(
1643  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1),
1644  getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
1645 
1646  // Check whether the backedge-taken count is SCEVCouldNotCompute.
1647  // Note that this serves two purposes: It filters out loops that are
1648  // simply not analyzable, and it covers the case where this code is
1649  // being called from within backedge-taken count analysis, such that
1650  // attempting to ask for the backedge-taken count would likely result
1651  // in infinite recursion. In the later case, the analysis code will
1652  // cope with a conservative value, and it will take care to purge
1653  // that value once it has finished.
1654  const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1655  if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1656  // Manually compute the final value for AR, checking for overflow.
1657 
1658  // Check whether the backedge-taken count can be losslessly casted to
1659  // the addrec's type. The count is always unsigned.
1660  const SCEV *CastedMaxBECount =
1661  getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1662  const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1663  CastedMaxBECount, MaxBECount->getType(), Depth);
1664  if (MaxBECount == RecastedMaxBECount) {
1665  Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1666  // Check whether Start+Step*MaxBECount has no unsigned overflow.
1667  const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1668  SCEV::FlagAnyWrap, Depth + 1);
1669  const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1671  Depth + 1),
1672  WideTy, Depth + 1);
1673  const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1674  const SCEV *WideMaxBECount =
1675  getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1676  const SCEV *OperandExtendedAdd =
1677  getAddExpr(WideStart,
1678  getMulExpr(WideMaxBECount,
1679  getZeroExtendExpr(Step, WideTy, Depth + 1),
1680  SCEV::FlagAnyWrap, Depth + 1),
1681  SCEV::FlagAnyWrap, Depth + 1);
1682  if (ZAdd == OperandExtendedAdd) {
1683  // Cache knowledge of AR NUW, which is propagated to this AddRec.
1684  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1685  // Return the expression with the addrec on the outside.
1686  return getAddRecExpr(
1687  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1688  Depth + 1),
1689  getZeroExtendExpr(Step, Ty, Depth + 1), L,
1690  AR->getNoWrapFlags());
1691  }
1692  // Similar to above, only this time treat the step value as signed.
1693  // This covers loops that count down.
1694  OperandExtendedAdd =
1695  getAddExpr(WideStart,
1696  getMulExpr(WideMaxBECount,
1697  getSignExtendExpr(Step, WideTy, Depth + 1),
1698  SCEV::FlagAnyWrap, Depth + 1),
1699  SCEV::FlagAnyWrap, Depth + 1);
1700  if (ZAdd == OperandExtendedAdd) {
1701  // Cache knowledge of AR NW, which is propagated to this AddRec.
1702  // Negative step causes unsigned wrap, but it still can't self-wrap.
1703  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1704  // Return the expression with the addrec on the outside.
1705  return getAddRecExpr(
1706  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1707  Depth + 1),
1708  getSignExtendExpr(Step, Ty, Depth + 1), L,
1709  AR->getNoWrapFlags());
1710  }
1711  }
1712  }
1713 
1714  // Normally, in the cases we can prove no-overflow via a
1715  // backedge guarding condition, we can also compute a backedge
1716  // taken count for the loop. The exceptions are assumptions and
1717  // guards present in the loop -- SCEV is not great at exploiting
1718  // these to compute max backedge taken counts, but can still use
1719  // these to prove lack of overflow. Use this fact to avoid
1720  // doing extra work that may not pay off.
1721  if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1722  !AC.assumptions().empty()) {
1723 
1724  auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1725  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1726  if (AR->hasNoUnsignedWrap()) {
1727  // Same as nuw case above - duplicated here to avoid a compile time
1728  // issue. It's not clear that the order of checks does matter, but
1729  // it's one of two issue possible causes for a change which was
1730  // reverted. Be conservative for the moment.
1731  return getAddRecExpr(
1732  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1733  Depth + 1),
1734  getZeroExtendExpr(Step, Ty, Depth + 1), L,
1735  AR->getNoWrapFlags());
1736  }
1737 
1738  // For a negative step, we can extend the operands iff doing so only
1739  // traverses values in the range zext([0,UINT_MAX]).
1740  if (isKnownNegative(Step)) {
1742  getSignedRangeMin(Step));
1745  // Cache knowledge of AR NW, which is propagated to this
1746  // AddRec. Negative step causes unsigned wrap, but it
1747  // still can't self-wrap.
1748  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1749  // Return the expression with the addrec on the outside.
1750  return getAddRecExpr(
1751  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1752  Depth + 1),
1753  getSignExtendExpr(Step, Ty, Depth + 1), L,
1754  AR->getNoWrapFlags());
1755  }
1756  }
1757  }
1758 
1759  // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1760  // if D + (C - D + Step * n) could be proven to not unsigned wrap
1761  // where D maximizes the number of trailing zeros of (C - D + Step * n)
1762  if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1763  const APInt &C = SC->getAPInt();
1764  const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1765  if (D != 0) {
1766  const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1767  const SCEV *SResidual =
1768  getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1769  const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1770  return getAddExpr(SZExtD, SZExtR,
1772  Depth + 1);
1773  }
1774  }
1775 
1776  if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1777  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1778  return getAddRecExpr(
1779  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1),
1780  getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
1781  }
1782  }
1783 
1784  // zext(A % B) --> zext(A) % zext(B)
1785  {
1786  const SCEV *LHS;
1787  const SCEV *RHS;
1788  if (matchURem(Op, LHS, RHS))
1789  return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1790  getZeroExtendExpr(RHS, Ty, Depth + 1));
1791  }
1792 
1793  // zext(A / B) --> zext(A) / zext(B).
1794  if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1795  return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1796  getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1797 
1798  if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1799  // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1800  if (SA->hasNoUnsignedWrap()) {
1801  // If the addition does not unsign overflow then we can, by definition,
1802  // commute the zero extension with the addition operation.
1804  for (const auto *Op : SA->operands())
1805  Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1806  return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1807  }
1808 
1809  // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1810  // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1811  // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1812  //
1813  // Often address arithmetics contain expressions like
1814  // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1815  // This transformation is useful while proving that such expressions are
1816  // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1817  if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1818  const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1819  if (D != 0) {
1820  const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1821  const SCEV *SResidual =
1823  const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1824  return getAddExpr(SZExtD, SZExtR,
1826  Depth + 1);
1827  }
1828  }
1829  }
1830 
1831  if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1832  // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1833  if (SM->hasNoUnsignedWrap()) {
1834  // If the multiply does not unsign overflow then we can, by definition,
1835  // commute the zero extension with the multiply operation.
1837  for (const auto *Op : SM->operands())
1838  Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1839  return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1840  }
1841 
1842  // zext(2^K * (trunc X to iN)) to iM ->
1843  // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1844  //
1845  // Proof:
1846  //
1847  // zext(2^K * (trunc X to iN)) to iM
1848  // = zext((trunc X to iN) << K) to iM
1849  // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1850  // (because shl removes the top K bits)
1851  // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1852  // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1853  //
1854  if (SM->getNumOperands() == 2)
1855  if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1856  if (MulLHS->getAPInt().isPowerOf2())
1857  if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1858  int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
1859  MulLHS->getAPInt().logBase2();
1860  Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1861  return getMulExpr(
1862  getZeroExtendExpr(MulLHS, Ty),
1864  getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
1865  SCEV::FlagNUW, Depth + 1);
1866  }
1867  }
1868 
1869  // The cast wasn't folded; create an explicit cast node.
1870  // Recompute the insert position, as it may have been invalidated.
1871  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1872  SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1873  Op, Ty);
1874  UniqueSCEVs.InsertNode(S, IP);
1875  addToLoopUseLists(S);
1876  return S;
1877 }
1878 
1879 const SCEV *
1881  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1882  "This is not an extending conversion!");
1883  assert(isSCEVable(Ty) &&
1884  "This is not a conversion to a SCEVable type!");
1885  assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1886  Ty = getEffectiveSCEVType(Ty);
1887 
1888  // Fold if the operand is constant.
1889  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1890  return getConstant(
1891  cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(), Ty)));
1892 
1893  // sext(sext(x)) --> sext(x)
1894  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1895  return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1896 
1897  // sext(zext(x)) --> zext(x)
1898  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1899  return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1900 
1901  // Before doing any expensive analysis, check to see if we've already
1902  // computed a SCEV for this Op and Ty.
1904  ID.AddInteger(scSignExtend);
1905  ID.AddPointer(Op);
1906  ID.AddPointer(Ty);
1907  void *IP = nullptr;
1908  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1909  // Limit recursion depth.
1910  if (Depth > MaxCastDepth) {
1911  SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1912  Op, Ty);
1913  UniqueSCEVs.InsertNode(S, IP);
1914  addToLoopUseLists(S);
1915  return S;
1916  }
1917 
1918  // sext(trunc(x)) --> sext(x) or x or trunc(x)
1919  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1920  // It's possible the bits taken off by the truncate were all sign bits. If
1921  // so, we should be able to simplify this further.
1922  const SCEV *X = ST->getOperand();
1924  unsigned TruncBits = getTypeSizeInBits(ST->getType());
1925  unsigned NewBits = getTypeSizeInBits(Ty);
1926  if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1927  CR.sextOrTrunc(NewBits)))
1928  return getTruncateOrSignExtend(X, Ty, Depth);
1929  }
1930 
1931  if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1932  // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1933  if (SA->hasNoSignedWrap()) {
1934  // If the addition does not sign overflow then we can, by definition,
1935  // commute the sign extension with the addition operation.
1937  for (const auto *Op : SA->operands())
1938  Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1939  return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1940  }
1941 
1942  // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1943  // if D + (C - D + x + y + ...) could be proven to not signed wrap
1944  // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1945  //
1946  // For instance, this will bring two seemingly different expressions:
1947  // 1 + sext(5 + 20 * %x + 24 * %y) and
1948  // sext(6 + 20 * %x + 24 * %y)
1949  // to the same form:
1950  // 2 + sext(4 + 20 * %x + 24 * %y)
1951  if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1952  const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1953  if (D != 0) {
1954  const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1955  const SCEV *SResidual =
1957  const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1958  return getAddExpr(SSExtD, SSExtR,
1960  Depth + 1);
1961  }
1962  }
1963  }
1964  // If the input value is a chrec scev, and we can prove that the value
1965  // did not overflow the old, smaller, value, we can sign extend all of the
1966  // operands (often constants). This allows analysis of something like
1967  // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
1968  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1969  if (AR->isAffine()) {
1970  const SCEV *Start = AR->getStart();
1971  const SCEV *Step = AR->getStepRecurrence(*this);
1972  unsigned BitWidth = getTypeSizeInBits(AR->getType());
1973  const Loop *L = AR->getLoop();
1974 
1975  if (!AR->hasNoSignedWrap()) {
1976  auto NewFlags = proveNoWrapViaConstantRanges(AR);
1977  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1978  }
1979 
1980  // If we have special knowledge that this addrec won't overflow,
1981  // we don't need to do any further analysis.
1982  if (AR->hasNoSignedWrap())
1983  return getAddRecExpr(
1984  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
1985  getSignExtendExpr(Step, Ty, Depth + 1), L, SCEV::FlagNSW);
1986 
1987  // Check whether the backedge-taken count is SCEVCouldNotCompute.
1988  // Note that this serves two purposes: It filters out loops that are
1989  // simply not analyzable, and it covers the case where this code is
1990  // being called from within backedge-taken count analysis, such that
1991  // attempting to ask for the backedge-taken count would likely result
1992  // in infinite recursion. In the later case, the analysis code will
1993  // cope with a conservative value, and it will take care to purge
1994  // that value once it has finished.
1995  const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1996  if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1997  // Manually compute the final value for AR, checking for
1998  // overflow.
1999 
2000  // Check whether the backedge-taken count can be losslessly casted to
2001  // the addrec's type. The count is always unsigned.
2002  const SCEV *CastedMaxBECount =
2003  getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2004  const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2005  CastedMaxBECount, MaxBECount->getType(), Depth);
2006  if (MaxBECount == RecastedMaxBECount) {
2007  Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2008  // Check whether Start+Step*MaxBECount has no signed overflow.
2009  const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2010  SCEV::FlagAnyWrap, Depth + 1);
2011  const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2013  Depth + 1),
2014  WideTy, Depth + 1);
2015  const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2016  const SCEV *WideMaxBECount =
2017  getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2018  const SCEV *OperandExtendedAdd =
2019  getAddExpr(WideStart,
2020  getMulExpr(WideMaxBECount,
2021  getSignExtendExpr(Step, WideTy, Depth + 1),
2022  SCEV::FlagAnyWrap, Depth + 1),
2023  SCEV::FlagAnyWrap, Depth + 1);
2024  if (SAdd == OperandExtendedAdd) {
2025  // Cache knowledge of AR NSW, which is propagated to this AddRec.
2026  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2027  // Return the expression with the addrec on the outside.
2028  return getAddRecExpr(
2029  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2030  Depth + 1),
2031  getSignExtendExpr(Step, Ty, Depth + 1), L,
2032  AR->getNoWrapFlags());
2033  }
2034  // Similar to above, only this time treat the step value as unsigned.
2035  // This covers loops that count up with an unsigned step.
2036  OperandExtendedAdd =
2037  getAddExpr(WideStart,
2038  getMulExpr(WideMaxBECount,
2039  getZeroExtendExpr(Step, WideTy, Depth + 1),
2040  SCEV::FlagAnyWrap, Depth + 1),
2041  SCEV::FlagAnyWrap, Depth + 1);
2042  if (SAdd == OperandExtendedAdd) {
2043  // If AR wraps around then
2044  //
2045  // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2046  // => SAdd != OperandExtendedAdd
2047  //
2048  // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2049  // (SAdd == OperandExtendedAdd => AR is NW)
2050 
2051  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2052 
2053  // Return the expression with the addrec on the outside.
2054  return getAddRecExpr(
2055  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2056  Depth + 1),
2057  getZeroExtendExpr(Step, Ty, Depth + 1), L,
2058  AR->getNoWrapFlags());
2059  }
2060  }
2061  }
2062 
2063  auto NewFlags = proveNoSignedWrapViaInduction(AR);
2064  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2065  if (AR->hasNoSignedWrap()) {
2066  // Same as nsw case above - duplicated here to avoid a compile time
2067  // issue. It's not clear that the order of checks does matter, but
2068  // it's one of two issue possible causes for a change which was
2069  // reverted. Be conservative for the moment.
2070  return getAddRecExpr(
2071  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
2072  getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
2073  }
2074 
2075  // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2076  // if D + (C - D + Step * n) could be proven to not signed wrap
2077  // where D maximizes the number of trailing zeros of (C - D + Step * n)
2078  if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2079  const APInt &C = SC->getAPInt();
2080  const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2081  if (D != 0) {
2082  const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2083  const SCEV *SResidual =
2084  getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2085  const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2086  return getAddExpr(SSExtD, SSExtR,
2088  Depth + 1);
2089  }
2090  }
2091 
2092  if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2093  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2094  return getAddRecExpr(
2095  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
2096  getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
2097  }
2098  }
2099 
2100  // If the input value is provably positive and we could not simplify
2101  // away the sext build a zext instead.
2102  if (isKnownNonNegative(Op))
2103  return getZeroExtendExpr(Op, Ty, Depth + 1);
2104 
2105  // The cast wasn't folded; create an explicit cast node.
2106  // Recompute the insert position, as it may have been invalidated.
2107  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2108  SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2109  Op, Ty);
2110  UniqueSCEVs.InsertNode(S, IP);
2111  addToLoopUseLists(S);
2112  return S;
2113 }
2114 
2115 /// getAnyExtendExpr - Return a SCEV for the given operand extended with
2116 /// unspecified bits out to the given type.
2118  Type *Ty) {
2119  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2120  "This is not an extending conversion!");
2121  assert(isSCEVable(Ty) &&
2122  "This is not a conversion to a SCEVable type!");
2123  Ty = getEffectiveSCEVType(Ty);
2124 
2125  // Sign-extend negative constants.
2126  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2127  if (SC->getAPInt().isNegative())
2128  return getSignExtendExpr(Op, Ty);
2129 
2130  // Peel off a truncate cast.
2131  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2132  const SCEV *NewOp = T->getOperand();
2133  if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2134  return getAnyExtendExpr(NewOp, Ty);
2135  return getTruncateOrNoop(NewOp, Ty);
2136  }
2137 
2138  // Next try a zext cast. If the cast is folded, use it.
2139  const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2140  if (!isa<SCEVZeroExtendExpr>(ZExt))
2141  return ZExt;
2142 
2143  // Next try a sext cast. If the cast is folded, use it.
2144  const SCEV *SExt = getSignExtendExpr(Op, Ty);
2145  if (!isa<SCEVSignExtendExpr>(SExt))
2146  return SExt;
2147 
2148  // Force the cast to be folded into the operands of an addrec.
2149  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2151  for (const SCEV *Op : AR->operands())
2152  Ops.push_back(getAnyExtendExpr(Op, Ty));
2153  return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2154  }
2155 
2156  // If the expression is obviously signed, use the sext cast value.
2157  if (isa<SCEVSMaxExpr>(Op))
2158  return SExt;
2159 
2160  // Absent any other information, use the zext cast value.
2161  return ZExt;
2162 }
2163 
2164 /// Process the given Ops list, which is a list of operands to be added under
2165 /// the given scale, update the given map. This is a helper function for
2166 /// getAddRecExpr. As an example of what it does, given a sequence of operands
2167 /// that would form an add expression like this:
2168 ///
2169 /// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2170 ///
2171 /// where A and B are constants, update the map with these values:
2172 ///
2173 /// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2174 ///
2175 /// and add 13 + A*B*29 to AccumulatedConstant.
2176 /// This will allow getAddRecExpr to produce this:
2177 ///
2178 /// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2179 ///
2180 /// This form often exposes folding opportunities that are hidden in
2181 /// the original operand list.
2182 ///
2183 /// Return true iff it appears that any interesting folding opportunities
2184 /// may be exposed. This helps getAddRecExpr short-circuit extra work in
2185 /// the common case where no interesting opportunities are present, and
2186 /// is also used as a check to avoid infinite recursion.
2187 static bool
2190  APInt &AccumulatedConstant,
2191  const SCEV *const *Ops, size_t NumOperands,
2192  const APInt &Scale,
2193  ScalarEvolution &SE) {
2194  bool Interesting = false;
2195 
2196  // Iterate over the add operands. They are sorted, with constants first.
2197  unsigned i = 0;
2198  while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2199  ++i;
2200  // Pull a buried constant out to the outside.
2201  if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2202  Interesting = true;
2203  AccumulatedConstant += Scale * C->getAPInt();
2204  }
2205 
2206  // Next comes everything else. We're especially interested in multiplies
2207  // here, but they're in the middle, so just visit the rest with one loop.
2208  for (; i != NumOperands; ++i) {
2209  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
2210  if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2211  APInt NewScale =
2212  Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2213  if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2214  // A multiplication of a constant with another add; recurse.
2215  const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2216  Interesting |=
2217  CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2218  Add->op_begin(), Add->getNumOperands(),
2219  NewScale, SE);
2220  } else {
2221  // A multiplication of a constant with some other value. Update
2222  // the map.
2223  SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2224  const SCEV *Key = SE.getMulExpr(MulOps);
2225  auto Pair = M.insert({Key, NewScale});
2226  if (Pair.second) {
2227  NewOps.push_back(Pair.first->first);
2228  } else {
2229  Pair.first->second += NewScale;
2230  // The map already had an entry for this value, which may indicate
2231  // a folding opportunity.
2232  Interesting = true;
2233  }
2234  }
2235  } else {
2236  // An ordinary operand. Update the map.
2237  std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2238  M.insert({Ops[i], Scale});
2239  if (Pair.second) {
2240  NewOps.push_back(Pair.first->first);
2241  } else {
2242  Pair.first->second += Scale;
2243  // The map already had an entry for this value, which may indicate
2244  // a folding opportunity.
2245  Interesting = true;
2246  }
2247  }
2248  }
2249 
2250  return Interesting;
2251 }
2252 
2254  const SCEV *LHS, const SCEV *RHS) {
2255  const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2256  SCEV::NoWrapFlags, unsigned);
2257  switch (BinOp) {
2258  default:
2259  llvm_unreachable("Unsupported binary op");
2260  case Instruction::Add:
2262  break;
2263  case Instruction::Sub:
2265  break;
2266  case Instruction::Mul:
2268  break;
2269  }
2270 
2271  const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2274 
2275  // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2276  auto *NarrowTy = cast<IntegerType>(LHS->getType());
2277  auto *WideTy =
2278  IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2279 
2280  const SCEV *A = (this->*Extension)(
2281  (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2282  const SCEV *B = (this->*Operation)((this->*Extension)(LHS, WideTy, 0),
2283  (this->*Extension)(RHS, WideTy, 0),
2284  SCEV::FlagAnyWrap, 0);
2285  return A == B;
2286 }
2287 
2288 std::pair<SCEV::NoWrapFlags, bool /*Deduced*/>
2290  const OverflowingBinaryOperator *OBO) {
2291  SCEV::NoWrapFlags Flags = SCEV::NoWrapFlags::FlagAnyWrap;
2292 
2293  if (OBO->hasNoUnsignedWrap())
2294  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2295  if (OBO->hasNoSignedWrap())
2296  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2297 
2298  bool Deduced = false;
2299 
2300  if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2301  return {Flags, Deduced};
2302 
2303  if (OBO->getOpcode() != Instruction::Add &&
2304  OBO->getOpcode() != Instruction::Sub &&
2305  OBO->getOpcode() != Instruction::Mul)
2306  return {Flags, Deduced};
2307 
2308  const SCEV *LHS = getSCEV(OBO->getOperand(0));
2309  const SCEV *RHS = getSCEV(OBO->getOperand(1));
2310 
2311  if (!OBO->hasNoUnsignedWrap() &&
2313  /* Signed */ false, LHS, RHS)) {
2314  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2315  Deduced = true;
2316  }
2317 
2318  if (!OBO->hasNoSignedWrap() &&
2320  /* Signed */ true, LHS, RHS)) {
2321  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2322  Deduced = true;
2323  }
2324 
2325  return {Flags, Deduced};
2326 }
2327 
2328 // We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2329 // `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2330 // can't-overflow flags for the operation if possible.
2331 static SCEV::NoWrapFlags
2333  const ArrayRef<const SCEV *> Ops,
2334  SCEV::NoWrapFlags Flags) {
2335  using namespace std::placeholders;
2336 
2337  using OBO = OverflowingBinaryOperator;
2338 
2339  bool CanAnalyze =
2340  Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr;
2341  (void)CanAnalyze;
2342  assert(CanAnalyze && "don't call from other places!");
2343 
2344  int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2345  SCEV::NoWrapFlags SignOrUnsignWrap =
2346  ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2347 
2348  // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2349  auto IsKnownNonNegative = [&](const SCEV *S) {
2350  return SE->isKnownNonNegative(S);
2351  };
2352 
2353  if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2354  Flags =
2355  ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2356 
2357  SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2358 
2359  if (SignOrUnsignWrap != SignOrUnsignMask &&
2360  (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2361  isa<SCEVConstant>(Ops[0])) {
2362 
2363  auto Opcode = [&] {
2364  switch (Type) {
2365  case scAddExpr:
2366  return Instruction::Add;
2367  case scMulExpr:
2368  return Instruction::Mul;
2369  default:
2370  llvm_unreachable("Unexpected SCEV op.");
2371  }
2372  }();
2373 
2374  const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2375 
2376  // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2377  if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2379  Opcode, C, OBO::NoSignedWrap);
2380  if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2381  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2382  }
2383 
2384  // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2385  if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2387  Opcode, C, OBO::NoUnsignedWrap);
2388  if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2389  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2390  }
2391  }
2392 
2393  // <0,+,nonnegative><nw> is also nuw
2394  // TODO: Add corresponding nsw case
2396  !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2397  Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2398  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2399 
2400  // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2402  Ops.size() == 2) {
2403  if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2404  if (UDiv->getOperand(1) == Ops[1])
2405  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2406  if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2407  if (UDiv->getOperand(1) == Ops[0])
2408  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2409  }
2410 
2411  return Flags;
2412 }
2413 
2415  return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2416 }
2417 
2418 /// Get a canonical add expression, or something simpler if possible.
2420  SCEV::NoWrapFlags OrigFlags,
2421  unsigned Depth) {
2422  assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2423  "only nuw or nsw allowed");
2424  assert(!Ops.empty() && "Cannot get empty add!");
2425  if (Ops.size() == 1) return Ops[0];
2426 #ifndef NDEBUG
2427  Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2428  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2429  assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2430  "SCEVAddExpr operand types don't match!");
2431  unsigned NumPtrs = count_if(
2432  Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2433  assert(NumPtrs <= 1 && "add has at most one pointer operand");
2434 #endif
2435 
2436  // Sort by complexity, this groups all similar expression types together.
2437  GroupByComplexity(Ops, &LI, DT);
2438 
2439  // If there are any constants, fold them together.
2440  unsigned Idx = 0;
2441  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2442  ++Idx;
2443  assert(Idx < Ops.size());
2444  while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2445  // We found two constants, fold them together!
2446  Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt());
2447  if (Ops.size() == 2) return Ops[0];
2448  Ops.erase(Ops.begin()+1); // Erase the folded element
2449  LHSC = cast<SCEVConstant>(Ops[0]);
2450  }
2451 
2452  // If we are left with a constant zero being added, strip it off.
2453  if (LHSC->getValue()->isZero()) {
2454  Ops.erase(Ops.begin());
2455  --Idx;
2456  }
2457 
2458  if (Ops.size() == 1) return Ops[0];
2459  }
2460 
2461  // Delay expensive flag strengthening until necessary.
2462  auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
2463  return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2464  };
2465 
2466  // Limit recursion calls depth.
2467  if (Depth > MaxArithDepth || hasHugeExpression(Ops))
2468  return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2469 
2470  if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2471  // Don't strengthen flags if we have no new information.
2472  SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2473  if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2474  Add->setNoWrapFlags(ComputeFlags(Ops));
2475  return S;
2476  }
2477 
2478  // Okay, check to see if the same value occurs in the operand list more than
2479  // once. If so, merge them together into an multiply expression. Since we
2480  // sorted the list, these values are required to be adjacent.
2481  Type *Ty = Ops[0]->getType();
2482  bool FoundMatch = false;
2483  for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2484  if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2485  // Scan ahead to count how many equal operands there are.
2486  unsigned Count = 2;
2487  while (i+Count != e && Ops[i+Count] == Ops[i])
2488  ++Count;
2489  // Merge the values into a multiply.
2490  const SCEV *Scale = getConstant(Ty, Count);
2491  const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2492  if (Ops.size() == Count)
2493  return Mul;
2494  Ops[i] = Mul;
2495  Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2496  --i; e -= Count - 1;
2497  FoundMatch = true;
2498  }
2499  if (FoundMatch)
2500  return getAddExpr(Ops, OrigFlags, Depth + 1);
2501 
2502  // Check for truncates. If all the operands are truncated from the same
2503  // type, see if factoring out the truncate would permit the result to be
2504  // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2505  // if the contents of the resulting outer trunc fold to something simple.
2506  auto FindTruncSrcType = [&]() -> Type * {
2507  // We're ultimately looking to fold an addrec of truncs and muls of only
2508  // constants and truncs, so if we find any other types of SCEV
2509  // as operands of the addrec then we bail and return nullptr here.
2510  // Otherwise, we return the type of the operand of a trunc that we find.
2511  if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2512  return T->getOperand()->getType();
2513  if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2514  const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2515  if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2516  return T->getOperand()->getType();
2517  }
2518  return nullptr;
2519  };
2520  if (auto *SrcType = FindTruncSrcType()) {
2522  bool Ok = true;
2523  // Check all the operands to see if they can be represented in the
2524  // source type of the truncate.
2525  for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
2526  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
2527  if (T->getOperand()->getType() != SrcType) {
2528  Ok = false;
2529  break;
2530  }
2531  LargeOps.push_back(T->getOperand());
2532  } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2533  LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2534  } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
2535  SmallVector<const SCEV *, 8> LargeMulOps;
2536  for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2537  if (const SCEVTruncateExpr *T =
2538  dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2539  if (T->getOperand()->getType() != SrcType) {
2540  Ok = false;
2541  break;
2542  }
2543  LargeMulOps.push_back(T->getOperand());
2544  } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2545  LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2546  } else {
2547  Ok = false;
2548  break;
2549  }
2550  }
2551  if (Ok)
2552  LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2553  } else {
2554  Ok = false;
2555  break;
2556  }
2557  }
2558  if (Ok) {
2559  // Evaluate the expression in the larger type.
2560  const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2561  // If it folds to something simple, use it. Otherwise, don't.
2562  if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2563  return getTruncateExpr(Fold, Ty);
2564  }
2565  }
2566 
2567  if (Ops.size() == 2) {
2568  // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2569  // C2 can be folded in a way that allows retaining wrapping flags of (X +
2570  // C1).
2571  const SCEV *A = Ops[0];
2572  const SCEV *B = Ops[1];
2573  auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2574  auto *C = dyn_cast<SCEVConstant>(A);
2575  if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2576  auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2577  auto C2 = C->getAPInt();
2578  SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2579 
2580  APInt ConstAdd = C1 + C2;
2581  auto AddFlags = AddExpr->getNoWrapFlags();
2582  // Adding a smaller constant is NUW if the original AddExpr was NUW.
2583  if (ScalarEvolution::hasFlags(AddFlags, SCEV::FlagNUW) &&
2584  ConstAdd.ule(C1)) {
2585  PreservedFlags =
2586  ScalarEvolution::setFlags(PreservedFlags, SCEV::FlagNUW);
2587  }
2588 
2589  // Adding a constant with the same sign and small magnitude is NSW, if the
2590  // original AddExpr was NSW.
2591  if (ScalarEvolution::hasFlags(AddFlags, SCEV::FlagNSW) &&
2592  C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2593  ConstAdd.abs().ule(C1.abs())) {
2594  PreservedFlags =
2595  ScalarEvolution::setFlags(PreservedFlags, SCEV::FlagNSW);
2596  }
2597 
2598  if (PreservedFlags != SCEV::FlagAnyWrap) {
2599  SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2600  NewOps[0] = getConstant(ConstAdd);
2601  return getAddExpr(NewOps, PreservedFlags);
2602  }
2603  }
2604  }
2605 
2606  // Skip past any other cast SCEVs.
2607  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2608  ++Idx;
2609 
2610  // If there are add operands they would be next.
2611  if (Idx < Ops.size()) {
2612  bool DeletedAdd = false;
2613  // If the original flags and all inlined SCEVAddExprs are NUW, use the
2614  // common NUW flag for expression after inlining. Other flags cannot be
2615  // preserved, because they may depend on the original order of operations.
2616  SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2617  while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2618  if (Ops.size() > AddOpsInlineThreshold ||
2619  Add->getNumOperands() > AddOpsInlineThreshold)
2620  break;
2621  // If we have an add, expand the add operands onto the end of the operands
2622  // list.
2623  Ops.erase(Ops.begin()+Idx);
2624  Ops.append(Add->op_begin(), Add->op_end());
2625  DeletedAdd = true;
2626  CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2627  }
2628 
2629  // If we deleted at least one add, we added operands to the end of the list,
2630  // and they are not necessarily sorted. Recurse to resort and resimplify
2631  // any operands we just acquired.
2632  if (DeletedAdd)
2633  return getAddExpr(Ops, CommonFlags, Depth + 1);
2634  }
2635 
2636  // Skip over the add expression until we get to a multiply.
2637  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2638  ++Idx;
2639 
2640  // Check to see if there are any folding opportunities present with
2641  // operands multiplied by constant values.
2642  if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2646  APInt AccumulatedConstant(BitWidth, 0);
2647  if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2648  Ops.data(), Ops.size(),
2649  APInt(BitWidth, 1), *this)) {
2650  struct APIntCompare {
2651  bool operator()(const APInt &LHS, const APInt &RHS) const {
2652  return LHS.ult(RHS);
2653  }
2654  };
2655 
2656  // Some interesting folding opportunity is present, so its worthwhile to
2657  // re-generate the operands list. Group the operands by constant scale,
2658  // to avoid multiplying by the same constant scale multiple times.
2659  std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2660  for (const SCEV *NewOp : NewOps)
2661  MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2662  // Re-generate the operands list.
2663  Ops.clear();
2664  if (AccumulatedConstant != 0)
2665  Ops.push_back(getConstant(AccumulatedConstant));
2666  for (auto &MulOp : MulOpLists) {
2667  if (MulOp.first == 1) {
2668  Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2669  } else if (MulOp.first != 0) {
2670  Ops.push_back(getMulExpr(
2671  getConstant(MulOp.first),
2672  getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2673  SCEV::FlagAnyWrap, Depth + 1));
2674  }
2675  }
2676  if (Ops.empty())
2677  return getZero(Ty);
2678  if (Ops.size() == 1)
2679  return Ops[0];
2680  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2681  }
2682  }
2683 
2684  // If we are adding something to a multiply expression, make sure the
2685  // something is not already an operand of the multiply. If so, merge it into
2686  // the multiply.
2687  for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2688  const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2689  for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2690  const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2691  if (isa<SCEVConstant>(MulOpSCEV))
2692  continue;
2693  for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2694  if (MulOpSCEV == Ops[AddOp]) {
2695  // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2696  const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2697  if (Mul->getNumOperands() != 2) {
2698  // If the multiply has more than two operands, we must get the
2699  // Y*Z term.
2700  SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
2701  Mul->op_begin()+MulOp);
2702  MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
2703  InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2704  }
2705  SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2706  const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2707  const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2708  SCEV::FlagAnyWrap, Depth + 1);
2709  if (Ops.size() == 2) return OuterMul;
2710  if (AddOp < Idx) {
2711  Ops.erase(Ops.begin()+AddOp);
2712  Ops.erase(Ops.begin()+Idx-1);
2713  } else {
2714  Ops.erase(Ops.begin()+Idx);
2715  Ops.erase(Ops.begin()+AddOp-1);
2716  }
2717  Ops.push_back(OuterMul);
2718  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2719  }
2720 
2721  // Check this multiply against other multiplies being added together.
2722  for (unsigned OtherMulIdx = Idx+1;
2723  OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2724  ++OtherMulIdx) {
2725  const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2726  // If MulOp occurs in OtherMul, we can fold the two multiplies
2727  // together.
2728  for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2729  OMulOp != e; ++OMulOp)
2730  if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2731  // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2732  const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2733  if (Mul->getNumOperands() != 2) {
2734  SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
2735  Mul->op_begin()+MulOp);
2736  MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
2737  InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2738  }
2739  const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2740  if (OtherMul->getNumOperands() != 2) {
2741  SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(),
2742  OtherMul->op_begin()+OMulOp);
2743  MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end());
2744  InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2745  }
2746  SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2747  const SCEV *InnerMulSum =
2748  getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2749  const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2750  SCEV::FlagAnyWrap, Depth + 1);
2751  if (Ops.size() == 2) return OuterMul;
2752  Ops.erase(Ops.begin()+Idx);
2753  Ops.erase(Ops.begin()+OtherMulIdx-1);
2754  Ops.push_back(OuterMul);
2755  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2756  }
2757  }
2758  }
2759  }
2760 
2761  // If there are any add recurrences in the operands list, see if any other
2762  // added values are loop invariant. If so, we can fold them into the
2763  // recurrence.
2764  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2765  ++Idx;
2766 
2767  // Scan over all recurrences, trying to fold loop invariants into them.
2768  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2769  // Scan all of the other operands to this add and add them to the vector if
2770  // they are loop invariant w.r.t. the recurrence.
2772  const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2773  const Loop *AddRecLoop = AddRec->getLoop();
2774  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2775  if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2776  LIOps.push_back(Ops[i]);
2777  Ops.erase(Ops.begin()+i);
2778  --i; --e;
2779  }
2780 
2781  // If we found some loop invariants, fold them into the recurrence.
2782  if (!LIOps.empty()) {
2783  // Compute nowrap flags for the addition of the loop-invariant ops and
2784  // the addrec. Temporarily push it as an operand for that purpose.
2785  LIOps.push_back(AddRec);
2786  SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2787  LIOps.pop_back();
2788 
2789  // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2790  LIOps.push_back(AddRec->getStart());
2791 
2792  SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2793  // This follows from the fact that the no-wrap flags on the outer add
2794  // expression are applicable on the 0th iteration, when the add recurrence
2795  // will be equal to its start value.
2796  AddRecOps[0] = getAddExpr(LIOps, Flags, Depth + 1);
2797 
2798  // Build the new addrec. Propagate the NUW and NSW flags if both the
2799  // outer add and the inner addrec are guaranteed to have no overflow.
2800  // Always propagate NW.
2801  Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2802  const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2803 
2804  // If all of the other operands were loop invariant, we are done.
2805  if (Ops.size() == 1) return NewRec;
2806 
2807  // Otherwise, add the folded AddRec by the non-invariant parts.
2808  for (unsigned i = 0;; ++i)
2809  if (Ops[i] == AddRec) {
2810  Ops[i] = NewRec;
2811  break;
2812  }
2813  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2814  }
2815 
2816  // Okay, if there weren't any loop invariants to be folded, check to see if
2817  // there are multiple AddRec's with the same loop induction variable being
2818  // added together. If so, we can fold them.
2819  for (unsigned OtherIdx = Idx+1;
2820  OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2821  ++OtherIdx) {
2822  // We expect the AddRecExpr's to be sorted in reverse dominance order,
2823  // so that the 1st found AddRecExpr is dominated by all others.
2824  assert(DT.dominates(
2825  cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2826  AddRec->getLoop()->getHeader()) &&
2827  "AddRecExprs are not sorted in reverse dominance order?");
2828  if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2829  // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2830  SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2831  for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2832  ++OtherIdx) {
2833  const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2834  if (OtherAddRec->getLoop() == AddRecLoop) {
2835  for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2836  i != e; ++i) {
2837  if (i >= AddRecOps.size()) {
2838  AddRecOps.append(OtherAddRec->op_begin()+i,
2839  OtherAddRec->op_end());
2840  break;
2841  }
2842  SmallVector<const SCEV *, 2> TwoOps = {
2843  AddRecOps[i], OtherAddRec->getOperand(i)};
2844  AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2845  }
2846  Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2847  }
2848  }
2849  // Step size has changed, so we cannot guarantee no self-wraparound.
2850  Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2851  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2852  }
2853  }
2854 
2855  // Otherwise couldn't fold anything into this recurrence. Move onto the
2856  // next one.
2857  }
2858 
2859  // Okay, it looks like we really DO need an add expr. Check to see if we
2860  // already have one, otherwise create a new one.
2861  return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2862 }
2863 
2864 const SCEV *
2865 ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2866  SCEV::NoWrapFlags Flags) {
2868  ID.AddInteger(scAddExpr);
2869  for (const SCEV *Op : Ops)
2870  ID.AddPointer(Op);
2871  void *IP = nullptr;
2872  SCEVAddExpr *S =
2873  static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2874  if (!S) {
2875  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2876  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2877  S = new (SCEVAllocator)
2878  SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
2879  UniqueSCEVs.InsertNode(S, IP);
2880  addToLoopUseLists(S);
2881  }
2882  S->setNoWrapFlags(Flags);
2883  return S;
2884 }
2885 
2886 const SCEV *
2887 ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
2888  const Loop *L, SCEV::NoWrapFlags Flags) {
2890  ID.AddInteger(scAddRecExpr);
2891  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2892  ID.AddPointer(Ops[i]);
2893  ID.AddPointer(L);
2894  void *IP = nullptr;
2895  SCEVAddRecExpr *S =
2896  static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2897  if (!S) {
2898  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2899  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2900  S = new (SCEVAllocator)
2901  SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
2902  UniqueSCEVs.InsertNode(S, IP);
2903  addToLoopUseLists(S);
2904  }
2905  setNoWrapFlags(S, Flags);
2906  return S;
2907 }
2908 
2909 const SCEV *
2910 ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
2911  SCEV::NoWrapFlags Flags) {
2913  ID.AddInteger(scMulExpr);
2914  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2915  ID.AddPointer(Ops[i]);
2916  void *IP = nullptr;
2917  SCEVMulExpr *S =
2918  static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2919  if (!S) {
2920  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2921  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2922  S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
2923  O, Ops.size());
2924  UniqueSCEVs.InsertNode(S, IP);
2925  addToLoopUseLists(S);
2926  }
2927  S->setNoWrapFlags(Flags);
2928  return S;
2929 }
2930 
2931 static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
2932  uint64_t k = i*j;
2933  if (j > 1 && k / j != i) Overflow = true;
2934  return k;
2935 }
2936 
2937 /// Compute the result of "n choose k", the binomial coefficient. If an
2938 /// intermediate computation overflows, Overflow will be set and the return will
2939 /// be garbage. Overflow is not cleared on absence of overflow.
2940 static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
2941  // We use the multiplicative formula:
2942  // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
2943  // At each iteration, we take the n-th term of the numeral and divide by the
2944  // (k-n)th term of the denominator. This division will always produce an
2945  // integral result, and helps reduce the chance of overflow in the
2946  // intermediate computations. However, we can still overflow even when the
2947  // final result would fit.
2948 
2949  if (n == 0 || n == k) return 1;
2950  if (k > n) return 0;
2951 
2952  if (k > n/2)
2953  k = n-k;
2954 
2955  uint64_t r = 1;
2956  for (uint64_t i = 1; i <= k; ++i) {
2957  r = umul_ov(r, n-(i-1), Overflow);
2958  r /= i;
2959  }
2960  return r;
2961 }
2962 
2963 /// Determine if any of the operands in this SCEV are a constant or if
2964 /// any of the add or multiply expressions in this SCEV contain a constant.
2965 static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
2966  struct FindConstantInAddMulChain {
2967  bool FoundConstant = false;
2968 
2969  bool follow(const SCEV *S) {
2970  FoundConstant |= isa<SCEVConstant>(S);
2971  return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
2972  }
2973 
2974  bool isDone() const {
2975  return FoundConstant;
2976  }
2977  };
2978 
2979  FindConstantInAddMulChain F;
2981  ST.visitAll(StartExpr);
2982  return F.FoundConstant;
2983 }
2984 
2985 /// Get a canonical multiply expression, or something simpler if possible.
2987  SCEV::NoWrapFlags OrigFlags,
2988  unsigned Depth) {
2989  assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
2990  "only nuw or nsw allowed");
2991  assert(!Ops.empty() && "Cannot get empty mul!");
2992  if (Ops.size() == 1) return Ops[0];
2993 #ifndef NDEBUG
2994  Type *ETy = Ops[0]->getType();
2995  assert(!ETy->isPointerTy());
2996  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2997  assert(Ops[i]->getType() == ETy &&
2998  "SCEVMulExpr operand types don't match!");
2999 #endif
3000 
3001  // Sort by complexity, this groups all similar expression types together.
3002  GroupByComplexity(Ops, &LI, DT);
3003 
3004  // If there are any constants, fold them together.
3005  unsigned Idx = 0;
3006  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3007  ++Idx;
3008  assert(Idx < Ops.size());
3009  while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3010  // We found two constants, fold them together!
3011  Ops[0] = getConstant(LHSC->getAPInt() * RHSC->getAPInt());
3012  if (Ops.size() == 2) return Ops[0];
3013  Ops.erase(Ops.begin()+1); // Erase the folded element
3014  LHSC = cast<SCEVConstant>(Ops[0]);
3015  }
3016 
3017  // If we have a multiply of zero, it will always be zero.
3018  if (LHSC->getValue()->isZero())
3019  return LHSC;
3020 
3021  // If we are left with a constant one being multiplied, strip it off.
3022  if (LHSC->getValue()->isOne()) {
3023  Ops.erase(Ops.begin());
3024  --Idx;
3025  }
3026 
3027  if (Ops.size() == 1)
3028  return Ops[0];
3029  }
3030 
3031  // Delay expensive flag strengthening until necessary.
3032  auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
3033  return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3034  };
3035 
3036  // Limit recursion calls depth.
3037  if (Depth > MaxArithDepth || hasHugeExpression(Ops))
3038  return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3039 
3040  if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3041  // Don't strengthen flags if we have no new information.
3042  SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3043  if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3044  Mul->setNoWrapFlags(ComputeFlags(Ops));
3045  return S;
3046  }
3047 
3048  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3049  if (Ops.size() == 2) {
3050  // C1*(C2+V) -> C1*C2 + C1*V
3051  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
3052  // If any of Add's ops are Adds or Muls with a constant, apply this
3053  // transformation as well.
3054  //
3055  // TODO: There are some cases where this transformation is not
3056  // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3057  // this transformation should be narrowed down.
3058  if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add))
3059  return getAddExpr(getMulExpr(LHSC, Add->getOperand(0),
3060  SCEV::FlagAnyWrap, Depth + 1),
3061  getMulExpr(LHSC, Add->getOperand(1),
3062  SCEV::FlagAnyWrap, Depth + 1),
3063  SCEV::FlagAnyWrap, Depth + 1);
3064 
3065  if (Ops[0]->isAllOnesValue()) {
3066  // If we have a mul by -1 of an add, try distributing the -1 among the
3067  // add operands.
3068  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3070  bool AnyFolded = false;
3071  for (const SCEV *AddOp : Add->operands()) {
3072  const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3073  Depth + 1);
3074  if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3075  NewOps.push_back(Mul);
3076  }
3077  if (AnyFolded)
3078  return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3079  } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3080  // Negation preserves a recurrence's no self-wrap property.
3082  for (const SCEV *AddRecOp : AddRec->operands())
3083  Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3084  Depth + 1));
3085 
3086  return getAddRecExpr(Operands, AddRec->getLoop(),
3087  AddRec->getNoWrapFlags(SCEV::FlagNW));
3088  }
3089  }
3090  }
3091  }
3092 
3093  // Skip over the add expression until we get to a multiply.
3094  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3095  ++Idx;
3096 
3097  // If there are mul operands inline them all into this expression.
3098  if (Idx < Ops.size()) {
3099  bool DeletedMul = false;
3100  while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3101  if (Ops.size() > MulOpsInlineThreshold)
3102  break;
3103  // If we have an mul, expand the mul operands onto the end of the
3104  // operands list.
3105  Ops.erase(Ops.begin()+Idx);
3106  Ops.append(Mul->op_begin(), Mul->op_end());
3107  DeletedMul = true;
3108  }
3109 
3110  // If we deleted at least one mul, we added operands to the end of the
3111  // list, and they are not necessarily sorted. Recurse to resort and
3112  // resimplify any operands we just acquired.
3113  if (DeletedMul)
3114  return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3115  }
3116 
3117  // If there are any add recurrences in the operands list, see if any other
3118  // added values are loop invariant. If so, we can fold them into the
3119  // recurrence.
3120  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3121  ++Idx;
3122 
3123  // Scan over all recurrences, trying to fold loop invariants into them.
3124  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3125  // Scan all of the other operands to this mul and add them to the vector
3126  // if they are loop invariant w.r.t. the recurrence.
3128  const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3129  const Loop *AddRecLoop = AddRec->getLoop();
3130  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3131  if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
3132  LIOps.push_back(Ops[i]);
3133  Ops.erase(Ops.begin()+i);
3134  --i; --e;
3135  }
3136 
3137  // If we found some loop invariants, fold them into the recurrence.
3138  if (!LIOps.empty()) {
3139  // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3141  NewOps.reserve(AddRec->getNumOperands());
3142  const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3143  for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
3144  NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3145  SCEV::FlagAnyWrap, Depth + 1));
3146 
3147  // Build the new addrec. Propagate the NUW and NSW flags if both the
3148  // outer mul and the inner addrec are guaranteed to have no overflow.
3149  //
3150  // No self-wrap cannot be guaranteed after changing the step size, but
3151  // will be inferred if either NUW or NSW is true.
3152  SCEV::NoWrapFlags Flags = ComputeFlags({Scale, AddRec});
3153  const SCEV *NewRec = getAddRecExpr(
3154  NewOps, AddRecLoop, AddRec->getNoWrapFlags(Flags));
3155 
3156  // If all of the other operands were loop invariant, we are done.
3157  if (Ops.size() == 1) return NewRec;
3158 
3159  // Otherwise, multiply the folded AddRec by the non-invariant parts.
3160  for (unsigned i = 0;; ++i)
3161  if (Ops[i] == AddRec) {
3162  Ops[i] = NewRec;
3163  break;
3164  }
3165  return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3166  }
3167 
3168  // Okay, if there weren't any loop invariants to be folded, check to see
3169  // if there are multiple AddRec's with the same loop induction variable
3170  // being multiplied together. If so, we can fold them.
3171 
3172  // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3173  // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3174  // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3175  // ]]],+,...up to x=2n}.
3176  // Note that the arguments to choose() are always integers with values
3177  // known at compile time, never SCEV objects.
3178  //
3179  // The implementation avoids pointless extra computations when the two
3180  // addrec's are of different length (mathematically, it's equivalent to
3181  // an infinite stream of zeros on the right).
3182  bool OpsModified = false;
3183  for (unsigned OtherIdx = Idx+1;
3184  OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3185  ++OtherIdx) {
3186  const SCEVAddRecExpr *OtherAddRec =
3187  dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3188  if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop)
3189  continue;
3190 
3191  // Limit max number of arguments to avoid creation of unreasonably big
3192  // SCEVAddRecs with very complex operands.
3193  if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3194  MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3195  continue;
3196 
3197  bool Overflow = false;
3198  Type *Ty = AddRec->getType();
3199  bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3200  SmallVector<const SCEV*, 7> AddRecOps;
3201  for (int x = 0, xe = AddRec->getNumOperands() +
3202  OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3204  for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3205  uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3206  for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3207  ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3208  z < ze && !Overflow; ++z) {
3209  uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3210  uint64_t Coeff;
3211  if (LargerThan64Bits)
3212  Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3213  else
3214  Coeff = Coeff1*Coeff2;
3215  const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3216  const SCEV *Term1 = AddRec->getOperand(y-z);
3217  const SCEV *Term2 = OtherAddRec->getOperand(z);
3218  SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3219  SCEV::FlagAnyWrap, Depth + 1));
3220  }
3221  }
3222  if (SumOps.empty())
3223  SumOps.push_back(getZero(Ty));
3224  AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3225  }
3226  if (!Overflow) {
3227  const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRecLoop,
3229  if (Ops.size() == 2) return NewAddRec;
3230  Ops[Idx] = NewAddRec;
3231  Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3232  OpsModified = true;
3233  AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3234  if (!AddRec)
3235  break;
3236  }
3237  }
3238  if (OpsModified)
3239  return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3240 
3241  // Otherwise couldn't fold anything into this recurrence. Move onto the
3242  // next one.
3243  }
3244 
3245  // Okay, it looks like we really DO need an mul expr. Check to see if we
3246  // already have one, otherwise create a new one.
3247  return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3248 }
3249 
3250 /// Represents an unsigned remainder expression based on unsigned division.
3252  const SCEV *RHS) {
3254  getEffectiveSCEVType(RHS->getType()) &&
3255  "SCEVURemExpr operand types don't match!");
3256 
3257  // Short-circuit easy cases
3258  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3259  // If constant is one, the result is trivial
3260  if (RHSC->getValue()->isOne())
3261  return getZero(LHS->getType()); // X urem 1 --> 0
3262 
3263  // If constant is a power of two, fold into a zext(trunc(LHS)).
3264  if (RHSC->getAPInt().isPowerOf2()) {
3265  Type *FullTy = LHS->getType();
3266  Type *TruncTy =
3267  IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3268  return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3269  }
3270  }
3271 
3272  // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3273  const SCEV *UDiv = getUDivExpr(LHS, RHS);
3274  const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3275  return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3276 }
3277 
3278 /// Get a canonical unsigned division expression, or something simpler if
3279 /// possible.
3281  const SCEV *RHS) {
3282  assert(!LHS->getType()->isPointerTy() &&
3283  "SCEVUDivExpr operand can't be pointer!");
3284  assert(LHS->getType() == RHS->getType() &&
3285  "SCEVUDivExpr operand types don't match!");
3286 
3288  ID.AddInteger(scUDivExpr);
3289  ID.AddPointer(LHS);
3290  ID.AddPointer(RHS);
3291  void *IP = nullptr;
3292  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3293  return S;
3294 
3295  // 0 udiv Y == 0
3296  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3297  if (LHSC->getValue()->isZero())
3298  return LHS;
3299 
3300  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3301  if (RHSC->getValue()->isOne())
3302  return LHS; // X udiv 1 --> x
3303  // If the denominator is zero, the result of the udiv is undefined. Don't
3304  // try to analyze it, because the resolution chosen here may differ from
3305  // the resolution chosen in other parts of the compiler.
3306  if (!RHSC->getValue()->isZero()) {
3307  // Determine if the division can be folded into the operands of
3308  // its operands.
3309  // TODO: Generalize this to non-constants by using known-bits information.
3310  Type *Ty = LHS->getType();
3311  unsigned LZ = RHSC->getAPInt().countLeadingZeros();
3312  unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3313  // For non-power-of-two values, effectively round the value up to the
3314  // nearest power of two.
3315  if (!RHSC->getAPInt().isPowerOf2())
3316  ++MaxShiftAmt;
3317  IntegerType *ExtTy =
3318  IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3319  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3320  if (const SCEVConstant *Step =
3321  dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3322  // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3323  const APInt &StepInt = Step->getAPInt();
3324  const APInt &DivInt = RHSC->getAPInt();
3325  if (!StepInt.urem(DivInt) &&
3326  getZeroExtendExpr(AR, ExtTy) ==
3327  getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3328  getZeroExtendExpr(Step, ExtTy),
3329  AR->getLoop(), SCEV::FlagAnyWrap)) {
3331  for (const SCEV *Op : AR->operands())
3332  Operands.push_back(getUDivExpr(Op, RHS));
3333  return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3334  }
3335  /// Get a canonical UDivExpr for a recurrence.
3336  /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3337  // We can currently only fold X%N if X is constant.
3338  const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
3339  if (StartC && !DivInt.urem(StepInt) &&
3340  getZeroExtendExpr(AR, ExtTy) ==
3341  getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3342  getZeroExtendExpr(Step, ExtTy),
3343  AR->getLoop(), SCEV::FlagAnyWrap)) {
3344  const APInt &StartInt = StartC->getAPInt();
3345  const APInt &StartRem = StartInt.urem(StepInt);
3346  if (StartRem != 0) {
3347  const SCEV *NewLHS =
3348  getAddRecExpr(getConstant(StartInt - StartRem), Step,
3349  AR->getLoop(), SCEV::FlagNW);
3350  if (LHS != NewLHS) {
3351  LHS = NewLHS;
3352 
3353  // Reset the ID to include the new LHS, and check if it is
3354  // already cached.
3355  ID.clear();
3356  ID.AddInteger(scUDivExpr);
3357  ID.AddPointer(LHS);
3358  ID.AddPointer(RHS);
3359  IP = nullptr;
3360  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3361  return S;
3362  }
3363  }
3364  }
3365  }
3366  // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3367  if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3369  for (const SCEV *Op : M->operands())
3370  Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3371  if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3372  // Find an operand that's safely divisible.
3373  for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3374  const SCEV *Op = M->getOperand(i);
3375  const SCEV *Div = getUDivExpr(Op, RHSC);
3376  if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3377  Operands = SmallVector<const SCEV *, 4>(M->operands());
3378  Operands[i] = Div;
3379  return getMulExpr(Operands);
3380  }
3381  }
3382  }
3383 
3384  // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3385  if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3386  if (auto *DivisorConstant =
3387  dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3388  bool Overflow = false;
3389  APInt NewRHS =
3390  DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3391  if (Overflow) {
3392  return getConstant(RHSC->getType(), 0, false);
3393  }
3394  return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3395  }
3396  }
3397 
3398  // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3399  if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3401  for (const SCEV *Op : A->operands())
3402  Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3403  if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3404  Operands.clear();
3405  for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3406  const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3407  if (isa<SCEVUDivExpr>(Op) ||
3408  getMulExpr(Op, RHS) != A->getOperand(i))
3409  break;
3410  Operands.push_back(Op);
3411  }
3412  if (Operands.size() == A->getNumOperands())
3413  return getAddExpr(Operands);
3414  }
3415  }
3416 
3417  // Fold if both operands are constant.
3418  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
3419  Constant *LHSCV = LHSC->getValue();
3420  Constant *RHSCV = RHSC->getValue();
3421  return getConstant(cast<ConstantInt>(ConstantExpr::getUDiv(LHSCV,
3422  RHSCV)));
3423  }
3424  }
3425  }
3426 
3427  // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3428  // changes). Make sure we get a new one.
3429  IP = nullptr;
3430  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3431  SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3432  LHS, RHS);
3433  UniqueSCEVs.InsertNode(S, IP);
3434  addToLoopUseLists(S);
3435  return S;
3436 }
3437 
3438 static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3439  APInt A = C1->getAPInt().abs();
3440  APInt B = C2->getAPInt().abs();
3441  uint32_t ABW = A.getBitWidth();
3442  uint32_t BBW = B.getBitWidth();
3443 
3444  if (ABW > BBW)
3445  B = B.zext(ABW);
3446  else if (ABW < BBW)
3447  A = A.zext(BBW);
3448 
3450 }
3451 
3452 /// Get a canonical unsigned division expression, or something simpler if
3453 /// possible. There is no representation for an exact udiv in SCEV IR, but we
3454 /// can attempt to remove factors from the LHS and RHS. We can't do this when
3455 /// it's not exact because the udiv may be clearing bits.
3457  const SCEV *RHS) {
3458  // TODO: we could try to find factors in all sorts of things, but for now we
3459  // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3460  // end of this file for inspiration.
3461 
3462  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
3463  if (!Mul || !Mul->hasNoUnsignedWrap())
3464  return getUDivExpr(LHS, RHS);
3465 
3466  if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3467  // If the mulexpr multiplies by a constant, then that constant must be the
3468  // first element of the mulexpr.
3469  if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3470  if (LHSCst == RHSCst) {
3472  return getMulExpr(Operands);
3473  }
3474 
3475  // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3476  // that there's a factor provided by one of the other terms. We need to
3477  // check.
3478  APInt Factor = gcd(LHSCst, RHSCst);
3479  if (!Factor.isIntN(1)) {
3480  LHSCst =
3481  cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3482  RHSCst =
3483  cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3485  Operands.push_back(LHSCst);
3486  Operands.append(Mul->op_begin() + 1, Mul->op_end());
3487  LHS = getMulExpr(Operands);
3488  RHS = RHSCst;
3489  Mul = dyn_cast<SCEVMulExpr>(LHS);
3490  if (!Mul)
3491  return getUDivExactExpr(LHS, RHS);
3492  }
3493  }
3494  }
3495 
3496  for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3497  if (Mul->getOperand(i) == RHS) {
3499  Operands.append(Mul->op_begin(), Mul->op_begin() + i);
3500  Operands.append(Mul->op_begin() + i + 1, Mul->op_end());
3501  return getMulExpr(Operands);
3502  }
3503  }
3504 
3505  return getUDivExpr(LHS, RHS);
3506 }
3507 
3508 /// Get an add recurrence expression for the specified loop. Simplify the
3509 /// expression as much as possible.
3510 const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3511  const Loop *L,
3512  SCEV::NoWrapFlags Flags) {
3514  Operands.push_back(Start);
3515  if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3516  if (StepChrec->getLoop() == L) {
3517  Operands.append(StepChrec->op_begin(), StepChrec->op_end());
3518  return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3519  }
3520 
3521  Operands.push_back(Step);
3522  return getAddRecExpr(Operands, L, Flags);
3523 }
3524 
3525 /// Get an add recurrence expression for the specified loop. Simplify the
3526 /// expression as much as possible.
3527 const SCEV *
3529  const Loop *L, SCEV::NoWrapFlags Flags) {
3530  if (Operands.size() == 1) return Operands[0];
3531 #ifndef NDEBUG
3532  Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3533  for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
3535  "SCEVAddRecExpr operand types don't match!");
3536  assert(!Operands[i]->getType()->isPointerTy() && "Step must be integer");
3537  }
3538  for (unsigned i = 0, e = Operands.size(); i != e; ++i)
3540  "SCEVAddRecExpr operand is not loop-invariant!");
3541 #endif
3542 
3543  if (Operands.back()->isZero()) {
3544  Operands.pop_back();
3545  return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3546  }
3547 
3548  // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3549  // use that information to infer NUW and NSW flags. However, computing a
3550  // BE count requires calling getAddRecExpr, so we may not yet have a
3551  // meaningful BE count at this point (and if we don't, we'd be stuck
3552  // with a SCEVCouldNotCompute as the cached BE count).
3553 
3554  Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3555 
3556  // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3557  if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3558  const Loop *NestedLoop = NestedAR->getLoop();
3559  if (L->contains(NestedLoop)
3560  ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3561  : (!NestedLoop->contains(L) &&
3562  DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3563  SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3564  Operands[0] = NestedAR->getStart();
3565  // AddRecs require their operands be loop-invariant with respect to their
3566  // loops. Don't perform this transformation if it would break this
3567  // requirement.
3568  bool AllInvariant = all_of(
3569  Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3570 
3571  if (AllInvariant) {
3572  // Create a recurrence for the outer loop with the same step size.
3573  //
3574  // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3575  // inner recurrence has the same property.
3576  SCEV::NoWrapFlags OuterFlags =
3577  maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3578 
3579  NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3580  AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3581  return isLoopInvariant(Op, NestedLoop);
3582  });
3583 
3584  if (AllInvariant) {
3585  // Ok, both add recurrences are valid after the transformation.
3586  //
3587  // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3588  // the outer recurrence has the same property.
3589  SCEV::NoWrapFlags InnerFlags =
3590  maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3591  return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3592  }
3593  }
3594  // Reset Operands to its original state.
3595  Operands[0] = NestedAR;
3596  }
3597  }
3598 
3599  // Okay, it looks like we really DO need an addrec expr. Check to see if we
3600  // already have one, otherwise create a new one.
3601  return getOrCreateAddRecExpr(Operands, L, Flags);
3602 }
3603 
3604 const SCEV *
3606  const SmallVectorImpl<const SCEV *> &IndexExprs) {
3607  const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3608  // getSCEV(Base)->getType() has the same address space as Base->getType()
3609  // because SCEV::getType() preserves the address space.
3610  Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3611  // FIXME(PR23527): Don't blindly transfer the inbounds flag from the GEP
3612  // instruction to its SCEV, because the Instruction may be guarded by control
3613  // flow and the no-overflow bits may not be valid for the expression in any
3614  // context. This can be fixed similarly to how these flags are handled for
3615  // adds.
3616  SCEV::NoWrapFlags OffsetWrap =
3617  GEP->isInBounds() ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
3618 
3619  Type *CurTy = GEP->getType();
3620  bool FirstIter = true;
3622  for (const SCEV *IndexExpr : IndexExprs) {
3623  // Compute the (potentially symbolic) offset in bytes for this index.
3624  if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3625  // For a struct, add the member offset.
3626  ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3627  unsigned FieldNo = Index->getZExtValue();
3628  const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3629  Offsets.push_back(FieldOffset);
3630 
3631  // Update CurTy to the type of the field at Index.
3632  CurTy = STy->getTypeAtIndex(Index);
3633  } else {
3634  // Update CurTy to its element type.
3635  if (FirstIter) {
3636  assert(isa<PointerType>(CurTy) &&
3637  "The first index of a GEP indexes a pointer");
3638  CurTy = GEP->getSourceElementType();
3639  FirstIter = false;
3640  } else {
3641  CurTy = GetElementPtrInst::getTypeAtIndex(CurTy, (uint64_t)0);
3642  }
3643  // For an array, add the element offset, explicitly scaled.
3644  const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3645  // Getelementptr indices are signed.
3646  IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3647 
3648  // Multiply the index by the element size to compute the element offset.
3649  const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3650  Offsets.push_back(LocalOffset);
3651  }
3652  }
3653 
3654  // Handle degenerate case of GEP without offsets.
3655  if (Offsets.empty())
3656  return BaseExpr;
3657 
3658  // Add the offsets together, assuming nsw if inbounds.
3659  const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3660  // Add the base address and the offset. We cannot use the nsw flag, as the
3661  // base address is unsigned. However, if we know that the offset is
3662  // non-negative, we can use nuw.
3663  SCEV::NoWrapFlags BaseWrap = GEP->isInBounds() && isKnownNonNegative(Offset)
3665  auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3666  assert(BaseExpr->getType() == GEPExpr->getType() &&
3667  "GEP should not change type mid-flight.");
3668  return GEPExpr;
3669 }
3670 
3671 SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3672  ArrayRef<const SCEV *> Ops) {
3674  ID.AddInteger(SCEVType);
3675  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3676  ID.AddPointer(Ops[i]);
3677  void *IP = nullptr;
3678  return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3679 }
3680 
3681 const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3683  return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3684 }
3685 
3688  assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3689  if (Ops.size() == 1) return Ops[0];
3690 #ifndef NDEBUG
3691  Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3692  for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3693  assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3694  "Operand types don't match!");
3695  assert(Ops[0]->getType()->isPointerTy() ==
3696  Ops[i]->getType()->isPointerTy() &&
3697  "min/max should be consistently pointerish");
3698  }
3699 #endif
3700 
3701  bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3702  bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3703 
3704  // Sort by complexity, this groups all similar expression types together.
3705  GroupByComplexity(Ops, &LI, DT);
3706 
3707  // Check if we have created the same expression before.
3708  if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3709  return S;
3710  }
3711 
3712  // If there are any constants, fold them together.
3713  unsigned Idx = 0;
3714  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3715  ++Idx;
3716  assert(Idx < Ops.size());
3717  auto FoldOp = [&](const APInt &LHS, const APInt &RHS) {
3718  if (Kind == scSMaxExpr)
3719  return APIntOps::smax(LHS, RHS);
3720  else if (Kind == scSMinExpr)
3721  return APIntOps::smin(LHS, RHS);
3722  else if (Kind == scUMaxExpr)
3723  return APIntOps::umax(LHS, RHS);
3724  else if (Kind == scUMinExpr)
3725  return APIntOps::umin(LHS, RHS);
3726  llvm_unreachable("Unknown SCEV min/max opcode");
3727  };
3728 
3729  while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3730  // We found two constants, fold them together!
3731  ConstantInt *Fold = ConstantInt::get(
3732  getContext(), FoldOp(LHSC->getAPInt(), RHSC->getAPInt()));
3733  Ops[0] = getConstant(Fold);
3734  Ops.erase(Ops.begin()+1); // Erase the folded element
3735  if (Ops.size() == 1) return Ops[0];
3736  LHSC = cast<SCEVConstant>(Ops[0]);
3737  }
3738 
3739  bool IsMinV = LHSC->getValue()->isMinValue(IsSigned);
3740  bool IsMaxV = LHSC->getValue()->isMaxValue(IsSigned);
3741 
3742  if (IsMax ? IsMinV : IsMaxV) {
3743  // If we are left with a constant minimum(/maximum)-int, strip it off.
3744  Ops.erase(Ops.begin());
3745  --Idx;
3746  } else if (IsMax ? IsMaxV : IsMinV) {
3747  // If we have a max(/min) with a constant maximum(/minimum)-int,
3748  // it will always be the extremum.
3749  return LHSC;
3750  }
3751 
3752  if (Ops.size() == 1) return Ops[0];
3753  }
3754 
3755  // Find the first operation of the same kind
3756  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3757  ++Idx;
3758 
3759  // Check to see if one of the operands is of the same kind. If so, expand its
3760  // operands onto our operand list, and recurse to simplify.
3761  if (Idx < Ops.size()) {
3762  bool DeletedAny = false;
3763  while (Ops[Idx]->getSCEVType() == Kind) {
3764  const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3765  Ops.erase(Ops.begin()+Idx);
3766  Ops.append(SMME->op_begin(), SMME->op_end());
3767  DeletedAny = true;
3768  }
3769 
3770  if (DeletedAny)
3771  return getMinMaxExpr(Kind, Ops);
3772  }
3773 
3774  // Okay, check to see if the same value occurs in the operand list twice. If
3775  // so, delete one. Since we sorted the list, these values are required to
3776  // be adjacent.
3777  llvm::CmpInst::Predicate GEPred =
3779  llvm::CmpInst::Predicate LEPred =
3781  llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3782  llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3783  for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3784  if (Ops[i] == Ops[i + 1] ||
3785  isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3786  // X op Y op Y --> X op Y
3787  // X op Y --> X, if we know X, Y are ordered appropriately
3788  Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3789  --i;
3790  --e;
3791  } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3792  Ops[i + 1])) {
3793  // X op Y --> Y, if we know X, Y are ordered appropriately
3794  Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3795  --i;
3796  --e;
3797  }
3798  }
3799 
3800  if (Ops.size() == 1) return Ops[0];
3801 
3802  assert(!Ops.empty() && "Reduced smax down to nothing!");
3803 
3804  // Okay, it looks like we really DO need an expr. Check to see if we
3805  // already have one, otherwise create a new one.
3807  ID.AddInteger(Kind);
3808  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3809  ID.AddPointer(Ops[i]);
3810  void *IP = nullptr;
3811  const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3812  if (ExistingSCEV)
3813  return ExistingSCEV;
3814  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3815  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3816  SCEV *S = new (SCEVAllocator)
3817  SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3818 
3819  UniqueSCEVs.InsertNode(S, IP);
3820  addToLoopUseLists(S);
3821  return S;
3822 }
3823 
3824 const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
3825  SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
3826  return getSMaxExpr(Ops);
3827 }
3828 
3830  return getMinMaxExpr(scSMaxExpr, Ops);
3831 }
3832 
3833 const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
3834  SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
3835  return getUMaxExpr(Ops);
3836 }
3837 
3839  return getMinMaxExpr(scUMaxExpr, Ops);
3840 }
3841 
3843  const SCEV *RHS) {
3844  SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
3845  return getSMinExpr(Ops);
3846 }
3847 
3849  return getMinMaxExpr(scSMinExpr, Ops);
3850 }
3851 
3853  const SCEV *RHS) {
3854  SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
3855  return getUMinExpr(Ops);
3856 }
3857 
3859  return getMinMaxExpr(scUMinExpr, Ops);
3860 }
3861 
3862 const SCEV *
3864  ScalableVectorType *ScalableTy) {
3865  Constant *NullPtr = Constant::getNullValue(ScalableTy->getPointerTo());
3866  Constant *One = ConstantInt::get(IntTy, 1);
3867  Constant *GEP = ConstantExpr::getGetElementPtr(ScalableTy, NullPtr, One);
3868  // Note that the expression we created is the final expression, we don't
3869  // want to simplify it any further Also, if we call a normal getSCEV(),
3870  // we'll end up in an endless recursion. So just create an SCEVUnknown.
3871  return getUnknown(ConstantExpr::getPtrToInt(GEP, IntTy));
3872 }
3873 
3874 const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) {
3875  if (auto *ScalableAllocTy = dyn_cast<ScalableVectorType>(AllocTy))
3876  return getSizeOfScalableVectorExpr(IntTy, ScalableAllocTy);
3877  // We can bypass creating a target-independent constant expression and then
3878  // folding it back into a ConstantInt. This is just a compile-time
3879  // optimization.
3880  return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
3881 }
3882 
3884  if (auto *ScalableStoreTy = dyn_cast<ScalableVectorType>(StoreTy))
3885  return getSizeOfScalableVectorExpr(IntTy, ScalableStoreTy);
3886  // We can bypass creating a target-independent constant expression and then
3887  // folding it back into a ConstantInt. This is just a compile-time
3888  // optimization.
3889  return getConstant(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
3890 }
3891 
3893  StructType *STy,
3894  unsigned FieldNo) {
3895  // We can bypass creating a target-independent constant expression and then
3896  // folding it back into a ConstantInt. This is just a compile-time
3897  // optimization.
3898  return getConstant(
3899  IntTy, getDataLayout().getStructLayout(STy)->getElementOffset(FieldNo));
3900 }
3901 
3903  // Don't attempt to do anything other than create a SCEVUnknown object
3904  // here. createSCEV only calls getUnknown after checking for all other
3905  // interesting possibilities, and any other code that calls getUnknown
3906  // is doing so in order to hide a value from SCEV canonicalization.
3907 
3909  ID.AddInteger(scUnknown);
3910  ID.AddPointer(V);
3911  void *IP = nullptr;
3912  if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
3913  assert(cast<SCEVUnknown>(S)->getValue() == V &&
3914  "Stale SCEVUnknown in uniquing map!");
3915  return S;
3916  }
3917  SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
3918  FirstUnknown);
3919  FirstUnknown = cast<SCEVUnknown>(S);
3920  UniqueSCEVs.InsertNode(S, IP);
3921  return S;
3922 }
3923 
3924 //===----------------------------------------------------------------------===//
3925 // Basic SCEV Analysis and PHI Idiom Recognition Code
3926 //
3927 
3928 /// Test if values of the given type are analyzable within the SCEV
3929 /// framework. This primarily includes integer types, and it can optionally
3930 /// include pointer types if the ScalarEvolution class has access to
3931 /// target-specific information.
3933  // Integers and pointers are always SCEVable.
3934  return Ty->isIntOrPtrTy();
3935 }
3936 
3937 /// Return the size in bits of the specified type, for which isSCEVable must
3938 /// return true.
3940  assert(isSCEVable(Ty) && "Type is not SCEVable!");
3941  if (Ty->isPointerTy())
3943  return getDataLayout().getTypeSizeInBits(Ty);
3944 }
3945 
3946 /// Return a type with the same bitwidth as the given type and which represents
3947 /// how SCEV will treat the given type, for which isSCEVable must return
3948 /// true. For pointer types, this is the pointer index sized integer type.
3950  assert(isSCEVable(Ty) && "Type is not SCEVable!");
3951 
3952  if (Ty->isIntegerTy())
3953  return Ty;
3954 
3955  // The only other support type is pointer.
3956  assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
3957  return getDataLayout().getIndexType(Ty);
3958 }
3959 
3961  return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
3962 }
3963 
3965  return CouldNotCompute.get();
3966 }
3967 
3968 bool ScalarEvolution::checkValidity(const SCEV *S) const {
3969  bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
3970  auto *SU = dyn_cast<SCEVUnknown>(S);
3971  return SU && SU->getValue() == nullptr;
3972  });
3973 
3974  return !ContainsNulls;
3975 }
3976 
3978  HasRecMapType::iterator I = HasRecMap.find(S);
3979  if (I != HasRecMap.end())
3980  return I->second;
3981 
3982  bool FoundAddRec =
3983  SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
3984  HasRecMap.insert({S, FoundAddRec});
3985  return FoundAddRec;
3986 }
3987 
3988 /// Try to split a SCEVAddExpr into a pair of {SCEV, ConstantInt}.
3989 /// If \p S is a SCEVAddExpr and is composed of a sub SCEV S' and an
3990 /// offset I, then return {S', I}, else return {\p S, nullptr}.
3991 static std::pair<const SCEV *, ConstantInt *> splitAddExpr(const SCEV *S) {
3992  const auto *Add = dyn_cast<SCEVAddExpr>(S);
3993  if (!Add)
3994  return {S, nullptr};
3995 
3996  if (Add->getNumOperands() != 2)
3997  return {S, nullptr};
3998 
3999  auto *ConstOp = dyn_cast<SCEVConstant>(Add->getOperand(0));
4000  if (!ConstOp)
4001  return {S, nullptr};
4002 
4003  return {Add->getOperand(1), ConstOp->getValue()};
4004 }
4005 
4006 /// Return the ValueOffsetPair set for \p S. \p S can be represented
4007 /// by the value and offset from any ValueOffsetPair in the set.
4009 ScalarEvolution::getSCEVValues(const SCEV *S) {
4010  ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4011  if (SI == ExprValueMap.end())
4012  return nullptr;
4013 #ifndef NDEBUG
4014  if (VerifySCEVMap) {
4015  // Check there is no dangling Value in the set returned.
4016  for (const auto &VE : SI->second)
4017  assert(ValueExprMap.count(VE.first));
4018  }
4019 #endif
4020  return &SI->second;
4021 }
4022 
4023 /// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4024 /// cannot be used separately. eraseValueFromMap should be used to remove
4025 /// V from ValueExprMap and ExprValueMap at the same time.
4027  ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4028  if (I != ValueExprMap.end()) {
4029  const SCEV *S = I->second;
4030  // Remove {V, 0} from the set of ExprValueMap[S]
4031  if (auto *SV = getSCEVValues(S))
4032  SV->remove({V, nullptr});
4033 
4034  // Remove {V, Offset} from the set of ExprValueMap[Stripped]
4035  const SCEV *Stripped;
4037  std::tie(Stripped, Offset) = splitAddExpr(S);
4038  if (Offset != nullptr) {
4039  if (auto *SV = getSCEVValues(Stripped))
4040  SV->remove({V, Offset});
4041  }
4042  ValueExprMap.erase(V);
4043  }
4044 }
4045 
4046 /// Check whether value has nuw/nsw/exact set but SCEV does not.
4047 /// TODO: In reality it is better to check the poison recursively
4048 /// but this is better than nothing.
4049 static bool SCEVLostPoisonFlags(const SCEV *S, const Value *V) {
4050  if (auto *I = dyn_cast<Instruction>(V)) {
4051  if (isa<OverflowingBinaryOperator>(I)) {
4052  if (auto *NS = dyn_cast<SCEVNAryExpr>(S)) {
4053  if (I->hasNoSignedWrap() && !NS->hasNoSignedWrap())
4054  return true;
4055  if (I->hasNoUnsignedWrap() && !NS->hasNoUnsignedWrap())
4056  return true;
4057  }
4058  } else if (isa<PossiblyExactOperator>(I) && I->isExact())
4059  return true;
4060  }
4061  return false;
4062 }
4063 
4064 /// Return an existing SCEV if it exists, otherwise analyze the expression and
4065 /// create a new one.
4067  assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4068 
4069  const SCEV *S = getExistingSCEV(V);
4070  if (S == nullptr) {
4071  S = createSCEV(V);
4072  // During PHI resolution, it is possible to create two SCEVs for the same
4073  // V, so it is needed to double check whether V->S is inserted into
4074  // ValueExprMap before insert S->{V, 0} into ExprValueMap.
4075  std::pair<ValueExprMapType::iterator, bool> Pair =
4076  ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4077  if (Pair.second && !SCEVLostPoisonFlags(S, V)) {
4078  ExprValueMap[S].insert({V, nullptr});
4079 
4080  // If S == Stripped + Offset, add Stripped -> {V, Offset} into
4081  // ExprValueMap.
4082  const SCEV *Stripped = S;
4083  ConstantInt *Offset = nullptr;
4084  std::tie(Stripped, Offset) = splitAddExpr(S);
4085  // If stripped is SCEVUnknown, don't bother to save
4086  // Stripped -> {V, offset}. It doesn't simplify and sometimes even
4087  // increase the complexity of the expansion code.
4088  // If V is GetElementPtrInst, don't save Stripped -> {V, offset}
4089  // because it may generate add/sub instead of GEP in SCEV expansion.
4090  if (Offset != nullptr && !isa<SCEVUnknown>(Stripped) &&
4091  !isa<GetElementPtrInst>(V))
4092  ExprValueMap[Stripped].insert({V, Offset});
4093  }
4094  }
4095  return S;
4096 }
4097 
4098 const SCEV *ScalarEvolution::getExistingSCEV(Value *V) {
4099  assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4100 
4101  ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4102  if (I != ValueExprMap.end()) {
4103  const SCEV *S = I->second;
4104  if (checkValidity(S))
4105  return S;
4106  eraseValueFromMap(V);
4107  forgetMemoizedResults(S);
4108  }
4109  return nullptr;
4110 }
4111 
4112 /// Return a SCEV corresponding to -V = -1*V
4114  SCEV::NoWrapFlags Flags) {
4115  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4116  return getConstant(
4117  cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4118 
4119  Type *Ty = V->getType();
4120  Ty = getEffectiveSCEVType(Ty);
4121  return getMulExpr(V, getMinusOne(Ty), Flags);
4122 }
4123 
4124 /// If Expr computes ~A, return A else return nullptr
4125 static const SCEV *MatchNotExpr(const SCEV *Expr) {
4126  const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
4127  if (!Add || Add->getNumOperands() != 2 ||
4128  !Add->getOperand(0)->isAllOnesValue())
4129  return nullptr;
4130 
4131  const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
4132  if (!AddRHS || AddRHS->getNumOperands() != 2 ||
4133  !AddRHS->getOperand(0)->isAllOnesValue())
4134  return nullptr;
4135 
4136  return AddRHS->getOperand(1);
4137 }
4138 
4139 /// Return a SCEV corresponding to ~V = -1-V
4141  assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4142 
4143  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4144  return getConstant(
4145  cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4146 
4147  // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4148  if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4149  auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4150  SmallVector<const SCEV *, 2> MatchedOperands;
4151  for (const SCEV *Operand : MME->operands()) {
4152  const SCEV *Matched = MatchNotExpr(Operand);
4153  if (!Matched)
4154  return (const SCEV *)nullptr;
4155  MatchedOperands.push_back(Matched);
4156  }
4157  return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4158  MatchedOperands);
4159  };
4160  if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4161  return Replaced;
4162  }
4163 
4164  Type *Ty = V->getType();
4165  Ty = getEffectiveSCEVType(Ty);
4166  return getMinusSCEV(getMinusOne(Ty), V);
4167 }
4168 
4170  assert(P->getType()->isPointerTy());
4171 
4172  if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4173  // The base of an AddRec is the first operand.
4174  SmallVector<const SCEV *> Ops{AddRec->operands()};
4175  Ops[0] = removePointerBase(Ops[0]);
4176  // Don't try to transfer nowrap flags for now. We could in some cases
4177  // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4178  return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4179  }
4180  if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4181  // The base of an Add is the pointer operand.
4182  SmallVector<const SCEV *> Ops{Add->operands()};
4183  const SCEV **PtrOp = nullptr;
4184  for (const SCEV *&AddOp : Ops) {
4185  if (AddOp->getType()->isPointerTy()) {
4186  assert(!PtrOp && "Cannot have multiple pointer ops");
4187  PtrOp = &AddOp;
4188  }
4189  }
4190  *PtrOp = removePointerBase(*PtrOp);
4191  // Don't try to transfer nowrap flags for now. We could in some cases
4192  // (for example, if the pointer operand of the Add is a SCEVUnknown).
4193  return getAddExpr(Ops);
4194  }
4195  // Any other expression must be a pointer base.
4196  return getZero(P->getType());
4197 }
4198 
4199 const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4200  SCEV::NoWrapFlags Flags,
4201  unsigned Depth) {
4202  // Fast path: X - X --> 0.
4203  if (LHS == RHS)
4204  return getZero(LHS->getType());
4205 
4206  // If we subtract two pointers with different pointer bases, bail.
4207  // Eventually, we're going to add an assertion to getMulExpr that we
4208  // can't multiply by a pointer.
4209  if (RHS->getType()->isPointerTy()) {
4210  if (!LHS->getType()->isPointerTy() ||
4211  getPointerBase(LHS) != getPointerBase(RHS))
4212  return getCouldNotCompute();
4213  LHS = removePointerBase(LHS);
4214  RHS = removePointerBase(RHS);
4215  }
4216 
4217  // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4218  // makes it so that we cannot make much use of NUW.
4219  auto AddFlags = SCEV::FlagAnyWrap;
4220  const bool RHSIsNotMinSigned =
4222  if (hasFlags(Flags, SCEV::FlagNSW)) {
4223  // Let M be the minimum representable signed value. Then (-1)*RHS
4224  // signed-wraps if and only if RHS is M. That can happen even for
4225  // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4226  // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4227  // (-1)*RHS, we need to prove that RHS != M.
4228  //
4229  // If LHS is non-negative and we know that LHS - RHS does not
4230  // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4231  // either by proving that RHS > M or that LHS >= 0.
4232  if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4233  AddFlags = SCEV::FlagNSW;
4234  }
4235  }
4236 
4237  // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4238  // RHS is NSW and LHS >= 0.
4239  //
4240  // The difficulty here is that the NSW flag may have been proven
4241  // relative to a loop that is to be found in a recurrence in LHS and
4242  // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4243  // larger scope than intended.
4244  auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4245 
4246  return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4247 }
4248 
4250  unsigned Depth) {
4251  Type *SrcTy = V->getType();
4252  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4253  "Cannot truncate or zero extend with non-integer arguments!");
4254  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4255  return V; // No conversion
4256  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4257  return getTruncateExpr(V, Ty, Depth);
4258  return getZeroExtendExpr(V, Ty, Depth);
4259 }
4260 
4262  unsigned Depth) {
4263  Type *SrcTy = V->getType();
4264  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4265  "Cannot truncate or zero extend with non-integer arguments!");
4266  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4267  return V; // No conversion
4268  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4269  return getTruncateExpr(V, Ty, Depth);
4270  return getSignExtendExpr(V, Ty, Depth);
4271 }
4272 
4273 const SCEV *
4275  Type *SrcTy = V->getType();
4276  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4277  "Cannot noop or zero extend with non-integer arguments!");
4278  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
4279  "getNoopOrZeroExtend cannot truncate!");
4280  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4281  return V; // No conversion
4282  return getZeroExtendExpr(V, Ty);
4283 }
4284 
4285 const SCEV *
4287  Type *SrcTy = V->getType();
4288  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4289  "Cannot noop or sign extend with non-integer arguments!");
4290  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
4291  "getNoopOrSignExtend cannot truncate!");
4292  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4293  return V; // No conversion
4294  return getSignExtendExpr(V, Ty);
4295 }
4296 
4297 const SCEV *
4299  Type *SrcTy = V->getType();
4300  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4301  "Cannot noop or any extend with non-integer arguments!");
4302  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
4303  "getNoopOrAnyExtend cannot truncate!");
4304  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4305  return V; // No conversion
4306  return getAnyExtendExpr(V, Ty);
4307 }
4308 
4309 const SCEV *
4311  Type *SrcTy = V->getType();
4312  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4313  "Cannot truncate or noop with non-integer arguments!");
4314  assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) &&
4315  "getTruncateOrNoop cannot extend!");
4316  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4317  return V; // No conversion
4318  return getTruncateExpr(V, Ty);
4319 }
4320 
4322  const SCEV *RHS) {
4323  const SCEV *PromotedLHS = LHS;
4324  const SCEV *PromotedRHS = RHS;
4325 
4326  if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4327  PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4328  else
4329  PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4330 
4331  return getUMaxExpr(PromotedLHS, PromotedRHS);
4332 }
4333 
4335  const SCEV *RHS) {
4336  SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4337  return getUMinFromMismatchedTypes(Ops);
4338 }
4339 
4342  assert(!Ops.empty() && "At least one operand must be!");
4343  // Trivial case.
4344  if (Ops.size() == 1)
4345  return Ops[0];
4346 
4347  // Find the max type first.
4348  Type *MaxType = nullptr;
4349  for (auto *S : Ops)
4350  if (MaxType)
4351  MaxType = getWiderType(MaxType, S->getType());
4352  else
4353  MaxType = S->getType();
4354  assert(MaxType && "Failed to find maximum type!");
4355 
4356  // Extend all ops to max type.
4357  SmallVector<const SCEV *, 2> PromotedOps;
4358  for (auto *S : Ops)
4359  PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4360 
4361  // Generate umin.
4362  return getUMinExpr(PromotedOps);
4363 }
4364 
4366  // A pointer operand may evaluate to a nonpointer expression, such as null.
4367  if (!V->getType()->isPointerTy())
4368  return V;
4369 
4370  while (true) {
4371  if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4372  V = AddRec->getStart();
4373  } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4374  const SCEV *PtrOp = nullptr;
4375  for (const SCEV *AddOp : Add->operands()) {
4376  if (AddOp->getType()->isPointerTy()) {
4377  assert(!PtrOp && "Cannot have multiple pointer ops");
4378  PtrOp = AddOp;
4379  }
4380  }
4381  assert(PtrOp && "Must have pointer op");
4382  V = PtrOp;
4383  } else // Not something we can look further into.
4384  return V;
4385  }
4386 }
4387 
4388 /// Push users of the given Instruction onto the given Worklist.
4389 static void
4391  SmallVectorImpl<Instruction *> &Worklist) {
4392  // Push the def-use children onto the Worklist stack.
4393  for (User *U : I->users())
4394  Worklist.push_back(cast<Instruction>(U));
4395 }
4396 
4397 void ScalarEvolution::forgetSymbolicName(Instruction *PN, const SCEV *SymName) {
4399  PushDefUseChildren(PN, Worklist);
4400 
4402  Visited.insert(PN);
4403  while (!Worklist.empty()) {
4404  Instruction *I = Worklist.pop_back_val();
4405  if (!Visited.insert(I).second)
4406  continue;
4407 
4408  auto It = ValueExprMap.find_as(static_cast<Value *>(I));
4409  if (It != ValueExprMap.end()) {
4410  const SCEV *Old = It->second;
4411 
4412  // Short-circuit the def-use traversal if the symbolic name
4413  // ceases to appear in expressions.
4414  if (Old != SymName && !hasOperand(Old, SymName))
4415  continue;
4416 
4417  // SCEVUnknown for a PHI either means that it has an unrecognized
4418  // structure, it's a PHI that's in the progress of being computed
4419  // by createNodeForPHI, or it's a single-value PHI. In the first case,
4420  // additional loop trip count information isn't going to change anything.
4421  // In the second case, createNodeForPHI will perform the necessary
4422  // updates on its own when it gets to that point. In the third, we do
4423  // want to forget the SCEVUnknown.
4424  if (!isa<PHINode>(I) ||
4425  !isa<SCEVUnknown>(Old) ||
4426  (I != PN && Old == SymName)) {
4427  eraseValueFromMap(It->first);
4428  forgetMemoizedResults(Old);
4429  }
4430  }
4431 
4432  PushDefUseChildren(I, Worklist);
4433  }
4434 }
4435 
4436 namespace {
4437 
4438 /// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4439 /// expression in case its Loop is L. If it is not L then
4440 /// if IgnoreOtherLoops is true then use AddRec itself
4441 /// otherwise rewrite cannot be done.
4442 /// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4443 class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4444 public:
4445  static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4446  bool IgnoreOtherLoops = true) {
4447  SCEVInitRewriter Rewriter(L, SE);
4448  const SCEV *Result = Rewriter.visit(S);
4449  if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4450  return SE.getCouldNotCompute();
4451  return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4452  ? SE.getCouldNotCompute()
4453  : Result;
4454  }
4455 
4456  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4457  if (!SE.isLoopInvariant(Expr, L))
4458  SeenLoopVariantSCEVUnknown = true;
4459  return Expr;
4460  }
4461 
4462  const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4463  // Only re-write AddRecExprs for this loop.
4464  if (Expr->getLoop() == L)
4465  return Expr->getStart();
4466  SeenOtherLoops = true;
4467  return Expr;
4468  }
4469 
4470  bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4471 
4472  bool hasSeenOtherLoops() { return SeenOtherLoops; }
4473 
4474 private:
4475  explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4476  : SCEVRewriteVisitor(SE), L(L) {}
4477 
4478  const Loop *L;
4479  bool SeenLoopVariantSCEVUnknown = false;
4480  bool SeenOtherLoops = false;
4481 };
4482 
4483 /// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4484 /// increment expression in case its Loop is L. If it is not L then
4485 /// use AddRec itself.
4486 /// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4487 class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4488 public:
4489  static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4490  SCEVPostIncRewriter Rewriter(L, SE);
4491  const SCEV *Result = Rewriter.visit(S);
4492  return Rewriter.hasSeenLoopVariantSCEVUnknown()
4493  ? SE.getCouldNotCompute()
4494  : Result;
4495  }
4496 
4497  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4498  if (!SE.isLoopInvariant(Expr, L))
4499  SeenLoopVariantSCEVUnknown = true;
4500  return Expr;
4501  }
4502 
4503  const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4504  // Only re-write AddRecExprs for this loop.
4505  if (Expr->getLoop() == L)
4506  return Expr->getPostIncExpr(SE);
4507  SeenOtherLoops = true;
4508  return Expr;
4509  }
4510 
4511  bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4512 
4513  bool hasSeenOtherLoops() { return SeenOtherLoops; }
4514 
4515 private:
4516  explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
4517  : SCEVRewriteVisitor(SE), L(L) {}
4518 
4519  const Loop *L;
4520  bool SeenLoopVariantSCEVUnknown = false;
4521  bool SeenOtherLoops = false;
4522 };
4523 
4524 /// This class evaluates the compare condition by matching it against the
4525 /// condition of loop latch. If there is a match we assume a true value
4526 /// for the condition while building SCEV nodes.
4527 class SCEVBackedgeConditionFolder
4528  : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
4529 public:
4530  static const SCEV *rewrite(const SCEV *S, const Loop *L,
4531  ScalarEvolution &SE) {
4532  bool IsPosBECond = false;
4533  Value *BECond = nullptr;
4534  if (BasicBlock *Latch = L->getLoopLatch()) {
4535  BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
4536  if (BI && BI->isConditional()) {
4537  assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
4538  "Both outgoing branches should not target same header!");
4539  BECond = BI->getCondition();
4540  IsPosBECond = BI->getSuccessor(0) == L->getHeader();
4541  } else {
4542  return S;
4543  }
4544  }
4545  SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
4546  return Rewriter.visit(S);
4547  }
4548 
4549  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4550  const SCEV *Result = Expr;
4551  bool InvariantF = SE.isLoopInvariant(Expr, L);
4552 
4553  if (!InvariantF) {
4554  Instruction *I = cast<Instruction>(Expr->getValue());
4555  switch (I->getOpcode()) {
4556  case Instruction::Select: {
4557  SelectInst *SI = cast<SelectInst>(I);
4559  compareWithBackedgeCondition(SI->getCondition());
4560  if (Res.hasValue()) {
4561  bool IsOne = cast<SCEVConstant>(Res.getValue())->getValue()->isOne();
4562  Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
4563  }
4564  break;
4565  }
4566  default: {
4567  Optional<const SCEV *> Res = compareWithBackedgeCondition(I);
4568  if (Res.hasValue())
4569  Result = Res.getValue();
4570  break;
4571  }
4572  }
4573  }
4574  return Result;
4575  }
4576 
4577 private:
4578  explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
4579  bool IsPosBECond, ScalarEvolution &SE)
4580  : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
4581  IsPositiveBECond(IsPosBECond) {}
4582 
4583  Optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
4584 
4585  const Loop *L;
4586  /// Loop back condition.
4587  Value *BackedgeCond = nullptr;
4588  /// Set to true if loop back is on positive branch condition.
4589  bool IsPositiveBECond;
4590 };
4591 
4593 SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
4594 
4595  // If value matches the backedge condition for loop latch,
4596  // then return a constant evolution node based on loopback
4597  // branch taken.
4598  if (BackedgeCond == IC)
4599  return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
4600  : SE.getZero(Type::getInt1Ty(SE.getContext()));
4601  return None;
4602 }
4603 
4604 class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
4605 public:
4606  static const SCEV *rewrite(const SCEV *S, const Loop *L,
4607  ScalarEvolution &SE) {
4608  SCEVShiftRewriter Rewriter(L, SE);
4609  const SCEV *Result = Rewriter.visit(S);
4610  return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
4611  }
4612 
4613  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4614  // Only allow AddRecExprs for this loop.
4615  if (!SE.isLoopInvariant(Expr, L))
4616  Valid = false;
4617  return Expr;
4618  }
4619 
4620  const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4621  if (Expr->getLoop() == L && Expr->isAffine())
4622  return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
4623  Valid = false;
4624  return Expr;
4625  }
4626 
4627  bool isValid() { return Valid; }
4628 
4629 private:
4630  explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
4631  : SCEVRewriteVisitor(SE), L(L) {}
4632 
4633  const Loop *L;
4634  bool Valid = true;
4635 };
4636 
4637 } // end anonymous namespace
4638 
4640 ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
4641  if (!AR->isAffine())
4642  return SCEV::FlagAnyWrap;
4643 
4644  using OBO = OverflowingBinaryOperator;
4645 
4647 
4648  if (!AR->hasNoSignedWrap()) {
4649  ConstantRange AddRecRange = getSignedRange(AR);
4650  ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
4651 
4653  Instruction::Add, IncRange, OBO::NoSignedWrap);
4654  if (NSWRegion.contains(AddRecRange))
4656  }
4657 
4658  if (!AR->hasNoUnsignedWrap()) {
4659  ConstantRange AddRecRange = getUnsignedRange(AR);
4660  ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
4661 
4663  Instruction::Add, IncRange, OBO::NoUnsignedWrap);
4664  if (NUWRegion.contains(AddRecRange))
4666  }
4667 
4668  return Result;
4669 }
4670 
4672 ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
4674 
4675  if (AR->hasNoSignedWrap())
4676  return Result;
4677 
4678  if (!AR->isAffine())
4679  return Result;
4680 
4681  const SCEV *Step = AR->getStepRecurrence(*this);
4682  const Loop *L = AR->getLoop();
4683 
4684  // Check whether the backedge-taken count is SCEVCouldNotCompute.
4685  // Note that this serves two purposes: It filters out loops that are
4686  // simply not analyzable, and it covers the case where this code is
4687  // being called from within backedge-taken count analysis, such that
4688  // attempting to ask for the backedge-taken count would likely result
4689  // in infinite recursion. In the later case, the analysis code will
4690  // cope with a conservative value, and it will take care to purge
4691  // that value once it has finished.
4692  const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
4693 
4694  // Normally, in the cases we can prove no-overflow via a
4695  // backedge guarding condition, we can also compute a backedge
4696  // taken count for the loop. The exceptions are assumptions and
4697  // guards present in the loop -- SCEV is not great at exploiting
4698  // these to compute max backedge taken counts, but can still use
4699  // these to prove lack of overflow. Use this fact to avoid
4700  // doing extra work that may not pay off.
4701 
4702  if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
4703  AC.assumptions().empty())
4704