LLVM  14.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1 //===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the scalar evolution analysis
10 // engine, which is used primarily to analyze expressions involving induction
11 // variables in loops.
12 //
13 // There are several aspects to this library. First is the representation of
14 // scalar expressions, which are represented as subclasses of the SCEV class.
15 // These classes are used to represent certain types of subexpressions that we
16 // can handle. We only create one SCEV of a particular shape, so
17 // pointer-comparisons for equality are legal.
18 //
19 // One important aspect of the SCEV objects is that they are never cyclic, even
20 // if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21 // the PHI node is one of the idioms that we can represent (e.g., a polynomial
22 // recurrence) then we represent it directly as a recurrence node, otherwise we
23 // represent it as a SCEVUnknown node.
24 //
25 // In addition to being able to represent expressions of various types, we also
26 // have folders that are used to build the *canonical* representation for a
27 // particular expression. These folders are capable of using a variety of
28 // rewrite rules to simplify the expressions.
29 //
30 // Once the folders are defined, we can implement the more interesting
31 // higher-level code, such as the code that recognizes PHI nodes of various
32 // types, computes the execution count of a loop, etc.
33 //
34 // TODO: We should use these routines and value representations to implement
35 // dependence analysis!
36 //
37 //===----------------------------------------------------------------------===//
38 //
39 // There are several good references for the techniques used in this analysis.
40 //
41 // Chains of recurrences -- a method to expedite the evaluation
42 // of closed-form functions
43 // Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44 //
45 // On computational properties of chains of recurrences
46 // Eugene V. Zima
47 //
48 // Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49 // Robert A. van Engelen
50 //
51 // Efficient Symbolic Analysis for Optimizing Compilers
52 // Robert A. van Engelen
53 //
54 // Using the chains of recurrences algebra for data dependence testing and
55 // induction variable substitution
56 // MS Thesis, Johnie Birch
57 //
58 //===----------------------------------------------------------------------===//
59 
61 #include "llvm/ADT/APInt.h"
62 #include "llvm/ADT/ArrayRef.h"
63 #include "llvm/ADT/DenseMap.h"
66 #include "llvm/ADT/FoldingSet.h"
67 #include "llvm/ADT/None.h"
68 #include "llvm/ADT/Optional.h"
69 #include "llvm/ADT/STLExtras.h"
70 #include "llvm/ADT/ScopeExit.h"
71 #include "llvm/ADT/Sequence.h"
72 #include "llvm/ADT/SetVector.h"
73 #include "llvm/ADT/SmallPtrSet.h"
74 #include "llvm/ADT/SmallSet.h"
75 #include "llvm/ADT/SmallVector.h"
76 #include "llvm/ADT/Statistic.h"
77 #include "llvm/ADT/StringRef.h"
81 #include "llvm/Analysis/LoopInfo.h"
86 #include "llvm/Config/llvm-config.h"
87 #include "llvm/IR/Argument.h"
88 #include "llvm/IR/BasicBlock.h"
89 #include "llvm/IR/CFG.h"
90 #include "llvm/IR/Constant.h"
91 #include "llvm/IR/ConstantRange.h"
92 #include "llvm/IR/Constants.h"
93 #include "llvm/IR/DataLayout.h"
94 #include "llvm/IR/DerivedTypes.h"
95 #include "llvm/IR/Dominators.h"
96 #include "llvm/IR/Function.h"
97 #include "llvm/IR/GlobalAlias.h"
98 #include "llvm/IR/GlobalValue.h"
99 #include "llvm/IR/GlobalVariable.h"
100 #include "llvm/IR/InstIterator.h"
101 #include "llvm/IR/InstrTypes.h"
102 #include "llvm/IR/Instruction.h"
103 #include "llvm/IR/Instructions.h"
104 #include "llvm/IR/IntrinsicInst.h"
105 #include "llvm/IR/Intrinsics.h"
106 #include "llvm/IR/LLVMContext.h"
107 #include "llvm/IR/Metadata.h"
108 #include "llvm/IR/Operator.h"
109 #include "llvm/IR/PatternMatch.h"
110 #include "llvm/IR/Type.h"
111 #include "llvm/IR/Use.h"
112 #include "llvm/IR/User.h"
113 #include "llvm/IR/Value.h"
114 #include "llvm/IR/Verifier.h"
115 #include "llvm/InitializePasses.h"
116 #include "llvm/Pass.h"
117 #include "llvm/Support/Casting.h"
119 #include "llvm/Support/Compiler.h"
120 #include "llvm/Support/Debug.h"
122 #include "llvm/Support/KnownBits.h"
125 #include <algorithm>
126 #include <cassert>
127 #include <climits>
128 #include <cstddef>
129 #include <cstdint>
130 #include <cstdlib>
131 #include <map>
132 #include <memory>
133 #include <tuple>
134 #include <utility>
135 #include <vector>
136 
137 using namespace llvm;
138 using namespace PatternMatch;
139 
140 #define DEBUG_TYPE "scalar-evolution"
141 
142 STATISTIC(NumTripCountsComputed,
143  "Number of loops with predictable loop counts");
144 STATISTIC(NumTripCountsNotComputed,
145  "Number of loops without predictable loop counts");
146 STATISTIC(NumBruteForceTripCountsComputed,
147  "Number of loops with trip counts computed by force");
148 
149 static cl::opt<unsigned>
150 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
152  cl::desc("Maximum number of iterations SCEV will "
153  "symbolically execute a constant "
154  "derived loop"),
155  cl::init(100));
156 
157 // FIXME: Enable this with EXPENSIVE_CHECKS when the test suite is clean.
159  "verify-scev", cl::Hidden,
160  cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
162  "verify-scev-strict", cl::Hidden,
163  cl::desc("Enable stricter verification with -verify-scev is passed"));
164 static cl::opt<bool>
165  VerifySCEVMap("verify-scev-maps", cl::Hidden,
166  cl::desc("Verify no dangling value in ScalarEvolution's "
167  "ExprValueMap (slow)"));
168 
169 static cl::opt<bool> VerifyIR(
170  "scev-verify-ir", cl::Hidden,
171  cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
172  cl::init(false));
173 
175  "scev-mulops-inline-threshold", cl::Hidden,
176  cl::desc("Threshold for inlining multiplication operands into a SCEV"),
177  cl::init(32));
178 
180  "scev-addops-inline-threshold", cl::Hidden,
181  cl::desc("Threshold for inlining addition operands into a SCEV"),
182  cl::init(500));
183 
185  "scalar-evolution-max-scev-compare-depth", cl::Hidden,
186  cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
187  cl::init(32));
188 
190  "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
191  cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
192  cl::init(2));
193 
195  "scalar-evolution-max-value-compare-depth", cl::Hidden,
196  cl::desc("Maximum depth of recursive value complexity comparisons"),
197  cl::init(2));
198 
199 static cl::opt<unsigned>
200  MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
201  cl::desc("Maximum depth of recursive arithmetics"),
202  cl::init(32));
203 
205  "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
206  cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
207 
208 static cl::opt<unsigned>
209  MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
210  cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
211  cl::init(8));
212 
213 static cl::opt<unsigned>
214  MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
215  cl::desc("Max coefficients in AddRec during evolving"),
216  cl::init(8));
217 
218 static cl::opt<unsigned>
219  HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
220  cl::desc("Size of the expression which is considered huge"),
221  cl::init(4096));
222 
223 static cl::opt<bool>
224 ClassifyExpressions("scalar-evolution-classify-expressions",
225  cl::Hidden, cl::init(true),
226  cl::desc("When printing analysis, include information on every instruction"));
227 
229  "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
230  cl::init(false),
231  cl::desc("Use more powerful methods of sharpening expression ranges. May "
232  "be costly in terms of compile time"));
233 
234 //===----------------------------------------------------------------------===//
235 // SCEV class definitions
236 //===----------------------------------------------------------------------===//
237 
238 //===----------------------------------------------------------------------===//
239 // Implementation of the SCEV class.
240 //
241 
242 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
244  print(dbgs());
245  dbgs() << '\n';
246 }
247 #endif
248 
249 void SCEV::print(raw_ostream &OS) const {
250  switch (getSCEVType()) {
251  case scConstant:
252  cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
253  return;
254  case scPtrToInt: {
255  const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
256  const SCEV *Op = PtrToInt->getOperand();
257  OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
258  << *PtrToInt->getType() << ")";
259  return;
260  }
261  case scTruncate: {
262  const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
263  const SCEV *Op = Trunc->getOperand();
264  OS << "(trunc " << *Op->getType() << " " << *Op << " to "
265  << *Trunc->getType() << ")";
266  return;
267  }
268  case scZeroExtend: {
269  const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
270  const SCEV *Op = ZExt->getOperand();
271  OS << "(zext " << *Op->getType() << " " << *Op << " to "
272  << *ZExt->getType() << ")";
273  return;
274  }
275  case scSignExtend: {
276  const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
277  const SCEV *Op = SExt->getOperand();
278  OS << "(sext " << *Op->getType() << " " << *Op << " to "
279  << *SExt->getType() << ")";
280  return;
281  }
282  case scAddRecExpr: {
283  const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
284  OS << "{" << *AR->getOperand(0);
285  for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
286  OS << ",+," << *AR->getOperand(i);
287  OS << "}<";
288  if (AR->hasNoUnsignedWrap())
289  OS << "nuw><";
290  if (AR->hasNoSignedWrap())
291  OS << "nsw><";
292  if (AR->hasNoSelfWrap() &&
293  !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)))
294  OS << "nw><";
295  AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
296  OS << ">";
297  return;
298  }
299  case scAddExpr:
300  case scMulExpr:
301  case scUMaxExpr:
302  case scSMaxExpr:
303  case scUMinExpr:
304  case scSMinExpr: {
305  const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
306  const char *OpStr = nullptr;
307  switch (NAry->getSCEVType()) {
308  case scAddExpr: OpStr = " + "; break;
309  case scMulExpr: OpStr = " * "; break;
310  case scUMaxExpr: OpStr = " umax "; break;
311  case scSMaxExpr: OpStr = " smax "; break;
312  case scUMinExpr:
313  OpStr = " umin ";
314  break;
315  case scSMinExpr:
316  OpStr = " smin ";
317  break;
318  default:
319  llvm_unreachable("There are no other nary expression types.");
320  }
321  OS << "(";
322  ListSeparator LS(OpStr);
323  for (const SCEV *Op : NAry->operands())
324  OS << LS << *Op;
325  OS << ")";
326  switch (NAry->getSCEVType()) {
327  case scAddExpr:
328  case scMulExpr:
329  if (NAry->hasNoUnsignedWrap())
330  OS << "<nuw>";
331  if (NAry->hasNoSignedWrap())
332  OS << "<nsw>";
333  break;
334  default:
335  // Nothing to print for other nary expressions.
336  break;
337  }
338  return;
339  }
340  case scUDivExpr: {
341  const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
342  OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
343  return;
344  }
345  case scUnknown: {
346  const SCEVUnknown *U = cast<SCEVUnknown>(this);
347  Type *AllocTy;
348  if (U->isSizeOf(AllocTy)) {
349  OS << "sizeof(" << *AllocTy << ")";
350  return;
351  }
352  if (U->isAlignOf(AllocTy)) {
353  OS << "alignof(" << *AllocTy << ")";
354  return;
355  }
356 
357  Type *CTy;
358  Constant *FieldNo;
359  if (U->isOffsetOf(CTy, FieldNo)) {
360  OS << "offsetof(" << *CTy << ", ";
361  FieldNo->printAsOperand(OS, false);
362  OS << ")";
363  return;
364  }
365 
366  // Otherwise just print it normally.
367  U->getValue()->printAsOperand(OS, false);
368  return;
369  }
370  case scCouldNotCompute:
371  OS << "***COULDNOTCOMPUTE***";
372  return;
373  }
374  llvm_unreachable("Unknown SCEV kind!");
375 }
376 
377 Type *SCEV::getType() const {
378  switch (getSCEVType()) {
379  case scConstant:
380  return cast<SCEVConstant>(this)->getType();
381  case scPtrToInt:
382  case scTruncate:
383  case scZeroExtend:
384  case scSignExtend:
385  return cast<SCEVCastExpr>(this)->getType();
386  case scAddRecExpr:
387  return cast<SCEVAddRecExpr>(this)->getType();
388  case scMulExpr:
389  return cast<SCEVMulExpr>(this)->getType();
390  case scUMaxExpr:
391  case scSMaxExpr:
392  case scUMinExpr:
393  case scSMinExpr:
394  return cast<SCEVMinMaxExpr>(this)->getType();
395  case scAddExpr:
396  return cast<SCEVAddExpr>(this)->getType();
397  case scUDivExpr:
398  return cast<SCEVUDivExpr>(this)->getType();
399  case scUnknown:
400  return cast<SCEVUnknown>(this)->getType();
401  case scCouldNotCompute:
402  llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
403  }
404  llvm_unreachable("Unknown SCEV kind!");
405 }
406 
407 bool SCEV::isZero() const {
408  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
409  return SC->getValue()->isZero();
410  return false;
411 }
412 
413 bool SCEV::isOne() const {
414  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
415  return SC->getValue()->isOne();
416  return false;
417 }
418 
419 bool SCEV::isAllOnesValue() const {
420  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
421  return SC->getValue()->isMinusOne();
422  return false;
423 }
424 
426  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
427  if (!Mul) return false;
428 
429  // If there is a constant factor, it will be first.
430  const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
431  if (!SC) return false;
432 
433  // Return true if the value is negative, this matches things like (-42 * V).
434  return SC->getAPInt().isNegative();
435 }
436 
439 
441  return S->getSCEVType() == scCouldNotCompute;
442 }
443 
446  ID.AddInteger(scConstant);
447  ID.AddPointer(V);
448  void *IP = nullptr;
449  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
450  SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
451  UniqueSCEVs.InsertNode(S, IP);
452  return S;
453 }
454 
456  return getConstant(ConstantInt::get(getContext(), Val));
457 }
458 
459 const SCEV *
461  IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
462  return getConstant(ConstantInt::get(ITy, V, isSigned));
463 }
464 
466  const SCEV *op, Type *ty)
467  : SCEV(ID, SCEVTy, computeExpressionSize(op)), Ty(ty) {
468  Operands[0] = op;
469 }
470 
471 SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
472  Type *ITy)
473  : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
474  assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
475  "Must be a non-bit-width-changing pointer-to-integer cast!");
476 }
477 
479  SCEVTypes SCEVTy, const SCEV *op,
480  Type *ty)
481  : SCEVCastExpr(ID, SCEVTy, op, ty) {}
482 
483 SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
484  Type *ty)
486  assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
487  "Cannot truncate non-integer value!");
488 }
489 
490 SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
491  const SCEV *op, Type *ty)
493  assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
494  "Cannot zero extend non-integer value!");
495 }
496 
497 SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
498  const SCEV *op, Type *ty)
500  assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
501  "Cannot sign extend non-integer value!");
502 }
503 
504 void SCEVUnknown::deleted() {
505  // Clear this SCEVUnknown from various maps.
506  SE->forgetMemoizedResults(this);
507 
508  // Remove this SCEVUnknown from the uniquing map.
509  SE->UniqueSCEVs.RemoveNode(this);
510 
511  // Release the value.
512  setValPtr(nullptr);
513 }
514 
515 void SCEVUnknown::allUsesReplacedWith(Value *New) {
516  // Remove this SCEVUnknown from the uniquing map.
517  SE->UniqueSCEVs.RemoveNode(this);
518 
519  // Update this SCEVUnknown to point to the new value. This is needed
520  // because there may still be outstanding SCEVs which still point to
521  // this SCEVUnknown.
522  setValPtr(New);
523 }
524 
525 bool SCEVUnknown::isSizeOf(Type *&AllocTy) const {
526  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
527  if (VCE->getOpcode() == Instruction::PtrToInt)
528  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
529  if (CE->getOpcode() == Instruction::GetElementPtr &&
530  CE->getOperand(0)->isNullValue() &&
531  CE->getNumOperands() == 2)
532  if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1)))
533  if (CI->isOne()) {
534  AllocTy = cast<GEPOperator>(CE)->getSourceElementType();
535  return true;
536  }
537 
538  return false;
539 }
540 
541 bool SCEVUnknown::isAlignOf(Type *&AllocTy) const {
542  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
543  if (VCE->getOpcode() == Instruction::PtrToInt)
544  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
545  if (CE->getOpcode() == Instruction::GetElementPtr &&
546  CE->getOperand(0)->isNullValue()) {
547  Type *Ty = cast<GEPOperator>(CE)->getSourceElementType();
548  if (StructType *STy = dyn_cast<StructType>(Ty))
549  if (!STy->isPacked() &&
550  CE->getNumOperands() == 3 &&
551  CE->getOperand(1)->isNullValue()) {
552  if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2)))
553  if (CI->isOne() &&
554  STy->getNumElements() == 2 &&
555  STy->getElementType(0)->isIntegerTy(1)) {
556  AllocTy = STy->getElementType(1);
557  return true;
558  }
559  }
560  }
561 
562  return false;
563 }
564 
565 bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const {
566  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
567  if (VCE->getOpcode() == Instruction::PtrToInt)
568  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
569  if (CE->getOpcode() == Instruction::GetElementPtr &&
570  CE->getNumOperands() == 3 &&
571  CE->getOperand(0)->isNullValue() &&
572  CE->getOperand(1)->isNullValue()) {
573  Type *Ty = cast<GEPOperator>(CE)->getSourceElementType();
574  // Ignore vector types here so that ScalarEvolutionExpander doesn't
575  // emit getelementptrs that index into vectors.
576  if (Ty->isStructTy() || Ty->isArrayTy()) {
577  CTy = Ty;
578  FieldNo = CE->getOperand(2);
579  return true;
580  }
581  }
582 
583  return false;
584 }
585 
586 //===----------------------------------------------------------------------===//
587 // SCEV Utilities
588 //===----------------------------------------------------------------------===//
589 
590 /// Compare the two values \p LV and \p RV in terms of their "complexity" where
591 /// "complexity" is a partial (and somewhat ad-hoc) relation used to order
592 /// operands in SCEV expressions. \p EqCache is a set of pairs of values that
593 /// have been previously deemed to be "equally complex" by this routine. It is
594 /// intended to avoid exponential time complexity in cases like:
595 ///
596 /// %a = f(%x, %y)
597 /// %b = f(%a, %a)
598 /// %c = f(%b, %b)
599 ///
600 /// %d = f(%x, %y)
601 /// %e = f(%d, %d)
602 /// %f = f(%e, %e)
603 ///
604 /// CompareValueComplexity(%f, %c)
605 ///
606 /// Since we do not continue running this routine on expression trees once we
607 /// have seen unequal values, there is no need to track them in the cache.
608 static int
610  const LoopInfo *const LI, Value *LV, Value *RV,
611  unsigned Depth) {
612  if (Depth > MaxValueCompareDepth || EqCacheValue.isEquivalent(LV, RV))
613  return 0;
614 
615  // Order pointer values after integer values. This helps SCEVExpander form
616  // GEPs.
617  bool LIsPointer = LV->getType()->isPointerTy(),
618  RIsPointer = RV->getType()->isPointerTy();
619  if (LIsPointer != RIsPointer)
620  return (int)LIsPointer - (int)RIsPointer;
621 
622  // Compare getValueID values.
623  unsigned LID = LV->getValueID(), RID = RV->getValueID();
624  if (LID != RID)
625  return (int)LID - (int)RID;
626 
627  // Sort arguments by their position.
628  if (const auto *LA = dyn_cast<Argument>(LV)) {
629  const auto *RA = cast<Argument>(RV);
630  unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
631  return (int)LArgNo - (int)RArgNo;
632  }
633 
634  if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
635  const auto *RGV = cast<GlobalValue>(RV);
636 
637  const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
638  auto LT = GV->getLinkage();
639  return !(GlobalValue::isPrivateLinkage(LT) ||
641  };
642 
643  // Use the names to distinguish the two values, but only if the
644  // names are semantically important.
645  if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
646  return LGV->getName().compare(RGV->getName());
647  }
648 
649  // For instructions, compare their loop depth, and their operand count. This
650  // is pretty loose.
651  if (const auto *LInst = dyn_cast<Instruction>(LV)) {
652  const auto *RInst = cast<Instruction>(RV);
653 
654  // Compare loop depths.
655  const BasicBlock *LParent = LInst->getParent(),
656  *RParent = RInst->getParent();
657  if (LParent != RParent) {
658  unsigned LDepth = LI->getLoopDepth(LParent),
659  RDepth = LI->getLoopDepth(RParent);
660  if (LDepth != RDepth)
661  return (int)LDepth - (int)RDepth;
662  }
663 
664  // Compare the number of operands.
665  unsigned LNumOps = LInst->getNumOperands(),
666  RNumOps = RInst->getNumOperands();
667  if (LNumOps != RNumOps)
668  return (int)LNumOps - (int)RNumOps;
669 
670  for (unsigned Idx : seq(0u, LNumOps)) {
671  int Result =
672  CompareValueComplexity(EqCacheValue, LI, LInst->getOperand(Idx),
673  RInst->getOperand(Idx), Depth + 1);
674  if (Result != 0)
675  return Result;
676  }
677  }
678 
679  EqCacheValue.unionSets(LV, RV);
680  return 0;
681 }
682 
683 // Return negative, zero, or positive, if LHS is less than, equal to, or greater
684 // than RHS, respectively. A three-way result allows recursive comparisons to be
685 // more efficient.
686 // If the max analysis depth was reached, return None, assuming we do not know
687 // if they are equivalent for sure.
688 static Optional<int>
690  EquivalenceClasses<const Value *> &EqCacheValue,
691  const LoopInfo *const LI, const SCEV *LHS,
692  const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
693  // Fast-path: SCEVs are uniqued so we can do a quick equality check.
694  if (LHS == RHS)
695  return 0;
696 
697  // Primarily, sort the SCEVs by their getSCEVType().
698  SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
699  if (LType != RType)
700  return (int)LType - (int)RType;
701 
702  if (EqCacheSCEV.isEquivalent(LHS, RHS))
703  return 0;
704 
706  return None;
707 
708  // Aside from the getSCEVType() ordering, the particular ordering
709  // isn't very important except that it's beneficial to be consistent,
710  // so that (a + b) and (b + a) don't end up as different expressions.
711  switch (LType) {
712  case scUnknown: {
713  const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
714  const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
715 
716  int X = CompareValueComplexity(EqCacheValue, LI, LU->getValue(),
717  RU->getValue(), Depth + 1);
718  if (X == 0)
719  EqCacheSCEV.unionSets(LHS, RHS);
720  return X;
721  }
722 
723  case scConstant: {
724  const SCEVConstant *LC = cast<SCEVConstant>(LHS);
725  const SCEVConstant *RC = cast<SCEVConstant>(RHS);
726 
727  // Compare constant values.
728  const APInt &LA = LC->getAPInt();
729  const APInt &RA = RC->getAPInt();
730  unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
731  if (LBitWidth != RBitWidth)
732  return (int)LBitWidth - (int)RBitWidth;
733  return LA.ult(RA) ? -1 : 1;
734  }
735 
736  case scAddRecExpr: {
737  const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
738  const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
739 
740  // There is always a dominance between two recs that are used by one SCEV,
741  // so we can safely sort recs by loop header dominance. We require such
742  // order in getAddExpr.
743  const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
744  if (LLoop != RLoop) {
745  const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
746  assert(LHead != RHead && "Two loops share the same header?");
747  if (DT.dominates(LHead, RHead))
748  return 1;
749  else
750  assert(DT.dominates(RHead, LHead) &&
751  "No dominance between recurrences used by one SCEV?");
752  return -1;
753  }
754 
755  // Addrec complexity grows with operand count.
756  unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands();
757  if (LNumOps != RNumOps)
758  return (int)LNumOps - (int)RNumOps;
759 
760  // Lexicographically compare.
761  for (unsigned i = 0; i != LNumOps; ++i) {
762  auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
763  LA->getOperand(i), RA->getOperand(i), DT,
764  Depth + 1);
765  if (X != 0)
766  return X;
767  }
768  EqCacheSCEV.unionSets(LHS, RHS);
769  return 0;
770  }
771 
772  case scAddExpr:
773  case scMulExpr:
774  case scSMaxExpr:
775  case scUMaxExpr:
776  case scSMinExpr:
777  case scUMinExpr: {
778  const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS);
779  const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
780 
781  // Lexicographically compare n-ary expressions.
782  unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
783  if (LNumOps != RNumOps)
784  return (int)LNumOps - (int)RNumOps;
785 
786  for (unsigned i = 0; i != LNumOps; ++i) {
787  auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
788  LC->getOperand(i), RC->getOperand(i), DT,
789  Depth + 1);
790  if (X != 0)
791  return X;
792  }
793  EqCacheSCEV.unionSets(LHS, RHS);
794  return 0;
795  }
796 
797  case scUDivExpr: {
798  const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS);
799  const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
800 
801  // Lexicographically compare udiv expressions.
802  auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getLHS(),
803  RC->getLHS(), DT, Depth + 1);
804  if (X != 0)
805  return X;
806  X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getRHS(),
807  RC->getRHS(), DT, Depth + 1);
808  if (X == 0)
809  EqCacheSCEV.unionSets(LHS, RHS);
810  return X;
811  }
812 
813  case scPtrToInt:
814  case scTruncate:
815  case scZeroExtend:
816  case scSignExtend: {
817  const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS);
818  const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
819 
820  // Compare cast expressions by operand.
821  auto X =
822  CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getOperand(),
823  RC->getOperand(), DT, Depth + 1);
824  if (X == 0)
825  EqCacheSCEV.unionSets(LHS, RHS);
826  return X;
827  }
828 
829  case scCouldNotCompute:
830  llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
831  }
832  llvm_unreachable("Unknown SCEV kind!");
833 }
834 
835 /// Given a list of SCEV objects, order them by their complexity, and group
836 /// objects of the same complexity together by value. When this routine is
837 /// finished, we know that any duplicates in the vector are consecutive and that
838 /// complexity is monotonically increasing.
839 ///
840 /// Note that we go take special precautions to ensure that we get deterministic
841 /// results from this routine. In other words, we don't want the results of
842 /// this to depend on where the addresses of various SCEV objects happened to
843 /// land in memory.
845  LoopInfo *LI, DominatorTree &DT) {
846  if (Ops.size() < 2) return; // Noop
847 
850 
851  // Whether LHS has provably less complexity than RHS.
852  auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
853  auto Complexity =
854  CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT);
855  return Complexity && *Complexity < 0;
856  };
857  if (Ops.size() == 2) {
858  // This is the common case, which also happens to be trivially simple.
859  // Special case it.
860  const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
861  if (IsLessComplex(RHS, LHS))
862  std::swap(LHS, RHS);
863  return;
864  }
865 
866  // Do the rough sort by complexity.
867  llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
868  return IsLessComplex(LHS, RHS);
869  });
870 
871  // Now that we are sorted by complexity, group elements of the same
872  // complexity. Note that this is, at worst, N^2, but the vector is likely to
873  // be extremely short in practice. Note that we take this approach because we
874  // do not want to depend on the addresses of the objects we are grouping.
875  for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
876  const SCEV *S = Ops[i];
877  unsigned Complexity = S->getSCEVType();
878 
879  // If there are any objects of the same complexity and same value as this
880  // one, group them.
881  for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
882  if (Ops[j] == S) { // Found a duplicate.
883  // Move it to immediately after i'th element.
884  std::swap(Ops[i+1], Ops[j]);
885  ++i; // no need to rescan it.
886  if (i == e-2) return; // Done!
887  }
888  }
889  }
890 }
891 
892 /// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
893 /// least HugeExprThreshold nodes).
895  return any_of(Ops, [](const SCEV *S) {
896  return S->getExpressionSize() >= HugeExprThreshold;
897  });
898 }
899 
900 //===----------------------------------------------------------------------===//
901 // Simple SCEV method implementations
902 //===----------------------------------------------------------------------===//
903 
904 /// Compute BC(It, K). The result has width W. Assume, K > 0.
905 static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
906  ScalarEvolution &SE,
907  Type *ResultTy) {
908  // Handle the simplest case efficiently.
909  if (K == 1)
910  return SE.getTruncateOrZeroExtend(It, ResultTy);
911 
912  // We are using the following formula for BC(It, K):
913  //
914  // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
915  //
916  // Suppose, W is the bitwidth of the return value. We must be prepared for
917  // overflow. Hence, we must assure that the result of our computation is
918  // equal to the accurate one modulo 2^W. Unfortunately, division isn't
919  // safe in modular arithmetic.
920  //
921  // However, this code doesn't use exactly that formula; the formula it uses
922  // is something like the following, where T is the number of factors of 2 in
923  // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
924  // exponentiation:
925  //
926  // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
927  //
928  // This formula is trivially equivalent to the previous formula. However,
929  // this formula can be implemented much more efficiently. The trick is that
930  // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
931  // arithmetic. To do exact division in modular arithmetic, all we have
932  // to do is multiply by the inverse. Therefore, this step can be done at
933  // width W.
934  //
935  // The next issue is how to safely do the division by 2^T. The way this
936  // is done is by doing the multiplication step at a width of at least W + T
937  // bits. This way, the bottom W+T bits of the product are accurate. Then,
938  // when we perform the division by 2^T (which is equivalent to a right shift
939  // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
940  // truncated out after the division by 2^T.
941  //
942  // In comparison to just directly using the first formula, this technique
943  // is much more efficient; using the first formula requires W * K bits,
944  // but this formula less than W + K bits. Also, the first formula requires
945  // a division step, whereas this formula only requires multiplies and shifts.
946  //
947  // It doesn't matter whether the subtraction step is done in the calculation
948  // width or the input iteration count's width; if the subtraction overflows,
949  // the result must be zero anyway. We prefer here to do it in the width of
950  // the induction variable because it helps a lot for certain cases; CodeGen
951  // isn't smart enough to ignore the overflow, which leads to much less
952  // efficient code if the width of the subtraction is wider than the native
953  // register width.
954  //
955  // (It's possible to not widen at all by pulling out factors of 2 before
956  // the multiplication; for example, K=2 can be calculated as
957  // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
958  // extra arithmetic, so it's not an obvious win, and it gets
959  // much more complicated for K > 3.)
960 
961  // Protection from insane SCEVs; this bound is conservative,
962  // but it probably doesn't matter.
963  if (K > 1000)
964  return SE.getCouldNotCompute();
965 
966  unsigned W = SE.getTypeSizeInBits(ResultTy);
967 
968  // Calculate K! / 2^T and T; we divide out the factors of two before
969  // multiplying for calculating K! / 2^T to avoid overflow.
970  // Other overflow doesn't matter because we only care about the bottom
971  // W bits of the result.
972  APInt OddFactorial(W, 1);
973  unsigned T = 1;
974  for (unsigned i = 3; i <= K; ++i) {
975  APInt Mult(W, i);
976  unsigned TwoFactors = Mult.countTrailingZeros();
977  T += TwoFactors;
978  Mult.lshrInPlace(TwoFactors);
979  OddFactorial *= Mult;
980  }
981 
982  // We need at least W + T bits for the multiplication step
983  unsigned CalculationBits = W + T;
984 
985  // Calculate 2^T, at width T+W.
986  APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
987 
988  // Calculate the multiplicative inverse of K! / 2^T;
989  // this multiplication factor will perform the exact division by
990  // K! / 2^T.
992  APInt MultiplyFactor = OddFactorial.zext(W+1);
993  MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
994  MultiplyFactor = MultiplyFactor.trunc(W);
995 
996  // Calculate the product, at width T+W
997  IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
998  CalculationBits);
999  const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
1000  for (unsigned i = 1; i != K; ++i) {
1001  const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
1002  Dividend = SE.getMulExpr(Dividend,
1003  SE.getTruncateOrZeroExtend(S, CalculationTy));
1004  }
1005 
1006  // Divide by 2^T
1007  const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
1008 
1009  // Truncate the result, and divide by K! / 2^T.
1010 
1011  return SE.getMulExpr(SE.getConstant(MultiplyFactor),
1012  SE.getTruncateOrZeroExtend(DivResult, ResultTy));
1013 }
1014 
1015 /// Return the value of this chain of recurrences at the specified iteration
1016 /// number. We can evaluate this recurrence by multiplying each element in the
1017 /// chain by the binomial coefficient corresponding to it. In other words, we
1018 /// can evaluate {A,+,B,+,C,+,D} as:
1019 ///
1020 /// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
1021 ///
1022 /// where BC(It, k) stands for binomial coefficient.
1024  ScalarEvolution &SE) const {
1025  return evaluateAtIteration(makeArrayRef(op_begin(), op_end()), It, SE);
1026 }
1027 
1028 const SCEV *
1030  const SCEV *It, ScalarEvolution &SE) {
1031  assert(Operands.size() > 0);
1032  const SCEV *Result = Operands[0];
1033  for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
1034  // The computation is correct in the face of overflow provided that the
1035  // multiplication is performed _after_ the evaluation of the binomial
1036  // coefficient.
1037  const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
1038  if (isa<SCEVCouldNotCompute>(Coeff))
1039  return Coeff;
1040 
1041  Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
1042  }
1043  return Result;
1044 }
1045 
1046 //===----------------------------------------------------------------------===//
1047 // SCEV Expression folder implementations
1048 //===----------------------------------------------------------------------===//
1049 
1051  unsigned Depth) {
1052  assert(Depth <= 1 &&
1053  "getLosslessPtrToIntExpr() should self-recurse at most once.");
1054 
1055  // We could be called with an integer-typed operands during SCEV rewrites.
1056  // Since the operand is an integer already, just perform zext/trunc/self cast.
1057  if (!Op->getType()->isPointerTy())
1058  return Op;
1059 
1060  // What would be an ID for such a SCEV cast expression?
1062  ID.AddInteger(scPtrToInt);
1063  ID.AddPointer(Op);
1064 
1065  void *IP = nullptr;
1066 
1067  // Is there already an expression for such a cast?
1068  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1069  return S;
1070 
1071  // It isn't legal for optimizations to construct new ptrtoint expressions
1072  // for non-integral pointers.
1073  if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1074  return getCouldNotCompute();
1075 
1076  Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1077 
1078  // We can only trivially model ptrtoint if SCEV's effective (integer) type
1079  // is sufficiently wide to represent all possible pointer values.
1080  // We could theoretically teach SCEV to truncate wider pointers, but
1081  // that isn't implemented for now.
1083  getDataLayout().getTypeSizeInBits(IntPtrTy))
1084  return getCouldNotCompute();
1085 
1086  // If not, is this expression something we can't reduce any further?
1087  if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1088  // Perform some basic constant folding. If the operand of the ptr2int cast
1089  // is a null pointer, don't create a ptr2int SCEV expression (that will be
1090  // left as-is), but produce a zero constant.
1091  // NOTE: We could handle a more general case, but lack motivational cases.
1092  if (isa<ConstantPointerNull>(U->getValue()))
1093  return getZero(IntPtrTy);
1094 
1095  // Create an explicit cast node.
1096  // We can reuse the existing insert position since if we get here,
1097  // we won't have made any changes which would invalidate it.
1098  SCEV *S = new (SCEVAllocator)
1099  SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1100  UniqueSCEVs.InsertNode(S, IP);
1101  registerUser(S, Op);
1102  return S;
1103  }
1104 
1105  assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
1106  "non-SCEVUnknown's.");
1107 
1108  // Otherwise, we've got some expression that is more complex than just a
1109  // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1110  // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1111  // only, and the expressions must otherwise be integer-typed.
1112  // So sink the cast down to the SCEVUnknown's.
1113 
1114  /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1115  /// which computes a pointer-typed value, and rewrites the whole expression
1116  /// tree so that *all* the computations are done on integers, and the only
1117  /// pointer-typed operands in the expression are SCEVUnknown.
1118  class SCEVPtrToIntSinkingRewriter
1119  : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1121 
1122  public:
1123  SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1124 
1125  static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1126  SCEVPtrToIntSinkingRewriter Rewriter(SE);
1127  return Rewriter.visit(Scev);
1128  }
1129 
1130  const SCEV *visit(const SCEV *S) {
1131  Type *STy = S->getType();
1132  // If the expression is not pointer-typed, just keep it as-is.
1133  if (!STy->isPointerTy())
1134  return S;
1135  // Else, recursively sink the cast down into it.
1136  return Base::visit(S);
1137  }
1138 
1139  const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1141  bool Changed = false;
1142  for (auto *Op : Expr->operands()) {
1143  Operands.push_back(visit(Op));
1144  Changed |= Op != Operands.back();
1145  }
1146  return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1147  }
1148 
1149  const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1151  bool Changed = false;
1152  for (auto *Op : Expr->operands()) {
1153  Operands.push_back(visit(Op));
1154  Changed |= Op != Operands.back();
1155  }
1156  return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1157  }
1158 
1159  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1160  assert(Expr->getType()->isPointerTy() &&
1161  "Should only reach pointer-typed SCEVUnknown's.");
1162  return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1163  }
1164  };
1165 
1166  // And actually perform the cast sinking.
1167  const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1168  assert(IntOp->getType()->isIntegerTy() &&
1169  "We must have succeeded in sinking the cast, "
1170  "and ending up with an integer-typed expression!");
1171  return IntOp;
1172 }
1173 
1175  assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1176 
1177  const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1178  if (isa<SCEVCouldNotCompute>(IntOp))
1179  return IntOp;
1180 
1181  return getTruncateOrZeroExtend(IntOp, Ty);
1182 }
1183 
1185  unsigned Depth) {
1186  assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1187  "This is not a truncating conversion!");
1188  assert(isSCEVable(Ty) &&
1189  "This is not a conversion to a SCEVable type!");
1190  assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1191  Ty = getEffectiveSCEVType(Ty);
1192 
1194  ID.AddInteger(scTruncate);
1195  ID.AddPointer(Op);
1196  ID.AddPointer(Ty);
1197  void *IP = nullptr;
1198  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1199 
1200  // Fold if the operand is constant.
1201  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1202  return getConstant(
1203  cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1204 
1205  // trunc(trunc(x)) --> trunc(x)
1206  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
1207  return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1208 
1209  // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1210  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1211  return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1212 
1213  // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1214  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1215  return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1216 
1217  if (Depth > MaxCastDepth) {
1218  SCEV *S =
1219  new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1220  UniqueSCEVs.InsertNode(S, IP);
1221  registerUser(S, Op);
1222  return S;
1223  }
1224 
1225  // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1226  // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1227  // if after transforming we have at most one truncate, not counting truncates
1228  // that replace other casts.
1229  if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) {
1230  auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1232  unsigned numTruncs = 0;
1233  for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1234  ++i) {
1235  const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1236  if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1237  isa<SCEVTruncateExpr>(S))
1238  numTruncs++;
1239  Operands.push_back(S);
1240  }
1241  if (numTruncs < 2) {
1242  if (isa<SCEVAddExpr>(Op))
1243  return getAddExpr(Operands);
1244  else if (isa<SCEVMulExpr>(Op))
1245  return getMulExpr(Operands);
1246  else
1247  llvm_unreachable("Unexpected SCEV type for Op.");
1248  }
1249  // Although we checked in the beginning that ID is not in the cache, it is
1250  // possible that during recursion and different modification ID was inserted
1251  // into the cache. So if we find it, just return it.
1252  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1253  return S;
1254  }
1255 
1256  // If the input value is a chrec scev, truncate the chrec's operands.
1257  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1259  for (const SCEV *Op : AddRec->operands())
1260  Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1261  return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1262  }
1263 
1264  // Return zero if truncating to known zeros.
1265  uint32_t MinTrailingZeros = GetMinTrailingZeros(Op);
1266  if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1267  return getZero(Ty);
1268 
1269  // The cast wasn't folded; create an explicit cast node. We can reuse
1270  // the existing insert position since if we get here, we won't have
1271  // made any changes which would invalidate it.
1272  SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1273  Op, Ty);
1274  UniqueSCEVs.InsertNode(S, IP);
1275  registerUser(S, Op);
1276  return S;
1277 }
1278 
1279 // Get the limit of a recurrence such that incrementing by Step cannot cause
1280 // signed overflow as long as the value of the recurrence within the
1281 // loop does not exceed this limit before incrementing.
1282 static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1283  ICmpInst::Predicate *Pred,
1284  ScalarEvolution *SE) {
1285  unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1286  if (SE->isKnownPositive(Step)) {
1287  *Pred = ICmpInst::ICMP_SLT;
1289  SE->getSignedRangeMax(Step));
1290  }
1291  if (SE->isKnownNegative(Step)) {
1292  *Pred = ICmpInst::ICMP_SGT;
1294  SE->getSignedRangeMin(Step));
1295  }
1296  return nullptr;
1297 }
1298 
1299 // Get the limit of a recurrence such that incrementing by Step cannot cause
1300 // unsigned overflow as long as the value of the recurrence within the loop does
1301 // not exceed this limit before incrementing.
1302 static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step,
1303  ICmpInst::Predicate *Pred,
1304  ScalarEvolution *SE) {
1305  unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1306  *Pred = ICmpInst::ICMP_ULT;
1307 
1309  SE->getUnsignedRangeMax(Step));
1310 }
1311 
1312 namespace {
1313 
1314 struct ExtendOpTraitsBase {
1315  typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1316  unsigned);
1317 };
1318 
1319 // Used to make code generic over signed and unsigned overflow.
1320 template <typename ExtendOp> struct ExtendOpTraits {
1321  // Members present:
1322  //
1323  // static const SCEV::NoWrapFlags WrapType;
1324  //
1325  // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1326  //
1327  // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1328  // ICmpInst::Predicate *Pred,
1329  // ScalarEvolution *SE);
1330 };
1331 
1332 template <>
1333 struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1334  static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1335 
1336  static const GetExtendExprTy GetExtendExpr;
1337 
1338  static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1339  ICmpInst::Predicate *Pred,
1340  ScalarEvolution *SE) {
1341  return getSignedOverflowLimitForStep(Step, Pred, SE);
1342  }
1343 };
1344 
1345 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1347 
1348 template <>
1349 struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1350  static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1351 
1352  static const GetExtendExprTy GetExtendExpr;
1353 
1354  static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1355  ICmpInst::Predicate *Pred,
1356  ScalarEvolution *SE) {
1357  return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1358  }
1359 };
1360 
1361 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1363 
1364 } // end anonymous namespace
1365 
1366 // The recurrence AR has been shown to have no signed/unsigned wrap or something
1367 // close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1368 // easily prove NSW/NUW for its preincrement or postincrement sibling. This
1369 // allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1370 // Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1371 // expression "Step + sext/zext(PreIncAR)" is congruent with
1372 // "sext/zext(PostIncAR)"
1373 template <typename ExtendOpTy>
1374 static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1375  ScalarEvolution *SE, unsigned Depth) {
1376  auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1377  auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1378 
1379  const Loop *L = AR->getLoop();
1380  const SCEV *Start = AR->getStart();
1381  const SCEV *Step = AR->getStepRecurrence(*SE);
1382 
1383  // Check for a simple looking step prior to loop entry.
1384  const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1385  if (!SA)
1386  return nullptr;
1387 
1388  // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1389  // subtraction is expensive. For this purpose, perform a quick and dirty
1390  // difference, by checking for Step in the operand list.
1392  for (const SCEV *Op : SA->operands())
1393  if (Op != Step)
1394  DiffOps.push_back(Op);
1395 
1396  if (DiffOps.size() == SA->getNumOperands())
1397  return nullptr;
1398 
1399  // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1400  // `Step`:
1401 
1402  // 1. NSW/NUW flags on the step increment.
1403  auto PreStartFlags =
1405  const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1406  const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1407  SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1408 
1409  // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1410  // "S+X does not sign/unsign-overflow".
1411  //
1412 
1413  const SCEV *BECount = SE->getBackedgeTakenCount(L);
1414  if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1415  !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1416  return PreStart;
1417 
1418  // 2. Direct overflow check on the step operation's expression.
1419  unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1420  Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1421  const SCEV *OperandExtendedStart =
1422  SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1423  (SE->*GetExtendExpr)(Step, WideTy, Depth));
1424  if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1425  if (PreAR && AR->getNoWrapFlags(WrapType)) {
1426  // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1427  // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1428  // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1429  SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1430  }
1431  return PreStart;
1432  }
1433 
1434  // 3. Loop precondition.
1435  ICmpInst::Predicate Pred;
1436  const SCEV *OverflowLimit =
1437  ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1438 
1439  if (OverflowLimit &&
1440  SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1441  return PreStart;
1442 
1443  return nullptr;
1444 }
1445 
1446 // Get the normalized zero or sign extended expression for this AddRec's Start.
1447 template <typename ExtendOpTy>
1448 static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1449  ScalarEvolution *SE,
1450  unsigned Depth) {
1451  auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1452 
1453  const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1454  if (!PreStart)
1455  return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1456 
1457  return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1458  Depth),
1459  (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1460 }
1461 
1462 // Try to prove away overflow by looking at "nearby" add recurrences. A
1463 // motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1464 // does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1465 //
1466 // Formally:
1467 //
1468 // {S,+,X} == {S-T,+,X} + T
1469 // => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1470 //
1471 // If ({S-T,+,X} + T) does not overflow ... (1)
1472 //
1473 // RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1474 //
1475 // If {S-T,+,X} does not overflow ... (2)
1476 //
1477 // RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1478 // == {Ext(S-T)+Ext(T),+,Ext(X)}
1479 //
1480 // If (S-T)+T does not overflow ... (3)
1481 //
1482 // RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1483 // == {Ext(S),+,Ext(X)} == LHS
1484 //
1485 // Thus, if (1), (2) and (3) are true for some T, then
1486 // Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1487 //
1488 // (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1489 // does not overflow" restricted to the 0th iteration. Therefore we only need
1490 // to check for (1) and (2).
1491 //
1492 // In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1493 // is `Delta` (defined below).
1494 template <typename ExtendOpTy>
1495 bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1496  const SCEV *Step,
1497  const Loop *L) {
1498  auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1499 
1500  // We restrict `Start` to a constant to prevent SCEV from spending too much
1501  // time here. It is correct (but more expensive) to continue with a
1502  // non-constant `Start` and do a general SCEV subtraction to compute
1503  // `PreStart` below.
1504  const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1505  if (!StartC)
1506  return false;
1507 
1508  APInt StartAI = StartC->getAPInt();
1509 
1510  for (unsigned Delta : {-2, -1, 1, 2}) {
1511  const SCEV *PreStart = getConstant(StartAI - Delta);
1512 
1514  ID.AddInteger(scAddRecExpr);
1515  ID.AddPointer(PreStart);
1516  ID.AddPointer(Step);
1517  ID.AddPointer(L);
1518  void *IP = nullptr;
1519  const auto *PreAR =
1520  static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1521 
1522  // Give up if we don't already have the add recurrence we need because
1523  // actually constructing an add recurrence is relatively expensive.
1524  if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1525  const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1527  const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1528  DeltaS, &Pred, this);
1529  if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1530  return true;
1531  }
1532  }
1533 
1534  return false;
1535 }
1536 
1537 // Finds an integer D for an expression (C + x + y + ...) such that the top
1538 // level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1539 // unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1540 // maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1541 // the (C + x + y + ...) expression is \p WholeAddExpr.
1543  const SCEVConstant *ConstantTerm,
1544  const SCEVAddExpr *WholeAddExpr) {
1545  const APInt &C = ConstantTerm->getAPInt();
1546  const unsigned BitWidth = C.getBitWidth();
1547  // Find number of trailing zeros of (x + y + ...) w/o the C first:
1548  uint32_t TZ = BitWidth;
1549  for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1550  TZ = std::min(TZ, SE.GetMinTrailingZeros(WholeAddExpr->getOperand(I)));
1551  if (TZ) {
1552  // Set D to be as many least significant bits of C as possible while still
1553  // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1554  return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1555  }
1556  return APInt(BitWidth, 0);
1557 }
1558 
1559 // Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1560 // level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1561 // number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1562 // ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1564  const APInt &ConstantStart,
1565  const SCEV *Step) {
1566  const unsigned BitWidth = ConstantStart.getBitWidth();
1567  const uint32_t TZ = SE.GetMinTrailingZeros(Step);
1568  if (TZ)
1569  return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1570  : ConstantStart;
1571  return APInt(BitWidth, 0);
1572 }
1573 
1574 const SCEV *
1576  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1577  "This is not an extending conversion!");
1578  assert(isSCEVable(Ty) &&
1579  "This is not a conversion to a SCEVable type!");
1580  assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1581  Ty = getEffectiveSCEVType(Ty);
1582 
1583  // Fold if the operand is constant.
1584  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1585  return getConstant(
1586  cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(), Ty)));
1587 
1588  // zext(zext(x)) --> zext(x)
1589  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1590  return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1591 
1592  // Before doing any expensive analysis, check to see if we've already
1593  // computed a SCEV for this Op and Ty.
1595  ID.AddInteger(scZeroExtend);
1596  ID.AddPointer(Op);
1597  ID.AddPointer(Ty);
1598  void *IP = nullptr;
1599  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1600  if (Depth > MaxCastDepth) {
1601  SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1602  Op, Ty);
1603  UniqueSCEVs.InsertNode(S, IP);
1604  registerUser(S, Op);
1605  return S;
1606  }
1607 
1608  // zext(trunc(x)) --> zext(x) or x or trunc(x)
1609  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1610  // It's possible the bits taken off by the truncate were all zero bits. If
1611  // so, we should be able to simplify this further.
1612  const SCEV *X = ST->getOperand();
1614  unsigned TruncBits = getTypeSizeInBits(ST->getType());
1615  unsigned NewBits = getTypeSizeInBits(Ty);
1616  if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1617  CR.zextOrTrunc(NewBits)))
1618  return getTruncateOrZeroExtend(X, Ty, Depth);
1619  }
1620 
1621  // If the input value is a chrec scev, and we can prove that the value
1622  // did not overflow the old, smaller, value, we can zero extend all of the
1623  // operands (often constants). This allows analysis of something like
1624  // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1625  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1626  if (AR->isAffine()) {
1627  const SCEV *Start = AR->getStart();
1628  const SCEV *Step = AR->getStepRecurrence(*this);
1629  unsigned BitWidth = getTypeSizeInBits(AR->getType());
1630  const Loop *L = AR->getLoop();
1631 
1632  if (!AR->hasNoUnsignedWrap()) {
1633  auto NewFlags = proveNoWrapViaConstantRanges(AR);
1634  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1635  }
1636 
1637  // If we have special knowledge that this addrec won't overflow,
1638  // we don't need to do any further analysis.
1639  if (AR->hasNoUnsignedWrap())
1640  return getAddRecExpr(
1641  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1),
1642  getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
1643 
1644  // Check whether the backedge-taken count is SCEVCouldNotCompute.
1645  // Note that this serves two purposes: It filters out loops that are
1646  // simply not analyzable, and it covers the case where this code is
1647  // being called from within backedge-taken count analysis, such that
1648  // attempting to ask for the backedge-taken count would likely result
1649  // in infinite recursion. In the later case, the analysis code will
1650  // cope with a conservative value, and it will take care to purge
1651  // that value once it has finished.
1652  const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1653  if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1654  // Manually compute the final value for AR, checking for overflow.
1655 
1656  // Check whether the backedge-taken count can be losslessly casted to
1657  // the addrec's type. The count is always unsigned.
1658  const SCEV *CastedMaxBECount =
1659  getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1660  const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1661  CastedMaxBECount, MaxBECount->getType(), Depth);
1662  if (MaxBECount == RecastedMaxBECount) {
1663  Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1664  // Check whether Start+Step*MaxBECount has no unsigned overflow.
1665  const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1666  SCEV::FlagAnyWrap, Depth + 1);
1667  const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1669  Depth + 1),
1670  WideTy, Depth + 1);
1671  const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1672  const SCEV *WideMaxBECount =
1673  getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1674  const SCEV *OperandExtendedAdd =
1675  getAddExpr(WideStart,
1676  getMulExpr(WideMaxBECount,
1677  getZeroExtendExpr(Step, WideTy, Depth + 1),
1678  SCEV::FlagAnyWrap, Depth + 1),
1679  SCEV::FlagAnyWrap, Depth + 1);
1680  if (ZAdd == OperandExtendedAdd) {
1681  // Cache knowledge of AR NUW, which is propagated to this AddRec.
1682  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1683  // Return the expression with the addrec on the outside.
1684  return getAddRecExpr(
1685  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1686  Depth + 1),
1687  getZeroExtendExpr(Step, Ty, Depth + 1), L,
1688  AR->getNoWrapFlags());
1689  }
1690  // Similar to above, only this time treat the step value as signed.
1691  // This covers loops that count down.
1692  OperandExtendedAdd =
1693  getAddExpr(WideStart,
1694  getMulExpr(WideMaxBECount,
1695  getSignExtendExpr(Step, WideTy, Depth + 1),
1696  SCEV::FlagAnyWrap, Depth + 1),
1697  SCEV::FlagAnyWrap, Depth + 1);
1698  if (ZAdd == OperandExtendedAdd) {
1699  // Cache knowledge of AR NW, which is propagated to this AddRec.
1700  // Negative step causes unsigned wrap, but it still can't self-wrap.
1701  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1702  // Return the expression with the addrec on the outside.
1703  return getAddRecExpr(
1704  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1705  Depth + 1),
1706  getSignExtendExpr(Step, Ty, Depth + 1), L,
1707  AR->getNoWrapFlags());
1708  }
1709  }
1710  }
1711 
1712  // Normally, in the cases we can prove no-overflow via a
1713  // backedge guarding condition, we can also compute a backedge
1714  // taken count for the loop. The exceptions are assumptions and
1715  // guards present in the loop -- SCEV is not great at exploiting
1716  // these to compute max backedge taken counts, but can still use
1717  // these to prove lack of overflow. Use this fact to avoid
1718  // doing extra work that may not pay off.
1719  if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1720  !AC.assumptions().empty()) {
1721 
1722  auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1723  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1724  if (AR->hasNoUnsignedWrap()) {
1725  // Same as nuw case above - duplicated here to avoid a compile time
1726  // issue. It's not clear that the order of checks does matter, but
1727  // it's one of two issue possible causes for a change which was
1728  // reverted. Be conservative for the moment.
1729  return getAddRecExpr(
1730  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1731  Depth + 1),
1732  getZeroExtendExpr(Step, Ty, Depth + 1), L,
1733  AR->getNoWrapFlags());
1734  }
1735 
1736  // For a negative step, we can extend the operands iff doing so only
1737  // traverses values in the range zext([0,UINT_MAX]).
1738  if (isKnownNegative(Step)) {
1740  getSignedRangeMin(Step));
1743  // Cache knowledge of AR NW, which is propagated to this
1744  // AddRec. Negative step causes unsigned wrap, but it
1745  // still can't self-wrap.
1746  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1747  // Return the expression with the addrec on the outside.
1748  return getAddRecExpr(
1749  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1750  Depth + 1),
1751  getSignExtendExpr(Step, Ty, Depth + 1), L,
1752  AR->getNoWrapFlags());
1753  }
1754  }
1755  }
1756 
1757  // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1758  // if D + (C - D + Step * n) could be proven to not unsigned wrap
1759  // where D maximizes the number of trailing zeros of (C - D + Step * n)
1760  if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1761  const APInt &C = SC->getAPInt();
1762  const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1763  if (D != 0) {
1764  const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1765  const SCEV *SResidual =
1766  getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1767  const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1768  return getAddExpr(SZExtD, SZExtR,
1770  Depth + 1);
1771  }
1772  }
1773 
1774  if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1775  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1776  return getAddRecExpr(
1777  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1),
1778  getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
1779  }
1780  }
1781 
1782  // zext(A % B) --> zext(A) % zext(B)
1783  {
1784  const SCEV *LHS;
1785  const SCEV *RHS;
1786  if (matchURem(Op, LHS, RHS))
1787  return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1788  getZeroExtendExpr(RHS, Ty, Depth + 1));
1789  }
1790 
1791  // zext(A / B) --> zext(A) / zext(B).
1792  if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1793  return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1794  getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1795 
1796  if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1797  // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1798  if (SA->hasNoUnsignedWrap()) {
1799  // If the addition does not unsign overflow then we can, by definition,
1800  // commute the zero extension with the addition operation.
1802  for (const auto *Op : SA->operands())
1803  Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1804  return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1805  }
1806 
1807  // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1808  // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1809  // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1810  //
1811  // Often address arithmetics contain expressions like
1812  // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1813  // This transformation is useful while proving that such expressions are
1814  // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1815  if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1816  const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1817  if (D != 0) {
1818  const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1819  const SCEV *SResidual =
1821  const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1822  return getAddExpr(SZExtD, SZExtR,
1824  Depth + 1);
1825  }
1826  }
1827  }
1828 
1829  if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1830  // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1831  if (SM->hasNoUnsignedWrap()) {
1832  // If the multiply does not unsign overflow then we can, by definition,
1833  // commute the zero extension with the multiply operation.
1835  for (const auto *Op : SM->operands())
1836  Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1837  return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1838  }
1839 
1840  // zext(2^K * (trunc X to iN)) to iM ->
1841  // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1842  //
1843  // Proof:
1844  //
1845  // zext(2^K * (trunc X to iN)) to iM
1846  // = zext((trunc X to iN) << K) to iM
1847  // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1848  // (because shl removes the top K bits)
1849  // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1850  // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1851  //
1852  if (SM->getNumOperands() == 2)
1853  if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1854  if (MulLHS->getAPInt().isPowerOf2())
1855  if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1856  int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
1857  MulLHS->getAPInt().logBase2();
1858  Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1859  return getMulExpr(
1860  getZeroExtendExpr(MulLHS, Ty),
1862  getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
1863  SCEV::FlagNUW, Depth + 1);
1864  }
1865  }
1866 
1867  // The cast wasn't folded; create an explicit cast node.
1868  // Recompute the insert position, as it may have been invalidated.
1869  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1870  SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1871  Op, Ty);
1872  UniqueSCEVs.InsertNode(S, IP);
1873  registerUser(S, Op);
1874  return S;
1875 }
1876 
1877 const SCEV *
1879  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1880  "This is not an extending conversion!");
1881  assert(isSCEVable(Ty) &&
1882  "This is not a conversion to a SCEVable type!");
1883  assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1884  Ty = getEffectiveSCEVType(Ty);
1885 
1886  // Fold if the operand is constant.
1887  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1888  return getConstant(
1889  cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(), Ty)));
1890 
1891  // sext(sext(x)) --> sext(x)
1892  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1893  return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1894 
1895  // sext(zext(x)) --> zext(x)
1896  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1897  return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1898 
1899  // Before doing any expensive analysis, check to see if we've already
1900  // computed a SCEV for this Op and Ty.
1902  ID.AddInteger(scSignExtend);
1903  ID.AddPointer(Op);
1904  ID.AddPointer(Ty);
1905  void *IP = nullptr;
1906  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1907  // Limit recursion depth.
1908  if (Depth > MaxCastDepth) {
1909  SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1910  Op, Ty);
1911  UniqueSCEVs.InsertNode(S, IP);
1912  registerUser(S, Op);
1913  return S;
1914  }
1915 
1916  // sext(trunc(x)) --> sext(x) or x or trunc(x)
1917  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1918  // It's possible the bits taken off by the truncate were all sign bits. If
1919  // so, we should be able to simplify this further.
1920  const SCEV *X = ST->getOperand();
1922  unsigned TruncBits = getTypeSizeInBits(ST->getType());
1923  unsigned NewBits = getTypeSizeInBits(Ty);
1924  if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1925  CR.sextOrTrunc(NewBits)))
1926  return getTruncateOrSignExtend(X, Ty, Depth);
1927  }
1928 
1929  if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1930  // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1931  if (SA->hasNoSignedWrap()) {
1932  // If the addition does not sign overflow then we can, by definition,
1933  // commute the sign extension with the addition operation.
1935  for (const auto *Op : SA->operands())
1936  Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1937  return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1938  }
1939 
1940  // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1941  // if D + (C - D + x + y + ...) could be proven to not signed wrap
1942  // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1943  //
1944  // For instance, this will bring two seemingly different expressions:
1945  // 1 + sext(5 + 20 * %x + 24 * %y) and
1946  // sext(6 + 20 * %x + 24 * %y)
1947  // to the same form:
1948  // 2 + sext(4 + 20 * %x + 24 * %y)
1949  if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1950  const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1951  if (D != 0) {
1952  const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1953  const SCEV *SResidual =
1955  const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1956  return getAddExpr(SSExtD, SSExtR,
1958  Depth + 1);
1959  }
1960  }
1961  }
1962  // If the input value is a chrec scev, and we can prove that the value
1963  // did not overflow the old, smaller, value, we can sign extend all of the
1964  // operands (often constants). This allows analysis of something like
1965  // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
1966  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1967  if (AR->isAffine()) {
1968  const SCEV *Start = AR->getStart();
1969  const SCEV *Step = AR->getStepRecurrence(*this);
1970  unsigned BitWidth = getTypeSizeInBits(AR->getType());
1971  const Loop *L = AR->getLoop();
1972 
1973  if (!AR->hasNoSignedWrap()) {
1974  auto NewFlags = proveNoWrapViaConstantRanges(AR);
1975  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1976  }
1977 
1978  // If we have special knowledge that this addrec won't overflow,
1979  // we don't need to do any further analysis.
1980  if (AR->hasNoSignedWrap())
1981  return getAddRecExpr(
1982  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
1983  getSignExtendExpr(Step, Ty, Depth + 1), L, SCEV::FlagNSW);
1984 
1985  // Check whether the backedge-taken count is SCEVCouldNotCompute.
1986  // Note that this serves two purposes: It filters out loops that are
1987  // simply not analyzable, and it covers the case where this code is
1988  // being called from within backedge-taken count analysis, such that
1989  // attempting to ask for the backedge-taken count would likely result
1990  // in infinite recursion. In the later case, the analysis code will
1991  // cope with a conservative value, and it will take care to purge
1992  // that value once it has finished.
1993  const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1994  if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1995  // Manually compute the final value for AR, checking for
1996  // overflow.
1997 
1998  // Check whether the backedge-taken count can be losslessly casted to
1999  // the addrec's type. The count is always unsigned.
2000  const SCEV *CastedMaxBECount =
2001  getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2002  const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2003  CastedMaxBECount, MaxBECount->getType(), Depth);
2004  if (MaxBECount == RecastedMaxBECount) {
2005  Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2006  // Check whether Start+Step*MaxBECount has no signed overflow.
2007  const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2008  SCEV::FlagAnyWrap, Depth + 1);
2009  const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2011  Depth + 1),
2012  WideTy, Depth + 1);
2013  const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2014  const SCEV *WideMaxBECount =
2015  getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2016  const SCEV *OperandExtendedAdd =
2017  getAddExpr(WideStart,
2018  getMulExpr(WideMaxBECount,
2019  getSignExtendExpr(Step, WideTy, Depth + 1),
2020  SCEV::FlagAnyWrap, Depth + 1),
2021  SCEV::FlagAnyWrap, Depth + 1);
2022  if (SAdd == OperandExtendedAdd) {
2023  // Cache knowledge of AR NSW, which is propagated to this AddRec.
2024  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2025  // Return the expression with the addrec on the outside.
2026  return getAddRecExpr(
2027  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2028  Depth + 1),
2029  getSignExtendExpr(Step, Ty, Depth + 1), L,
2030  AR->getNoWrapFlags());
2031  }
2032  // Similar to above, only this time treat the step value as unsigned.
2033  // This covers loops that count up with an unsigned step.
2034  OperandExtendedAdd =
2035  getAddExpr(WideStart,
2036  getMulExpr(WideMaxBECount,
2037  getZeroExtendExpr(Step, WideTy, Depth + 1),
2038  SCEV::FlagAnyWrap, Depth + 1),
2039  SCEV::FlagAnyWrap, Depth + 1);
2040  if (SAdd == OperandExtendedAdd) {
2041  // If AR wraps around then
2042  //
2043  // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2044  // => SAdd != OperandExtendedAdd
2045  //
2046  // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2047  // (SAdd == OperandExtendedAdd => AR is NW)
2048 
2049  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2050 
2051  // Return the expression with the addrec on the outside.
2052  return getAddRecExpr(
2053  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2054  Depth + 1),
2055  getZeroExtendExpr(Step, Ty, Depth + 1), L,
2056  AR->getNoWrapFlags());
2057  }
2058  }
2059  }
2060 
2061  auto NewFlags = proveNoSignedWrapViaInduction(AR);
2062  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2063  if (AR->hasNoSignedWrap()) {
2064  // Same as nsw case above - duplicated here to avoid a compile time
2065  // issue. It's not clear that the order of checks does matter, but
2066  // it's one of two issue possible causes for a change which was
2067  // reverted. Be conservative for the moment.
2068  return getAddRecExpr(
2069  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
2070  getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
2071  }
2072 
2073  // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2074  // if D + (C - D + Step * n) could be proven to not signed wrap
2075  // where D maximizes the number of trailing zeros of (C - D + Step * n)
2076  if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2077  const APInt &C = SC->getAPInt();
2078  const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2079  if (D != 0) {
2080  const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2081  const SCEV *SResidual =
2082  getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2083  const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2084  return getAddExpr(SSExtD, SSExtR,
2086  Depth + 1);
2087  }
2088  }
2089 
2090  if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2091  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2092  return getAddRecExpr(
2093  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
2094  getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
2095  }
2096  }
2097 
2098  // If the input value is provably positive and we could not simplify
2099  // away the sext build a zext instead.
2100  if (isKnownNonNegative(Op))
2101  return getZeroExtendExpr(Op, Ty, Depth + 1);
2102 
2103  // The cast wasn't folded; create an explicit cast node.
2104  // Recompute the insert position, as it may have been invalidated.
2105  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2106  SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2107  Op, Ty);
2108  UniqueSCEVs.InsertNode(S, IP);
2109  registerUser(S, { Op });
2110  return S;
2111 }
2112 
2113 /// getAnyExtendExpr - Return a SCEV for the given operand extended with
2114 /// unspecified bits out to the given type.
2116  Type *Ty) {
2117  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2118  "This is not an extending conversion!");
2119  assert(isSCEVable(Ty) &&
2120  "This is not a conversion to a SCEVable type!");
2121  Ty = getEffectiveSCEVType(Ty);
2122 
2123  // Sign-extend negative constants.
2124  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2125  if (SC->getAPInt().isNegative())
2126  return getSignExtendExpr(Op, Ty);
2127 
2128  // Peel off a truncate cast.
2129  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2130  const SCEV *NewOp = T->getOperand();
2131  if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2132  return getAnyExtendExpr(NewOp, Ty);
2133  return getTruncateOrNoop(NewOp, Ty);
2134  }
2135 
2136  // Next try a zext cast. If the cast is folded, use it.
2137  const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2138  if (!isa<SCEVZeroExtendExpr>(ZExt))
2139  return ZExt;
2140 
2141  // Next try a sext cast. If the cast is folded, use it.
2142  const SCEV *SExt = getSignExtendExpr(Op, Ty);
2143  if (!isa<SCEVSignExtendExpr>(SExt))
2144  return SExt;
2145 
2146  // Force the cast to be folded into the operands of an addrec.
2147  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2149  for (const SCEV *Op : AR->operands())
2150  Ops.push_back(getAnyExtendExpr(Op, Ty));
2151  return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2152  }
2153 
2154  // If the expression is obviously signed, use the sext cast value.
2155  if (isa<SCEVSMaxExpr>(Op))
2156  return SExt;
2157 
2158  // Absent any other information, use the zext cast value.
2159  return ZExt;
2160 }
2161 
2162 /// Process the given Ops list, which is a list of operands to be added under
2163 /// the given scale, update the given map. This is a helper function for
2164 /// getAddRecExpr. As an example of what it does, given a sequence of operands
2165 /// that would form an add expression like this:
2166 ///
2167 /// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2168 ///
2169 /// where A and B are constants, update the map with these values:
2170 ///
2171 /// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2172 ///
2173 /// and add 13 + A*B*29 to AccumulatedConstant.
2174 /// This will allow getAddRecExpr to produce this:
2175 ///
2176 /// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2177 ///
2178 /// This form often exposes folding opportunities that are hidden in
2179 /// the original operand list.
2180 ///
2181 /// Return true iff it appears that any interesting folding opportunities
2182 /// may be exposed. This helps getAddRecExpr short-circuit extra work in
2183 /// the common case where no interesting opportunities are present, and
2184 /// is also used as a check to avoid infinite recursion.
2185 static bool
2188  APInt &AccumulatedConstant,
2189  const SCEV *const *Ops, size_t NumOperands,
2190  const APInt &Scale,
2191  ScalarEvolution &SE) {
2192  bool Interesting = false;
2193 
2194  // Iterate over the add operands. They are sorted, with constants first.
2195  unsigned i = 0;
2196  while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2197  ++i;
2198  // Pull a buried constant out to the outside.
2199  if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2200  Interesting = true;
2201  AccumulatedConstant += Scale * C->getAPInt();
2202  }
2203 
2204  // Next comes everything else. We're especially interested in multiplies
2205  // here, but they're in the middle, so just visit the rest with one loop.
2206  for (; i != NumOperands; ++i) {
2207  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
2208  if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2209  APInt NewScale =
2210  Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2211  if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2212  // A multiplication of a constant with another add; recurse.
2213  const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2214  Interesting |=
2215  CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2216  Add->op_begin(), Add->getNumOperands(),
2217  NewScale, SE);
2218  } else {
2219  // A multiplication of a constant with some other value. Update
2220  // the map.
2221  SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2222  const SCEV *Key = SE.getMulExpr(MulOps);
2223  auto Pair = M.insert({Key, NewScale});
2224  if (Pair.second) {
2225  NewOps.push_back(Pair.first->first);
2226  } else {
2227  Pair.first->second += NewScale;
2228  // The map already had an entry for this value, which may indicate
2229  // a folding opportunity.
2230  Interesting = true;
2231  }
2232  }
2233  } else {
2234  // An ordinary operand. Update the map.
2235  std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2236  M.insert({Ops[i], Scale});
2237  if (Pair.second) {
2238  NewOps.push_back(Pair.first->first);
2239  } else {
2240  Pair.first->second += Scale;
2241  // The map already had an entry for this value, which may indicate
2242  // a folding opportunity.
2243  Interesting = true;
2244  }
2245  }
2246  }
2247 
2248  return Interesting;
2249 }
2250 
2252  const SCEV *LHS, const SCEV *RHS) {
2253  const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2254  SCEV::NoWrapFlags, unsigned);
2255  switch (BinOp) {
2256  default:
2257  llvm_unreachable("Unsupported binary op");
2258  case Instruction::Add:
2260  break;
2261  case Instruction::Sub:
2263  break;
2264  case Instruction::Mul:
2266  break;
2267  }
2268 
2269  const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2272 
2273  // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2274  auto *NarrowTy = cast<IntegerType>(LHS->getType());
2275  auto *WideTy =
2276  IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2277 
2278  const SCEV *A = (this->*Extension)(
2279  (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2280  const SCEV *B = (this->*Operation)((this->*Extension)(LHS, WideTy, 0),
2281  (this->*Extension)(RHS, WideTy, 0),
2282  SCEV::FlagAnyWrap, 0);
2283  return A == B;
2284 }
2285 
2286 std::pair<SCEV::NoWrapFlags, bool /*Deduced*/>
2288  const OverflowingBinaryOperator *OBO) {
2289  SCEV::NoWrapFlags Flags = SCEV::NoWrapFlags::FlagAnyWrap;
2290 
2291  if (OBO->hasNoUnsignedWrap())
2292  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2293  if (OBO->hasNoSignedWrap())
2294  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2295 
2296  bool Deduced = false;
2297 
2298  if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2299  return {Flags, Deduced};
2300 
2301  if (OBO->getOpcode() != Instruction::Add &&
2302  OBO->getOpcode() != Instruction::Sub &&
2303  OBO->getOpcode() != Instruction::Mul)
2304  return {Flags, Deduced};
2305 
2306  const SCEV *LHS = getSCEV(OBO->getOperand(0));
2307  const SCEV *RHS = getSCEV(OBO->getOperand(1));
2308 
2309  if (!OBO->hasNoUnsignedWrap() &&
2311  /* Signed */ false, LHS, RHS)) {
2312  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2313  Deduced = true;
2314  }
2315 
2316  if (!OBO->hasNoSignedWrap() &&
2318  /* Signed */ true, LHS, RHS)) {
2319  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2320  Deduced = true;
2321  }
2322 
2323  return {Flags, Deduced};
2324 }
2325 
2326 // We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2327 // `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2328 // can't-overflow flags for the operation if possible.
2329 static SCEV::NoWrapFlags
2331  const ArrayRef<const SCEV *> Ops,
2332  SCEV::NoWrapFlags Flags) {
2333  using namespace std::placeholders;
2334 
2335  using OBO = OverflowingBinaryOperator;
2336 
2337  bool CanAnalyze =
2338  Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr;
2339  (void)CanAnalyze;
2340  assert(CanAnalyze && "don't call from other places!");
2341 
2342  int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2343  SCEV::NoWrapFlags SignOrUnsignWrap =
2344  ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2345 
2346  // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2347  auto IsKnownNonNegative = [&](const SCEV *S) {
2348  return SE->isKnownNonNegative(S);
2349  };
2350 
2351  if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2352  Flags =
2353  ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2354 
2355  SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2356 
2357  if (SignOrUnsignWrap != SignOrUnsignMask &&
2358  (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2359  isa<SCEVConstant>(Ops[0])) {
2360 
2361  auto Opcode = [&] {
2362  switch (Type) {
2363  case scAddExpr:
2364  return Instruction::Add;
2365  case scMulExpr:
2366  return Instruction::Mul;
2367  default:
2368  llvm_unreachable("Unexpected SCEV op.");
2369  }
2370  }();
2371 
2372  const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2373 
2374  // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2375  if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2377  Opcode, C, OBO::NoSignedWrap);
2378  if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2379  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2380  }
2381 
2382  // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2383  if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2385  Opcode, C, OBO::NoUnsignedWrap);
2386  if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2387  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2388  }
2389  }
2390 
2391  // <0,+,nonnegative><nw> is also nuw
2392  // TODO: Add corresponding nsw case
2394  !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2395  Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2396  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2397 
2398  // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2400  Ops.size() == 2) {
2401  if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2402  if (UDiv->getOperand(1) == Ops[1])
2403  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2404  if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2405  if (UDiv->getOperand(1) == Ops[0])
2406  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2407  }
2408 
2409  return Flags;
2410 }
2411 
2413  return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2414 }
2415 
2416 /// Get a canonical add expression, or something simpler if possible.
2418  SCEV::NoWrapFlags OrigFlags,
2419  unsigned Depth) {
2420  assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2421  "only nuw or nsw allowed");
2422  assert(!Ops.empty() && "Cannot get empty add!");
2423  if (Ops.size() == 1) return Ops[0];
2424 #ifndef NDEBUG
2425  Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2426  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2427  assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2428  "SCEVAddExpr operand types don't match!");
2429  unsigned NumPtrs = count_if(
2430  Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2431  assert(NumPtrs <= 1 && "add has at most one pointer operand");
2432 #endif
2433 
2434  // Sort by complexity, this groups all similar expression types together.
2435  GroupByComplexity(Ops, &LI, DT);
2436 
2437  // If there are any constants, fold them together.
2438  unsigned Idx = 0;
2439  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2440  ++Idx;
2441  assert(Idx < Ops.size());
2442  while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2443  // We found two constants, fold them together!
2444  Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt());
2445  if (Ops.size() == 2) return Ops[0];
2446  Ops.erase(Ops.begin()+1); // Erase the folded element
2447  LHSC = cast<SCEVConstant>(Ops[0]);
2448  }
2449 
2450  // If we are left with a constant zero being added, strip it off.
2451  if (LHSC->getValue()->isZero()) {
2452  Ops.erase(Ops.begin());
2453  --Idx;
2454  }
2455 
2456  if (Ops.size() == 1) return Ops[0];
2457  }
2458 
2459  // Delay expensive flag strengthening until necessary.
2460  auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
2461  return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2462  };
2463 
2464  // Limit recursion calls depth.
2465  if (Depth > MaxArithDepth || hasHugeExpression(Ops))
2466  return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2467 
2468  if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2469  // Don't strengthen flags if we have no new information.
2470  SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2471  if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2472  Add->setNoWrapFlags(ComputeFlags(Ops));
2473  return S;
2474  }
2475 
2476  // Okay, check to see if the same value occurs in the operand list more than
2477  // once. If so, merge them together into an multiply expression. Since we
2478  // sorted the list, these values are required to be adjacent.
2479  Type *Ty = Ops[0]->getType();
2480  bool FoundMatch = false;
2481  for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2482  if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2483  // Scan ahead to count how many equal operands there are.
2484  unsigned Count = 2;
2485  while (i+Count != e && Ops[i+Count] == Ops[i])
2486  ++Count;
2487  // Merge the values into a multiply.
2488  const SCEV *Scale = getConstant(Ty, Count);
2489  const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2490  if (Ops.size() == Count)
2491  return Mul;
2492  Ops[i] = Mul;
2493  Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2494  --i; e -= Count - 1;
2495  FoundMatch = true;
2496  }
2497  if (FoundMatch)
2498  return getAddExpr(Ops, OrigFlags, Depth + 1);
2499 
2500  // Check for truncates. If all the operands are truncated from the same
2501  // type, see if factoring out the truncate would permit the result to be
2502  // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2503  // if the contents of the resulting outer trunc fold to something simple.
2504  auto FindTruncSrcType = [&]() -> Type * {
2505  // We're ultimately looking to fold an addrec of truncs and muls of only
2506  // constants and truncs, so if we find any other types of SCEV
2507  // as operands of the addrec then we bail and return nullptr here.
2508  // Otherwise, we return the type of the operand of a trunc that we find.
2509  if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2510  return T->getOperand()->getType();
2511  if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2512  const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2513  if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2514  return T->getOperand()->getType();
2515  }
2516  return nullptr;
2517  };
2518  if (auto *SrcType = FindTruncSrcType()) {
2520  bool Ok = true;
2521  // Check all the operands to see if they can be represented in the
2522  // source type of the truncate.
2523  for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
2524  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
2525  if (T->getOperand()->getType() != SrcType) {
2526  Ok = false;
2527  break;
2528  }
2529  LargeOps.push_back(T->getOperand());
2530  } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2531  LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2532  } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
2533  SmallVector<const SCEV *, 8> LargeMulOps;
2534  for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2535  if (const SCEVTruncateExpr *T =
2536  dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2537  if (T->getOperand()->getType() != SrcType) {
2538  Ok = false;
2539  break;
2540  }
2541  LargeMulOps.push_back(T->getOperand());
2542  } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2543  LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2544  } else {
2545  Ok = false;
2546  break;
2547  }
2548  }
2549  if (Ok)
2550  LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2551  } else {
2552  Ok = false;
2553  break;
2554  }
2555  }
2556  if (Ok) {
2557  // Evaluate the expression in the larger type.
2558  const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2559  // If it folds to something simple, use it. Otherwise, don't.
2560  if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2561  return getTruncateExpr(Fold, Ty);
2562  }
2563  }
2564 
2565  if (Ops.size() == 2) {
2566  // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2567  // C2 can be folded in a way that allows retaining wrapping flags of (X +
2568  // C1).
2569  const SCEV *A = Ops[0];
2570  const SCEV *B = Ops[1];
2571  auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2572  auto *C = dyn_cast<SCEVConstant>(A);
2573  if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2574  auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2575  auto C2 = C->getAPInt();
2576  SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2577 
2578  APInt ConstAdd = C1 + C2;
2579  auto AddFlags = AddExpr->getNoWrapFlags();
2580  // Adding a smaller constant is NUW if the original AddExpr was NUW.
2581  if (ScalarEvolution::hasFlags(AddFlags, SCEV::FlagNUW) &&
2582  ConstAdd.ule(C1)) {
2583  PreservedFlags =
2584  ScalarEvolution::setFlags(PreservedFlags, SCEV::FlagNUW);
2585  }
2586 
2587  // Adding a constant with the same sign and small magnitude is NSW, if the
2588  // original AddExpr was NSW.
2589  if (ScalarEvolution::hasFlags(AddFlags, SCEV::FlagNSW) &&
2590  C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2591  ConstAdd.abs().ule(C1.abs())) {
2592  PreservedFlags =
2593  ScalarEvolution::setFlags(PreservedFlags, SCEV::FlagNSW);
2594  }
2595 
2596  if (PreservedFlags != SCEV::FlagAnyWrap) {
2597  SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2598  NewOps[0] = getConstant(ConstAdd);
2599  return getAddExpr(NewOps, PreservedFlags);
2600  }
2601  }
2602  }
2603 
2604  // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2605  if (Ops.size() == 2) {
2606  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
2607  if (Mul && Mul->getNumOperands() == 2 &&
2608  Mul->getOperand(0)->isAllOnesValue()) {
2609  const SCEV *X;
2610  const SCEV *Y;
2611  if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
2612  return getMulExpr(Y, getUDivExpr(X, Y));
2613  }
2614  }
2615  }
2616 
2617  // Skip past any other cast SCEVs.
2618  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2619  ++Idx;
2620 
2621  // If there are add operands they would be next.
2622  if (Idx < Ops.size()) {
2623  bool DeletedAdd = false;
2624  // If the original flags and all inlined SCEVAddExprs are NUW, use the
2625  // common NUW flag for expression after inlining. Other flags cannot be
2626  // preserved, because they may depend on the original order of operations.
2627  SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2628  while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2629  if (Ops.size() > AddOpsInlineThreshold ||
2630  Add->getNumOperands() > AddOpsInlineThreshold)
2631  break;
2632  // If we have an add, expand the add operands onto the end of the operands
2633  // list.
2634  Ops.erase(Ops.begin()+Idx);
2635  Ops.append(Add->op_begin(), Add->op_end());
2636  DeletedAdd = true;
2637  CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2638  }
2639 
2640  // If we deleted at least one add, we added operands to the end of the list,
2641  // and they are not necessarily sorted. Recurse to resort and resimplify
2642  // any operands we just acquired.
2643  if (DeletedAdd)
2644  return getAddExpr(Ops, CommonFlags, Depth + 1);
2645  }
2646 
2647  // Skip over the add expression until we get to a multiply.
2648  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2649  ++Idx;
2650 
2651  // Check to see if there are any folding opportunities present with
2652  // operands multiplied by constant values.
2653  if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2657  APInt AccumulatedConstant(BitWidth, 0);
2658  if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2659  Ops.data(), Ops.size(),
2660  APInt(BitWidth, 1), *this)) {
2661  struct APIntCompare {
2662  bool operator()(const APInt &LHS, const APInt &RHS) const {
2663  return LHS.ult(RHS);
2664  }
2665  };
2666 
2667  // Some interesting folding opportunity is present, so its worthwhile to
2668  // re-generate the operands list. Group the operands by constant scale,
2669  // to avoid multiplying by the same constant scale multiple times.
2670  std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2671  for (const SCEV *NewOp : NewOps)
2672  MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2673  // Re-generate the operands list.
2674  Ops.clear();
2675  if (AccumulatedConstant != 0)
2676  Ops.push_back(getConstant(AccumulatedConstant));
2677  for (auto &MulOp : MulOpLists) {
2678  if (MulOp.first == 1) {
2679  Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2680  } else if (MulOp.first != 0) {
2681  Ops.push_back(getMulExpr(
2682  getConstant(MulOp.first),
2683  getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2684  SCEV::FlagAnyWrap, Depth + 1));
2685  }
2686  }
2687  if (Ops.empty())
2688  return getZero(Ty);
2689  if (Ops.size() == 1)
2690  return Ops[0];
2691  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2692  }
2693  }
2694 
2695  // If we are adding something to a multiply expression, make sure the
2696  // something is not already an operand of the multiply. If so, merge it into
2697  // the multiply.
2698  for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2699  const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2700  for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2701  const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2702  if (isa<SCEVConstant>(MulOpSCEV))
2703  continue;
2704  for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2705  if (MulOpSCEV == Ops[AddOp]) {
2706  // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2707  const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2708  if (Mul->getNumOperands() != 2) {
2709  // If the multiply has more than two operands, we must get the
2710  // Y*Z term.
2711  SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
2712  Mul->op_begin()+MulOp);
2713  MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
2714  InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2715  }
2716  SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2717  const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2718  const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2719  SCEV::FlagAnyWrap, Depth + 1);
2720  if (Ops.size() == 2) return OuterMul;
2721  if (AddOp < Idx) {
2722  Ops.erase(Ops.begin()+AddOp);
2723  Ops.erase(Ops.begin()+Idx-1);
2724  } else {
2725  Ops.erase(Ops.begin()+Idx);
2726  Ops.erase(Ops.begin()+AddOp-1);
2727  }
2728  Ops.push_back(OuterMul);
2729  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2730  }
2731 
2732  // Check this multiply against other multiplies being added together.
2733  for (unsigned OtherMulIdx = Idx+1;
2734  OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2735  ++OtherMulIdx) {
2736  const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2737  // If MulOp occurs in OtherMul, we can fold the two multiplies
2738  // together.
2739  for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2740  OMulOp != e; ++OMulOp)
2741  if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2742  // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2743  const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2744  if (Mul->getNumOperands() != 2) {
2745  SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
2746  Mul->op_begin()+MulOp);
2747  MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
2748  InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2749  }
2750  const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2751  if (OtherMul->getNumOperands() != 2) {
2752  SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(),
2753  OtherMul->op_begin()+OMulOp);
2754  MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end());
2755  InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2756  }
2757  SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2758  const SCEV *InnerMulSum =
2759  getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2760  const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2761  SCEV::FlagAnyWrap, Depth + 1);
2762  if (Ops.size() == 2) return OuterMul;
2763  Ops.erase(Ops.begin()+Idx);
2764  Ops.erase(Ops.begin()+OtherMulIdx-1);
2765  Ops.push_back(OuterMul);
2766  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2767  }
2768  }
2769  }
2770  }
2771 
2772  // If there are any add recurrences in the operands list, see if any other
2773  // added values are loop invariant. If so, we can fold them into the
2774  // recurrence.
2775  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2776  ++Idx;
2777 
2778  // Scan over all recurrences, trying to fold loop invariants into them.
2779  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2780  // Scan all of the other operands to this add and add them to the vector if
2781  // they are loop invariant w.r.t. the recurrence.
2783  const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2784  const Loop *AddRecLoop = AddRec->getLoop();
2785  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2786  if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2787  LIOps.push_back(Ops[i]);
2788  Ops.erase(Ops.begin()+i);
2789  --i; --e;
2790  }
2791 
2792  // If we found some loop invariants, fold them into the recurrence.
2793  if (!LIOps.empty()) {
2794  // Compute nowrap flags for the addition of the loop-invariant ops and
2795  // the addrec. Temporarily push it as an operand for that purpose. These
2796  // flags are valid in the scope of the addrec only.
2797  LIOps.push_back(AddRec);
2798  SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2799  LIOps.pop_back();
2800 
2801  // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2802  LIOps.push_back(AddRec->getStart());
2803 
2804  SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2805 
2806  // It is not in general safe to propagate flags valid on an add within
2807  // the addrec scope to one outside it. We must prove that the inner
2808  // scope is guaranteed to execute if the outer one does to be able to
2809  // safely propagate. We know the program is undefined if poison is
2810  // produced on the inner scoped addrec. We also know that *for this use*
2811  // the outer scoped add can't overflow (because of the flags we just
2812  // computed for the inner scoped add) without the program being undefined.
2813  // Proving that entry to the outer scope neccesitates entry to the inner
2814  // scope, thus proves the program undefined if the flags would be violated
2815  // in the outer scope.
2816  SCEV::NoWrapFlags AddFlags = Flags;
2817  if (AddFlags != SCEV::FlagAnyWrap) {
2818  auto *DefI = getDefiningScopeBound(LIOps);
2819  auto *ReachI = &*AddRecLoop->getHeader()->begin();
2820  if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2821  AddFlags = SCEV::FlagAnyWrap;
2822  }
2823  AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2824 
2825  // Build the new addrec. Propagate the NUW and NSW flags if both the
2826  // outer add and the inner addrec are guaranteed to have no overflow.
2827  // Always propagate NW.
2828  Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2829  const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2830 
2831  // If all of the other operands were loop invariant, we are done.
2832  if (Ops.size() == 1) return NewRec;
2833 
2834  // Otherwise, add the folded AddRec by the non-invariant parts.
2835  for (unsigned i = 0;; ++i)
2836  if (Ops[i] == AddRec) {
2837  Ops[i] = NewRec;
2838  break;
2839  }
2840  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2841  }
2842 
2843  // Okay, if there weren't any loop invariants to be folded, check to see if
2844  // there are multiple AddRec's with the same loop induction variable being
2845  // added together. If so, we can fold them.
2846  for (unsigned OtherIdx = Idx+1;
2847  OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2848  ++OtherIdx) {
2849  // We expect the AddRecExpr's to be sorted in reverse dominance order,
2850  // so that the 1st found AddRecExpr is dominated by all others.
2851  assert(DT.dominates(
2852  cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2853  AddRec->getLoop()->getHeader()) &&
2854  "AddRecExprs are not sorted in reverse dominance order?");
2855  if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2856  // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2857  SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2858  for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2859  ++OtherIdx) {
2860  const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2861  if (OtherAddRec->getLoop() == AddRecLoop) {
2862  for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2863  i != e; ++i) {
2864  if (i >= AddRecOps.size()) {
2865  AddRecOps.append(OtherAddRec->op_begin()+i,
2866  OtherAddRec->op_end());
2867  break;
2868  }
2869  SmallVector<const SCEV *, 2> TwoOps = {
2870  AddRecOps[i], OtherAddRec->getOperand(i)};
2871  AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2872  }
2873  Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2874  }
2875  }
2876  // Step size has changed, so we cannot guarantee no self-wraparound.
2877  Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2878  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2879  }
2880  }
2881 
2882  // Otherwise couldn't fold anything into this recurrence. Move onto the
2883  // next one.
2884  }
2885 
2886  // Okay, it looks like we really DO need an add expr. Check to see if we
2887  // already have one, otherwise create a new one.
2888  return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2889 }
2890 
2891 const SCEV *
2892 ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2893  SCEV::NoWrapFlags Flags) {
2895  ID.AddInteger(scAddExpr);
2896  for (const SCEV *Op : Ops)
2897  ID.AddPointer(Op);
2898  void *IP = nullptr;
2899  SCEVAddExpr *S =
2900  static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2901  if (!S) {
2902  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2903  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2904  S = new (SCEVAllocator)
2905  SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
2906  UniqueSCEVs.InsertNode(S, IP);
2907  registerUser(S, Ops);
2908  }
2909  S->setNoWrapFlags(Flags);
2910  return S;
2911 }
2912 
2913 const SCEV *
2914 ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
2915  const Loop *L, SCEV::NoWrapFlags Flags) {
2917  ID.AddInteger(scAddRecExpr);
2918  for (const SCEV *Op : Ops)
2919  ID.AddPointer(Op);
2920  ID.AddPointer(L);
2921  void *IP = nullptr;
2922  SCEVAddRecExpr *S =
2923  static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2924  if (!S) {
2925  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2926  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2927  S = new (SCEVAllocator)
2928  SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
2929  UniqueSCEVs.InsertNode(S, IP);
2930  LoopUsers[L].push_back(S);
2931  registerUser(S, Ops);
2932  }
2933  setNoWrapFlags(S, Flags);
2934  return S;
2935 }
2936 
2937 const SCEV *
2938 ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
2939  SCEV::NoWrapFlags Flags) {
2941  ID.AddInteger(scMulExpr);
2942  for (const SCEV *Op : Ops)
2943  ID.AddPointer(Op);
2944  void *IP = nullptr;
2945  SCEVMulExpr *S =
2946  static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2947  if (!S) {
2948  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2949  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2950  S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
2951  O, Ops.size());
2952  UniqueSCEVs.InsertNode(S, IP);
2953  registerUser(S, Ops);
2954  }
2955  S->setNoWrapFlags(Flags);
2956  return S;
2957 }
2958 
2959 static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
2960  uint64_t k = i*j;
2961  if (j > 1 && k / j != i) Overflow = true;
2962  return k;
2963 }
2964 
2965 /// Compute the result of "n choose k", the binomial coefficient. If an
2966 /// intermediate computation overflows, Overflow will be set and the return will
2967 /// be garbage. Overflow is not cleared on absence of overflow.
2968 static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
2969  // We use the multiplicative formula:
2970  // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
2971  // At each iteration, we take the n-th term of the numeral and divide by the
2972  // (k-n)th term of the denominator. This division will always produce an
2973  // integral result, and helps reduce the chance of overflow in the
2974  // intermediate computations. However, we can still overflow even when the
2975  // final result would fit.
2976 
2977  if (n == 0 || n == k) return 1;
2978  if (k > n) return 0;
2979 
2980  if (k > n/2)
2981  k = n-k;
2982 
2983  uint64_t r = 1;
2984  for (uint64_t i = 1; i <= k; ++i) {
2985  r = umul_ov(r, n-(i-1), Overflow);
2986  r /= i;
2987  }
2988  return r;
2989 }
2990 
2991 /// Determine if any of the operands in this SCEV are a constant or if
2992 /// any of the add or multiply expressions in this SCEV contain a constant.
2993 static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
2994  struct FindConstantInAddMulChain {
2995  bool FoundConstant = false;
2996 
2997  bool follow(const SCEV *S) {
2998  FoundConstant |= isa<SCEVConstant>(S);
2999  return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3000  }
3001 
3002  bool isDone() const {
3003  return FoundConstant;
3004  }
3005  };
3006 
3007  FindConstantInAddMulChain F;
3009  ST.visitAll(StartExpr);
3010  return F.FoundConstant;
3011 }
3012 
3013 /// Get a canonical multiply expression, or something simpler if possible.
3015  SCEV::NoWrapFlags OrigFlags,
3016  unsigned Depth) {
3017  assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3018  "only nuw or nsw allowed");
3019  assert(!Ops.empty() && "Cannot get empty mul!");
3020  if (Ops.size() == 1) return Ops[0];
3021 #ifndef NDEBUG
3022  Type *ETy = Ops[0]->getType();
3023  assert(!ETy->isPointerTy());
3024  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3025  assert(Ops[i]->getType() == ETy &&
3026  "SCEVMulExpr operand types don't match!");
3027 #endif
3028 
3029  // Sort by complexity, this groups all similar expression types together.
3030  GroupByComplexity(Ops, &LI, DT);
3031 
3032  // If there are any constants, fold them together.
3033  unsigned Idx = 0;
3034  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3035  ++Idx;
3036  assert(Idx < Ops.size());
3037  while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3038  // We found two constants, fold them together!
3039  Ops[0] = getConstant(LHSC->getAPInt() * RHSC->getAPInt());
3040  if (Ops.size() == 2) return Ops[0];
3041  Ops.erase(Ops.begin()+1); // Erase the folded element
3042  LHSC = cast<SCEVConstant>(Ops[0]);
3043  }
3044 
3045  // If we have a multiply of zero, it will always be zero.
3046  if (LHSC->getValue()->isZero())
3047  return LHSC;
3048 
3049  // If we are left with a constant one being multiplied, strip it off.
3050  if (LHSC->getValue()->isOne()) {
3051  Ops.erase(Ops.begin());
3052  --Idx;
3053  }
3054 
3055  if (Ops.size() == 1)
3056  return Ops[0];
3057  }
3058 
3059  // Delay expensive flag strengthening until necessary.
3060  auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
3061  return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3062  };
3063 
3064  // Limit recursion calls depth.
3065  if (Depth > MaxArithDepth || hasHugeExpression(Ops))
3066  return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3067 
3068  if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3069  // Don't strengthen flags if we have no new information.
3070  SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3071  if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3072  Mul->setNoWrapFlags(ComputeFlags(Ops));
3073  return S;
3074  }
3075 
3076  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3077  if (Ops.size() == 2) {
3078  // C1*(C2+V) -> C1*C2 + C1*V
3079  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
3080  // If any of Add's ops are Adds or Muls with a constant, apply this
3081  // transformation as well.
3082  //
3083  // TODO: There are some cases where this transformation is not
3084  // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3085  // this transformation should be narrowed down.
3086  if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add))
3087  return getAddExpr(getMulExpr(LHSC, Add->getOperand(0),
3088  SCEV::FlagAnyWrap, Depth + 1),
3089  getMulExpr(LHSC, Add->getOperand(1),
3090  SCEV::FlagAnyWrap, Depth + 1),
3091  SCEV::FlagAnyWrap, Depth + 1);
3092 
3093  if (Ops[0]->isAllOnesValue()) {
3094  // If we have a mul by -1 of an add, try distributing the -1 among the
3095  // add operands.
3096  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3098  bool AnyFolded = false;
3099  for (const SCEV *AddOp : Add->operands()) {
3100  const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3101  Depth + 1);
3102  if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3103  NewOps.push_back(Mul);
3104  }
3105  if (AnyFolded)
3106  return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3107  } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3108  // Negation preserves a recurrence's no self-wrap property.
3110  for (const SCEV *AddRecOp : AddRec->operands())
3111  Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3112  Depth + 1));
3113 
3114  return getAddRecExpr(Operands, AddRec->getLoop(),
3115  AddRec->getNoWrapFlags(SCEV::FlagNW));
3116  }
3117  }
3118  }
3119  }
3120 
3121  // Skip over the add expression until we get to a multiply.
3122  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3123  ++Idx;
3124 
3125  // If there are mul operands inline them all into this expression.
3126  if (Idx < Ops.size()) {
3127  bool DeletedMul = false;
3128  while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3129  if (Ops.size() > MulOpsInlineThreshold)
3130  break;
3131  // If we have an mul, expand the mul operands onto the end of the
3132  // operands list.
3133  Ops.erase(Ops.begin()+Idx);
3134  Ops.append(Mul->op_begin(), Mul->op_end());
3135  DeletedMul = true;
3136  }
3137 
3138  // If we deleted at least one mul, we added operands to the end of the
3139  // list, and they are not necessarily sorted. Recurse to resort and
3140  // resimplify any operands we just acquired.
3141  if (DeletedMul)
3142  return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3143  }
3144 
3145  // If there are any add recurrences in the operands list, see if any other
3146  // added values are loop invariant. If so, we can fold them into the
3147  // recurrence.
3148  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3149  ++Idx;
3150 
3151  // Scan over all recurrences, trying to fold loop invariants into them.
3152  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3153  // Scan all of the other operands to this mul and add them to the vector
3154  // if they are loop invariant w.r.t. the recurrence.
3156  const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3157  const Loop *AddRecLoop = AddRec->getLoop();
3158  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3159  if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
3160  LIOps.push_back(Ops[i]);
3161  Ops.erase(Ops.begin()+i);
3162  --i; --e;
3163  }
3164 
3165  // If we found some loop invariants, fold them into the recurrence.
3166  if (!LIOps.empty()) {
3167  // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3169  NewOps.reserve(AddRec->getNumOperands());
3170  const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3171  for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
3172  NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3173  SCEV::FlagAnyWrap, Depth + 1));
3174 
3175  // Build the new addrec. Propagate the NUW and NSW flags if both the
3176  // outer mul and the inner addrec are guaranteed to have no overflow.
3177  //
3178  // No self-wrap cannot be guaranteed after changing the step size, but
3179  // will be inferred if either NUW or NSW is true.
3180  SCEV::NoWrapFlags Flags = ComputeFlags({Scale, AddRec});
3181  const SCEV *NewRec = getAddRecExpr(
3182  NewOps, AddRecLoop, AddRec->getNoWrapFlags(Flags));
3183 
3184  // If all of the other operands were loop invariant, we are done.
3185  if (Ops.size() == 1) return NewRec;
3186 
3187  // Otherwise, multiply the folded AddRec by the non-invariant parts.
3188  for (unsigned i = 0;; ++i)
3189  if (Ops[i] == AddRec) {
3190  Ops[i] = NewRec;
3191  break;
3192  }
3193  return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3194  }
3195 
3196  // Okay, if there weren't any loop invariants to be folded, check to see
3197  // if there are multiple AddRec's with the same loop induction variable
3198  // being multiplied together. If so, we can fold them.
3199 
3200  // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3201  // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3202  // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3203  // ]]],+,...up to x=2n}.
3204  // Note that the arguments to choose() are always integers with values
3205  // known at compile time, never SCEV objects.
3206  //
3207  // The implementation avoids pointless extra computations when the two
3208  // addrec's are of different length (mathematically, it's equivalent to
3209  // an infinite stream of zeros on the right).
3210  bool OpsModified = false;
3211  for (unsigned OtherIdx = Idx+1;
3212  OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3213  ++OtherIdx) {
3214  const SCEVAddRecExpr *OtherAddRec =
3215  dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3216  if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop)
3217  continue;
3218 
3219  // Limit max number of arguments to avoid creation of unreasonably big
3220  // SCEVAddRecs with very complex operands.
3221  if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3222  MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3223  continue;
3224 
3225  bool Overflow = false;
3226  Type *Ty = AddRec->getType();
3227  bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3228  SmallVector<const SCEV*, 7> AddRecOps;
3229  for (int x = 0, xe = AddRec->getNumOperands() +
3230  OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3232  for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3233  uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3234  for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3235  ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3236  z < ze && !Overflow; ++z) {
3237  uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3238  uint64_t Coeff;
3239  if (LargerThan64Bits)
3240  Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3241  else
3242  Coeff = Coeff1*Coeff2;
3243  const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3244  const SCEV *Term1 = AddRec->getOperand(y-z);
3245  const SCEV *Term2 = OtherAddRec->getOperand(z);
3246  SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3247  SCEV::FlagAnyWrap, Depth + 1));
3248  }
3249  }
3250  if (SumOps.empty())
3251  SumOps.push_back(getZero(Ty));
3252  AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3253  }
3254  if (!Overflow) {
3255  const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRecLoop,
3257  if (Ops.size() == 2) return NewAddRec;
3258  Ops[Idx] = NewAddRec;
3259  Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3260  OpsModified = true;
3261  AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3262  if (!AddRec)
3263  break;
3264  }
3265  }
3266  if (OpsModified)
3267  return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3268 
3269  // Otherwise couldn't fold anything into this recurrence. Move onto the
3270  // next one.
3271  }
3272 
3273  // Okay, it looks like we really DO need an mul expr. Check to see if we
3274  // already have one, otherwise create a new one.
3275  return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3276 }
3277 
3278 /// Represents an unsigned remainder expression based on unsigned division.
3280  const SCEV *RHS) {
3282  getEffectiveSCEVType(RHS->getType()) &&
3283  "SCEVURemExpr operand types don't match!");
3284 
3285  // Short-circuit easy cases
3286  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3287  // If constant is one, the result is trivial
3288  if (RHSC->getValue()->isOne())
3289  return getZero(LHS->getType()); // X urem 1 --> 0
3290 
3291  // If constant is a power of two, fold into a zext(trunc(LHS)).
3292  if (RHSC->getAPInt().isPowerOf2()) {
3293  Type *FullTy = LHS->getType();
3294  Type *TruncTy =
3295  IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3296  return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3297  }
3298  }
3299 
3300  // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3301  const SCEV *UDiv = getUDivExpr(LHS, RHS);
3302  const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3303  return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3304 }
3305 
3306 /// Get a canonical unsigned division expression, or something simpler if
3307 /// possible.
3309  const SCEV *RHS) {
3310  assert(!LHS->getType()->isPointerTy() &&
3311  "SCEVUDivExpr operand can't be pointer!");
3312  assert(LHS->getType() == RHS->getType() &&
3313  "SCEVUDivExpr operand types don't match!");
3314 
3316  ID.AddInteger(scUDivExpr);
3317  ID.AddPointer(LHS);
3318  ID.AddPointer(RHS);
3319  void *IP = nullptr;
3320  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3321  return S;
3322 
3323  // 0 udiv Y == 0
3324  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3325  if (LHSC->getValue()->isZero())
3326  return LHS;
3327 
3328  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3329  if (RHSC->getValue()->isOne())
3330  return LHS; // X udiv 1 --> x
3331  // If the denominator is zero, the result of the udiv is undefined. Don't
3332  // try to analyze it, because the resolution chosen here may differ from
3333  // the resolution chosen in other parts of the compiler.
3334  if (!RHSC->getValue()->isZero()) {
3335  // Determine if the division can be folded into the operands of
3336  // its operands.
3337  // TODO: Generalize this to non-constants by using known-bits information.
3338  Type *Ty = LHS->getType();
3339  unsigned LZ = RHSC->getAPInt().countLeadingZeros();
3340  unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3341  // For non-power-of-two values, effectively round the value up to the
3342  // nearest power of two.
3343  if (!RHSC->getAPInt().isPowerOf2())
3344  ++MaxShiftAmt;
3345  IntegerType *ExtTy =
3346  IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3347  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3348  if (const SCEVConstant *Step =
3349  dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3350  // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3351  const APInt &StepInt = Step->getAPInt();
3352  const APInt &DivInt = RHSC->getAPInt();
3353  if (!StepInt.urem(DivInt) &&
3354  getZeroExtendExpr(AR, ExtTy) ==
3355  getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3356  getZeroExtendExpr(Step, ExtTy),
3357  AR->getLoop(), SCEV::FlagAnyWrap)) {
3359  for (const SCEV *Op : AR->operands())
3360  Operands.push_back(getUDivExpr(Op, RHS));
3361  return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3362  }
3363  /// Get a canonical UDivExpr for a recurrence.
3364  /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3365  // We can currently only fold X%N if X is constant.
3366  const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
3367  if (StartC && !DivInt.urem(StepInt) &&
3368  getZeroExtendExpr(AR, ExtTy) ==
3369  getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3370  getZeroExtendExpr(Step, ExtTy),
3371  AR->getLoop(), SCEV::FlagAnyWrap)) {
3372  const APInt &StartInt = StartC->getAPInt();
3373  const APInt &StartRem = StartInt.urem(StepInt);
3374  if (StartRem != 0) {
3375  const SCEV *NewLHS =
3376  getAddRecExpr(getConstant(StartInt - StartRem), Step,
3377  AR->getLoop(), SCEV::FlagNW);
3378  if (LHS != NewLHS) {
3379  LHS = NewLHS;
3380 
3381  // Reset the ID to include the new LHS, and check if it is
3382  // already cached.
3383  ID.clear();
3384  ID.AddInteger(scUDivExpr);
3385  ID.AddPointer(LHS);
3386  ID.AddPointer(RHS);
3387  IP = nullptr;
3388  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3389  return S;
3390  }
3391  }
3392  }
3393  }
3394  // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3395  if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3397  for (const SCEV *Op : M->operands())
3398  Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3399  if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3400  // Find an operand that's safely divisible.
3401  for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3402  const SCEV *Op = M->getOperand(i);
3403  const SCEV *Div = getUDivExpr(Op, RHSC);
3404  if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3405  Operands = SmallVector<const SCEV *, 4>(M->operands());
3406  Operands[i] = Div;
3407  return getMulExpr(Operands);
3408  }
3409  }
3410  }
3411 
3412  // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3413  if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3414  if (auto *DivisorConstant =
3415  dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3416  bool Overflow = false;
3417  APInt NewRHS =
3418  DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3419  if (Overflow) {
3420  return getConstant(RHSC->getType(), 0, false);
3421  }
3422  return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3423  }
3424  }
3425 
3426  // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3427  if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3429  for (const SCEV *Op : A->operands())
3430  Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3431  if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3432  Operands.clear();
3433  for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3434  const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3435  if (isa<SCEVUDivExpr>(Op) ||
3436  getMulExpr(Op, RHS) != A->getOperand(i))
3437  break;
3438  Operands.push_back(Op);
3439  }
3440  if (Operands.size() == A->getNumOperands())
3441  return getAddExpr(Operands);
3442  }
3443  }
3444 
3445  // Fold if both operands are constant.
3446  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
3447  Constant *LHSCV = LHSC->getValue();
3448  Constant *RHSCV = RHSC->getValue();
3449  return getConstant(cast<ConstantInt>(ConstantExpr::getUDiv(LHSCV,
3450  RHSCV)));
3451  }
3452  }
3453  }
3454 
3455  // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3456  // changes). Make sure we get a new one.
3457  IP = nullptr;
3458  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3459  SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3460  LHS, RHS);
3461  UniqueSCEVs.InsertNode(S, IP);
3462  registerUser(S, {LHS, RHS});
3463  return S;
3464 }
3465 
3466 static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3467  APInt A = C1->getAPInt().abs();
3468  APInt B = C2->getAPInt().abs();
3469  uint32_t ABW = A.getBitWidth();
3470  uint32_t BBW = B.getBitWidth();
3471 
3472  if (ABW > BBW)
3473  B = B.zext(ABW);
3474  else if (ABW < BBW)
3475  A = A.zext(BBW);
3476 
3478 }
3479 
3480 /// Get a canonical unsigned division expression, or something simpler if
3481 /// possible. There is no representation for an exact udiv in SCEV IR, but we
3482 /// can attempt to remove factors from the LHS and RHS. We can't do this when
3483 /// it's not exact because the udiv may be clearing bits.
3485  const SCEV *RHS) {
3486  // TODO: we could try to find factors in all sorts of things, but for now we
3487  // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3488  // end of this file for inspiration.
3489 
3490  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
3491  if (!Mul || !Mul->hasNoUnsignedWrap())
3492  return getUDivExpr(LHS, RHS);
3493 
3494  if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3495  // If the mulexpr multiplies by a constant, then that constant must be the
3496  // first element of the mulexpr.
3497  if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3498  if (LHSCst == RHSCst) {
3500  return getMulExpr(Operands);
3501  }
3502 
3503  // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3504  // that there's a factor provided by one of the other terms. We need to
3505  // check.
3506  APInt Factor = gcd(LHSCst, RHSCst);
3507  if (!Factor.isIntN(1)) {
3508  LHSCst =
3509  cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3510  RHSCst =
3511  cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3513  Operands.push_back(LHSCst);
3514  Operands.append(Mul->op_begin() + 1, Mul->op_end());
3515  LHS = getMulExpr(Operands);
3516  RHS = RHSCst;
3517  Mul = dyn_cast<SCEVMulExpr>(LHS);
3518  if (!Mul)
3519  return getUDivExactExpr(LHS, RHS);
3520  }
3521  }
3522  }
3523 
3524  for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3525  if (Mul->getOperand(i) == RHS) {
3527  Operands.append(Mul->op_begin(), Mul->op_begin() + i);
3528  Operands.append(Mul->op_begin() + i + 1, Mul->op_end());
3529  return getMulExpr(Operands);
3530  }
3531  }
3532 
3533  return getUDivExpr(LHS, RHS);
3534 }
3535 
3536 /// Get an add recurrence expression for the specified loop. Simplify the
3537 /// expression as much as possible.
3538 const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3539  const Loop *L,
3540  SCEV::NoWrapFlags Flags) {
3542  Operands.push_back(Start);
3543  if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3544  if (StepChrec->getLoop() == L) {
3545  Operands.append(StepChrec->op_begin(), StepChrec->op_end());
3546  return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3547  }
3548 
3549  Operands.push_back(Step);
3550  return getAddRecExpr(Operands, L, Flags);
3551 }
3552 
3553 /// Get an add recurrence expression for the specified loop. Simplify the
3554 /// expression as much as possible.
3555 const SCEV *
3557  const Loop *L, SCEV::NoWrapFlags Flags) {
3558  if (Operands.size() == 1) return Operands[0];
3559 #ifndef NDEBUG
3560  Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3561  for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
3563  "SCEVAddRecExpr operand types don't match!");
3564  assert(!Operands[i]->getType()->isPointerTy() && "Step must be integer");
3565  }
3566  for (unsigned i = 0, e = Operands.size(); i != e; ++i)
3568  "SCEVAddRecExpr operand is not loop-invariant!");
3569 #endif
3570 
3571  if (Operands.back()->isZero()) {
3572  Operands.pop_back();
3573  return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3574  }
3575 
3576  // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3577  // use that information to infer NUW and NSW flags. However, computing a
3578  // BE count requires calling getAddRecExpr, so we may not yet have a
3579  // meaningful BE count at this point (and if we don't, we'd be stuck
3580  // with a SCEVCouldNotCompute as the cached BE count).
3581 
3582  Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3583 
3584  // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3585  if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3586  const Loop *NestedLoop = NestedAR->getLoop();
3587  if (L->contains(NestedLoop)
3588  ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3589  : (!NestedLoop->contains(L) &&
3590  DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3591  SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3592  Operands[0] = NestedAR->getStart();
3593  // AddRecs require their operands be loop-invariant with respect to their
3594  // loops. Don't perform this transformation if it would break this
3595  // requirement.
3596  bool AllInvariant = all_of(
3597  Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3598 
3599  if (AllInvariant) {
3600  // Create a recurrence for the outer loop with the same step size.
3601  //
3602  // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3603  // inner recurrence has the same property.
3604  SCEV::NoWrapFlags OuterFlags =
3605  maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3606 
3607  NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3608  AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3609  return isLoopInvariant(Op, NestedLoop);
3610  });
3611 
3612  if (AllInvariant) {
3613  // Ok, both add recurrences are valid after the transformation.
3614  //
3615  // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3616  // the outer recurrence has the same property.
3617  SCEV::NoWrapFlags InnerFlags =
3618  maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3619  return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3620  }
3621  }
3622  // Reset Operands to its original state.
3623  Operands[0] = NestedAR;
3624  }
3625  }
3626 
3627  // Okay, it looks like we really DO need an addrec expr. Check to see if we
3628  // already have one, otherwise create a new one.
3629  return getOrCreateAddRecExpr(Operands, L, Flags);
3630 }
3631 
3632 const SCEV *
3634  const SmallVectorImpl<const SCEV *> &IndexExprs) {
3635  const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3636  // getSCEV(Base)->getType() has the same address space as Base->getType()
3637  // because SCEV::getType() preserves the address space.
3638  Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3639  const bool AssumeInBoundsFlags = [&]() {
3640  if (!GEP->isInBounds())
3641  return false;
3642 
3643  // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3644  // but to do that, we have to ensure that said flag is valid in the entire
3645  // defined scope of the SCEV.
3646  auto *GEPI = dyn_cast<Instruction>(GEP);
3647  // TODO: non-instructions have global scope. We might be able to prove
3648  // some global scope cases
3649  return GEPI && isSCEVExprNeverPoison(GEPI);
3650  }();
3651 
3652  SCEV::NoWrapFlags OffsetWrap =
3653  AssumeInBoundsFlags ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
3654 
3655  Type *CurTy = GEP->getType();
3656  bool FirstIter = true;
3658  for (const SCEV *IndexExpr : IndexExprs) {
3659  // Compute the (potentially symbolic) offset in bytes for this index.
3660  if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3661  // For a struct, add the member offset.
3662  ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3663  unsigned FieldNo = Index->getZExtValue();
3664  const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3665  Offsets.push_back(FieldOffset);
3666 
3667  // Update CurTy to the type of the field at Index.
3668  CurTy = STy->getTypeAtIndex(Index);
3669  } else {
3670  // Update CurTy to its element type.
3671  if (FirstIter) {
3672  assert(isa<PointerType>(CurTy) &&
3673  "The first index of a GEP indexes a pointer");
3674  CurTy = GEP->getSourceElementType();
3675  FirstIter = false;
3676  } else {
3677  CurTy = GetElementPtrInst::getTypeAtIndex(CurTy, (uint64_t)0);
3678  }
3679  // For an array, add the element offset, explicitly scaled.
3680  const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3681  // Getelementptr indices are signed.
3682  IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3683 
3684  // Multiply the index by the element size to compute the element offset.
3685  const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3686  Offsets.push_back(LocalOffset);
3687  }
3688  }
3689 
3690  // Handle degenerate case of GEP without offsets.
3691  if (Offsets.empty())
3692  return BaseExpr;
3693 
3694  // Add the offsets together, assuming nsw if inbounds.
3695  const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3696  // Add the base address and the offset. We cannot use the nsw flag, as the
3697  // base address is unsigned. However, if we know that the offset is
3698  // non-negative, we can use nuw.
3699  SCEV::NoWrapFlags BaseWrap = AssumeInBoundsFlags && isKnownNonNegative(Offset)
3701  auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3702  assert(BaseExpr->getType() == GEPExpr->getType() &&
3703  "GEP should not change type mid-flight.");
3704  return GEPExpr;
3705 }
3706 
3707 SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3708  ArrayRef<const SCEV *> Ops) {
3710  ID.AddInteger(SCEVType);
3711  for (const SCEV *Op : Ops)
3712  ID.AddPointer(Op);
3713  void *IP = nullptr;
3714  return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3715 }
3716 
3717 const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3719  return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3720 }
3721 
3724  assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3725  if (Ops.size() == 1) return Ops[0];
3726 #ifndef NDEBUG
3727  Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3728  for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3729  assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3730  "Operand types don't match!");
3731  assert(Ops[0]->getType()->isPointerTy() ==
3732  Ops[i]->getType()->isPointerTy() &&
3733  "min/max should be consistently pointerish");
3734  }
3735 #endif
3736 
3737  bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3738  bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3739 
3740  // Sort by complexity, this groups all similar expression types together.
3741  GroupByComplexity(Ops, &LI, DT);
3742 
3743  // Check if we have created the same expression before.
3744  if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3745  return S;
3746  }
3747 
3748  // If there are any constants, fold them together.
3749  unsigned Idx = 0;
3750  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3751  ++Idx;
3752  assert(Idx < Ops.size());
3753  auto FoldOp = [&](const APInt &LHS, const APInt &RHS) {
3754  if (Kind == scSMaxExpr)
3755  return APIntOps::smax(LHS, RHS);
3756  else if (Kind == scSMinExpr)
3757  return APIntOps::smin(LHS, RHS);
3758  else if (Kind == scUMaxExpr)
3759  return APIntOps::umax(LHS, RHS);
3760  else if (Kind == scUMinExpr)
3761  return APIntOps::umin(LHS, RHS);
3762  llvm_unreachable("Unknown SCEV min/max opcode");
3763  };
3764 
3765  while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3766  // We found two constants, fold them together!
3767  ConstantInt *Fold = ConstantInt::get(
3768  getContext(), FoldOp(LHSC->getAPInt(), RHSC->getAPInt()));
3769  Ops[0] = getConstant(Fold);
3770  Ops.erase(Ops.begin()+1); // Erase the folded element
3771  if (Ops.size() == 1) return Ops[0];
3772  LHSC = cast<SCEVConstant>(Ops[0]);
3773  }
3774 
3775  bool IsMinV = LHSC->getValue()->isMinValue(IsSigned);
3776  bool IsMaxV = LHSC->getValue()->isMaxValue(IsSigned);
3777 
3778  if (IsMax ? IsMinV : IsMaxV) {
3779  // If we are left with a constant minimum(/maximum)-int, strip it off.
3780  Ops.erase(Ops.begin());
3781  --Idx;
3782  } else if (IsMax ? IsMaxV : IsMinV) {
3783  // If we have a max(/min) with a constant maximum(/minimum)-int,
3784  // it will always be the extremum.
3785  return LHSC;
3786  }
3787 
3788  if (Ops.size() == 1) return Ops[0];
3789  }
3790 
3791  // Find the first operation of the same kind
3792  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3793  ++Idx;
3794 
3795  // Check to see if one of the operands is of the same kind. If so, expand its
3796  // operands onto our operand list, and recurse to simplify.
3797  if (Idx < Ops.size()) {
3798  bool DeletedAny = false;
3799  while (Ops[Idx]->getSCEVType() == Kind) {
3800  const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3801  Ops.erase(Ops.begin()+Idx);
3802  Ops.append(SMME->op_begin(), SMME->op_end());
3803  DeletedAny = true;
3804  }
3805 
3806  if (DeletedAny)
3807  return getMinMaxExpr(Kind, Ops);
3808  }
3809 
3810  // Okay, check to see if the same value occurs in the operand list twice. If
3811  // so, delete one. Since we sorted the list, these values are required to
3812  // be adjacent.
3813  llvm::CmpInst::Predicate GEPred =
3815  llvm::CmpInst::Predicate LEPred =
3817  llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3818  llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3819  for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3820  if (Ops[i] == Ops[i + 1] ||
3821  isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3822  // X op Y op Y --> X op Y
3823  // X op Y --> X, if we know X, Y are ordered appropriately
3824  Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3825  --i;
3826  --e;
3827  } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3828  Ops[i + 1])) {
3829  // X op Y --> Y, if we know X, Y are ordered appropriately
3830  Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3831  --i;
3832  --e;
3833  }
3834  }
3835 
3836  if (Ops.size() == 1) return Ops[0];
3837 
3838  assert(!Ops.empty() && "Reduced smax down to nothing!");
3839 
3840  // Okay, it looks like we really DO need an expr. Check to see if we
3841  // already have one, otherwise create a new one.
3843  ID.AddInteger(Kind);
3844  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3845  ID.AddPointer(Ops[i]);
3846  void *IP = nullptr;
3847  const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3848  if (ExistingSCEV)
3849  return ExistingSCEV;
3850  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3851  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3852  SCEV *S = new (SCEVAllocator)
3853  SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3854 
3855  UniqueSCEVs.InsertNode(S, IP);
3856  registerUser(S, Ops);
3857  return S;
3858 }
3859 
3860 const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
3861  SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
3862  return getSMaxExpr(Ops);
3863 }
3864 
3866  return getMinMaxExpr(scSMaxExpr, Ops);
3867 }
3868 
3869 const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
3870  SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
3871  return getUMaxExpr(Ops);
3872 }
3873 
3875  return getMinMaxExpr(scUMaxExpr, Ops);
3876 }
3877 
3879  const SCEV *RHS) {
3880  SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
3881  return getSMinExpr(Ops);
3882 }
3883 
3885  return getMinMaxExpr(scSMinExpr, Ops);
3886 }
3887 
3889  const SCEV *RHS) {
3890  SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
3891  return getUMinExpr(Ops);
3892 }
3893 
3895  return getMinMaxExpr(scUMinExpr, Ops);
3896 }
3897 
3898 const SCEV *
3900  ScalableVectorType *ScalableTy) {
3901  Constant *NullPtr = Constant::getNullValue(ScalableTy->getPointerTo());
3902  Constant *One = ConstantInt::get(IntTy, 1);
3903  Constant *GEP = ConstantExpr::getGetElementPtr(ScalableTy, NullPtr, One);
3904  // Note that the expression we created is the final expression, we don't
3905  // want to simplify it any further Also, if we call a normal getSCEV(),
3906  // we'll end up in an endless recursion. So just create an SCEVUnknown.
3907  return getUnknown(ConstantExpr::getPtrToInt(GEP, IntTy));
3908 }
3909 
3910 const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) {
3911  if (auto *ScalableAllocTy = dyn_cast<ScalableVectorType>(AllocTy))
3912  return getSizeOfScalableVectorExpr(IntTy, ScalableAllocTy);
3913  // We can bypass creating a target-independent constant expression and then
3914  // folding it back into a ConstantInt. This is just a compile-time
3915  // optimization.
3916  return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
3917 }
3918 
3920  if (auto *ScalableStoreTy = dyn_cast<ScalableVectorType>(StoreTy))
3921  return getSizeOfScalableVectorExpr(IntTy, ScalableStoreTy);
3922  // We can bypass creating a target-independent constant expression and then
3923  // folding it back into a ConstantInt. This is just a compile-time
3924  // optimization.
3925  return getConstant(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
3926 }
3927 
3929  StructType *STy,
3930  unsigned FieldNo) {
3931  // We can bypass creating a target-independent constant expression and then
3932  // folding it back into a ConstantInt. This is just a compile-time
3933  // optimization.
3934  return getConstant(
3935  IntTy, getDataLayout().getStructLayout(STy)->getElementOffset(FieldNo));
3936 }
3937 
3939  // Don't attempt to do anything other than create a SCEVUnknown object
3940  // here. createSCEV only calls getUnknown after checking for all other
3941  // interesting possibilities, and any other code that calls getUnknown
3942  // is doing so in order to hide a value from SCEV canonicalization.
3943 
3945  ID.AddInteger(scUnknown);
3946  ID.AddPointer(V);
3947  void *IP = nullptr;
3948  if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
3949  assert(cast<SCEVUnknown>(S)->getValue() == V &&
3950  "Stale SCEVUnknown in uniquing map!");
3951  return S;
3952  }
3953  SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
3954  FirstUnknown);
3955  FirstUnknown = cast<SCEVUnknown>(S);
3956  UniqueSCEVs.InsertNode(S, IP);
3957  return S;
3958 }
3959 
3960 //===----------------------------------------------------------------------===//
3961 // Basic SCEV Analysis and PHI Idiom Recognition Code
3962 //
3963 
3964 /// Test if values of the given type are analyzable within the SCEV
3965 /// framework. This primarily includes integer types, and it can optionally
3966 /// include pointer types if the ScalarEvolution class has access to
3967 /// target-specific information.
3969  // Integers and pointers are always SCEVable.
3970  return Ty->isIntOrPtrTy();
3971 }
3972 
3973 /// Return the size in bits of the specified type, for which isSCEVable must
3974 /// return true.
3976  assert(isSCEVable(Ty) && "Type is not SCEVable!");
3977  if (Ty->isPointerTy())
3979  return getDataLayout().getTypeSizeInBits(Ty);
3980 }
3981 
3982 /// Return a type with the same bitwidth as the given type and which represents
3983 /// how SCEV will treat the given type, for which isSCEVable must return
3984 /// true. For pointer types, this is the pointer index sized integer type.
3986  assert(isSCEVable(Ty) && "Type is not SCEVable!");
3987 
3988  if (Ty->isIntegerTy())
3989  return Ty;
3990 
3991  // The only other support type is pointer.
3992  assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
3993  return getDataLayout().getIndexType(Ty);
3994 }
3995 
3997  return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
3998 }
3999 
4001  const SCEV *B) {
4002  /// For a valid use point to exist, the defining scope of one operand
4003  /// must dominate the other.
4004  bool PreciseA, PreciseB;
4005  auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4006  auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4007  if (!PreciseA || !PreciseB)
4008  // Can't tell.
4009  return false;
4010  return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4011  DT.dominates(ScopeB, ScopeA);
4012 }
4013 
4014 
4016  return CouldNotCompute.get();
4017 }
4018 
4019 bool ScalarEvolution::checkValidity(const SCEV *S) const {
4020  bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4021  auto *SU = dyn_cast<SCEVUnknown>(S);
4022  return SU && SU->getValue() == nullptr;
4023  });
4024 
4025  return !ContainsNulls;
4026 }
4027 
4029  HasRecMapType::iterator I = HasRecMap.find(S);
4030  if (I != HasRecMap.end())
4031  return I->second;
4032 
4033  bool FoundAddRec =
4034  SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4035  HasRecMap.insert({S, FoundAddRec});
4036  return FoundAddRec;
4037 }
4038 
4039 /// Try to split a SCEVAddExpr into a pair of {SCEV, ConstantInt}.
4040 /// If \p S is a SCEVAddExpr and is composed of a sub SCEV S' and an
4041 /// offset I, then return {S', I}, else return {\p S, nullptr}.
4042 static std::pair<const SCEV *, ConstantInt *> splitAddExpr(const SCEV *S) {
4043  const auto *Add = dyn_cast<SCEVAddExpr>(S);
4044  if (!Add)
4045  return {S, nullptr};
4046 
4047  if (Add->getNumOperands() != 2)
4048  return {S, nullptr};
4049 
4050  auto *ConstOp = dyn_cast<SCEVConstant>(Add->getOperand(0));
4051  if (!ConstOp)
4052  return {S, nullptr};
4053 
4054  return {Add->getOperand(1), ConstOp->getValue()};
4055 }
4056 
4057 /// Return the ValueOffsetPair set for \p S. \p S can be represented
4058 /// by the value and offset from any ValueOffsetPair in the set.
4060 ScalarEvolution::getSCEVValues(const SCEV *S) {
4061  ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4062  if (SI == ExprValueMap.end())
4063  return nullptr;
4064 #ifndef NDEBUG
4065  if (VerifySCEVMap) {
4066  // Check there is no dangling Value in the set returned.
4067  for (const auto &VE : SI->second)
4068  assert(ValueExprMap.count(VE.first));
4069  }
4070 #endif
4071  return &SI->second;
4072 }
4073 
4074 /// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4075 /// cannot be used separately. eraseValueFromMap should be used to remove
4076 /// V from ValueExprMap and ExprValueMap at the same time.
4077 void ScalarEvolution::eraseValueFromMap(Value *V) {
4078  ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4079  if (I != ValueExprMap.end()) {
4080  const SCEV *S = I->second;
4081  // Remove {V, 0} from the set of ExprValueMap[S]
4082  if (auto *SV = getSCEVValues(S))
4083  SV->remove({V, nullptr});
4084 
4085  // Remove {V, Offset} from the set of ExprValueMap[Stripped]
4086  const SCEV *Stripped;
4088  std::tie(Stripped, Offset) = splitAddExpr(S);
4089  if (Offset != nullptr) {
4090  if (auto *SV = getSCEVValues(Stripped))
4091  SV->remove({V, Offset});
4092  }
4093  ValueExprMap.erase(V);
4094  }
4095 }
4096 
4097 void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4098  // A recursive query may have already computed the SCEV. It should be
4099  // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4100  // inferred nowrap flags.
4101  auto It = ValueExprMap.find_as(V);
4102  if (It == ValueExprMap.end()) {
4103  ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4104  ExprValueMap[S].insert({V, nullptr});
4105  }
4106 }
4107 
4108 /// Return an existing SCEV if it exists, otherwise analyze the expression and
4109 /// create a new one.
4111  assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4112 
4113  const SCEV *S = getExistingSCEV(V);
4114  if (S == nullptr) {
4115  S = createSCEV(V);
4116  // During PHI resolution, it is possible to create two SCEVs for the same
4117  // V, so it is needed to double check whether V->S is inserted into
4118  // ValueExprMap before insert S->{V, 0} into ExprValueMap.
4119  std::pair<ValueExprMapType::iterator, bool> Pair =
4120  ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4121  if (Pair.second) {
4122  ExprValueMap[S].insert({V, nullptr});
4123 
4124  // If S == Stripped + Offset, add Stripped -> {V, Offset} into
4125  // ExprValueMap.
4126  const SCEV *Stripped = S;
4127  ConstantInt *Offset = nullptr;
4128  std::tie(Stripped, Offset) = splitAddExpr(S);
4129  // If stripped is SCEVUnknown, don't bother to save
4130  // Stripped -> {V, offset}. It doesn't simplify and sometimes even
4131  // increase the complexity of the expansion code.
4132  // If V is GetElementPtrInst, don't save Stripped -> {V, offset}
4133  // because it may generate add/sub instead of GEP in SCEV expansion.
4134  if (Offset != nullptr && !isa<SCEVUnknown>(Stripped) &&
4135  !isa<GetElementPtrInst>(V))
4136  ExprValueMap[Stripped].insert({V, Offset});
4137  }
4138  }
4139  return S;
4140 }
4141 
4142 const SCEV *ScalarEvolution::getExistingSCEV(Value *V) {
4143  assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4144 
4145  ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4146  if (I != ValueExprMap.end()) {
4147  const SCEV *S = I->second;
4148  assert(checkValidity(S) &&
4149  "existing SCEV has not been properly invalidated");
4150  return S;
4151  }
4152  return nullptr;
4153 }
4154 
4155 /// Return a SCEV corresponding to -V = -1*V
4157  SCEV::NoWrapFlags Flags) {
4158  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4159  return getConstant(
4160  cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4161 
4162  Type *Ty = V->getType();
4163  Ty = getEffectiveSCEVType(Ty);
4164  return getMulExpr(V, getMinusOne(Ty), Flags);
4165 }
4166 
4167 /// If Expr computes ~A, return A else return nullptr
4168 static const SCEV *MatchNotExpr(const SCEV *Expr) {
4169  const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
4170  if (!Add || Add->getNumOperands() != 2 ||
4171  !Add->getOperand(0)->isAllOnesValue())
4172  return nullptr;
4173 
4174  const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
4175  if (!AddRHS || AddRHS->getNumOperands() != 2 ||
4176  !AddRHS->getOperand(0)->isAllOnesValue())
4177  return nullptr;
4178 
4179  return AddRHS->getOperand(1);
4180 }
4181 
4182 /// Return a SCEV corresponding to ~V = -1-V
4184  assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4185 
4186  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4187  return getConstant(
4188  cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4189 
4190  // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4191  if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4192  auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4193  SmallVector<const SCEV *, 2> MatchedOperands;
4194  for (const SCEV *Operand : MME->operands()) {
4195  const SCEV *Matched = MatchNotExpr(Operand);
4196  if (!Matched)
4197  return (const SCEV *)nullptr;
4198  MatchedOperands.push_back(Matched);
4199  }
4200  return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4201  MatchedOperands);
4202  };
4203  if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4204  return Replaced;
4205  }
4206 
4207  Type *Ty = V->getType();
4208  Ty = getEffectiveSCEVType(Ty);
4209  return getMinusSCEV(getMinusOne(Ty), V);
4210 }
4211 
4213  assert(P->getType()->isPointerTy());
4214 
4215  if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4216  // The base of an AddRec is the first operand.
4217  SmallVector<const SCEV *> Ops{AddRec->operands()};
4218  Ops[0] = removePointerBase(Ops[0]);
4219  // Don't try to transfer nowrap flags for now. We could in some cases
4220  // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4221  return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4222  }
4223  if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4224  // The base of an Add is the pointer operand.
4225  SmallVector<const SCEV *> Ops{Add->operands()};
4226  const SCEV **PtrOp = nullptr;
4227  for (const SCEV *&AddOp : Ops) {
4228  if (AddOp->getType()->isPointerTy()) {
4229  assert(!PtrOp && "Cannot have multiple pointer ops");
4230  PtrOp = &AddOp;
4231  }
4232  }
4233  *PtrOp = removePointerBase(*PtrOp);
4234  // Don't try to transfer nowrap flags for now. We could in some cases
4235  // (for example, if the pointer operand of the Add is a SCEVUnknown).
4236  return getAddExpr(Ops);
4237  }
4238  // Any other expression must be a pointer base.
4239  return getZero(P->getType());
4240 }
4241 
4242 const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4243  SCEV::NoWrapFlags Flags,
4244  unsigned Depth) {
4245  // Fast path: X - X --> 0.
4246  if (LHS == RHS)
4247  return getZero(LHS->getType());
4248 
4249  // If we subtract two pointers with different pointer bases, bail.
4250  // Eventually, we're going to add an assertion to getMulExpr that we
4251  // can't multiply by a pointer.
4252  if (RHS->getType()->isPointerTy()) {
4253  if (!LHS->getType()->isPointerTy() ||
4254  getPointerBase(LHS) != getPointerBase(RHS))
4255  return getCouldNotCompute();
4256  LHS = removePointerBase(LHS);
4257  RHS = removePointerBase(RHS);
4258  }
4259 
4260  // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4261  // makes it so that we cannot make much use of NUW.
4262  auto AddFlags = SCEV::FlagAnyWrap;
4263  const bool RHSIsNotMinSigned =
4265  if (hasFlags(Flags, SCEV::FlagNSW)) {
4266  // Let M be the minimum representable signed value. Then (-1)*RHS
4267  // signed-wraps if and only if RHS is M. That can happen even for
4268  // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4269  // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4270  // (-1)*RHS, we need to prove that RHS != M.
4271  //
4272  // If LHS is non-negative and we know that LHS - RHS does not
4273  // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4274  // either by proving that RHS > M or that LHS >= 0.
4275  if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4276  AddFlags = SCEV::FlagNSW;
4277  }
4278  }
4279 
4280  // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4281  // RHS is NSW and LHS >= 0.
4282  //
4283  // The difficulty here is that the NSW flag may have been proven
4284  // relative to a loop that is to be found in a recurrence in LHS and
4285  // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4286  // larger scope than intended.
4287  auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4288 
4289  return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4290 }
4291 
4293  unsigned Depth) {
4294  Type *SrcTy = V->getType();
4295  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4296  "Cannot truncate or zero extend with non-integer arguments!");
4297  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4298  return V; // No conversion
4299  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4300  return getTruncateExpr(V, Ty, Depth);
4301  return getZeroExtendExpr(V, Ty, Depth);
4302 }
4303 
4305  unsigned Depth) {
4306  Type *SrcTy = V->getType();
4307  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4308  "Cannot truncate or zero extend with non-integer arguments!");
4309  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4310  return V; // No conversion
4311  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4312  return getTruncateExpr(V, Ty, Depth);
4313  return getSignExtendExpr(V, Ty, Depth);
4314 }
4315 
4316 const SCEV *
4318  Type *SrcTy = V->getType();
4319  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4320  "Cannot noop or zero extend with non-integer arguments!");
4321  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
4322  "getNoopOrZeroExtend cannot truncate!");
4323  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4324  return V; // No conversion
4325  return getZeroExtendExpr(V, Ty);
4326 }
4327 
4328 const SCEV *
4330  Type *SrcTy = V->getType();
4331  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4332  "Cannot noop or sign extend with non-integer arguments!");
4333  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
4334  "getNoopOrSignExtend cannot truncate!");
4335  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4336  return V; // No conversion
4337  return getSignExtendExpr(V, Ty);
4338 }
4339 
4340 const SCEV *
4342  Type *SrcTy = V->getType();
4343  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4344  "Cannot noop or any extend with non-integer arguments!");
4345  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
4346  "getNoopOrAnyExtend cannot truncate!");
4347  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4348  return V; // No conversion
4349  return getAnyExtendExpr(V, Ty);
4350 }
4351 
4352 const SCEV *
4354  Type *SrcTy = V->getType();
4355  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4356  "Cannot truncate or noop with non-integer arguments!");
4357  assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) &&
4358  "getTruncateOrNoop cannot extend!");
4359  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4360  return V; // No conversion
4361  return getTruncateExpr(V, Ty);
4362 }
4363 
4365  const SCEV *RHS) {
4366  const SCEV *PromotedLHS = LHS;
4367  const SCEV *PromotedRHS = RHS;
4368 
4369  if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4370  PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4371  else
4372  PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4373 
4374  return getUMaxExpr(PromotedLHS, PromotedRHS);
4375 }
4376 
4378  const SCEV *RHS) {
4379  SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4380  return getUMinFromMismatchedTypes(Ops);
4381 }
4382 
4385  assert(!Ops.empty() && "At least one operand must be!");
4386  // Trivial case.
4387  if (Ops.size() == 1)
4388  return Ops[0];
4389 
4390  // Find the max type first.
4391  Type *MaxType = nullptr;
4392  for (auto *S : Ops)
4393  if (MaxType)
4394  MaxType = getWiderType(MaxType, S->getType());
4395  else
4396  MaxType = S->getType();
4397  assert(MaxType && "Failed to find maximum type!");
4398 
4399  // Extend all ops to max type.
4400  SmallVector<const SCEV *, 2> PromotedOps;
4401  for (auto *S : Ops)
4402  PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4403 
4404  // Generate umin.
4405  return getUMinExpr(PromotedOps);
4406 }
4407 
4409  // A pointer operand may evaluate to a nonpointer expression, such as null.
4410  if (!V->getType()->isPointerTy())
4411  return V;
4412 
4413  while (true) {
4414  if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4415  V = AddRec->getStart();
4416  } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4417  const SCEV *PtrOp = nullptr;
4418  for (const SCEV *AddOp : Add->operands()) {
4419  if (AddOp->getType()->isPointerTy()) {
4420  assert(!PtrOp && "Cannot have multiple pointer ops");
4421  PtrOp = AddOp;
4422  }
4423  }
4424  assert(PtrOp && "Must have pointer op");
4425  V = PtrOp;
4426  } else // Not something we can look further into.
4427  return V;
4428  }
4429 }
4430 
4431 /// Push users of the given Instruction onto the given Worklist.
4434  SmallPtrSetImpl<Instruction *> &Visited) {
4435  // Push the def-use children onto the Worklist stack.
4436  for (User *U : I->users()) {
4437  auto *UserInsn = cast<Instruction>(U);
4438  if (Visited.insert(UserInsn).second)
4439  Worklist.push_back(UserInsn);
4440  }
4441 }
4442 
4443 namespace {
4444 
4445 /// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4446 /// expression in case its Loop is L. If it is not L then
4447 /// if IgnoreOtherLoops is true then use AddRec itself
4448 /// otherwise rewrite cannot be done.
4449 /// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4450 class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4451 public:
4452  static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4453  bool IgnoreOtherLoops = true) {
4454  SCEVInitRewriter Rewriter(L, SE);
4455  const SCEV *Result = Rewriter.visit(S);
4456  if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4457  return SE.getCouldNotCompute();
4458  return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4459  ? SE.getCouldNotCompute()
4460  : Result;
4461  }
4462 
4463  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4464  if (!SE.isLoopInvariant(Expr, L))
4465  SeenLoopVariantSCEVUnknown = true;
4466  return Expr;
4467  }
4468 
4469  const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4470  // Only re-write AddRecExprs for this loop.
4471  if (Expr->getLoop() == L)
4472  return Expr->getStart();
4473  SeenOtherLoops = true;
4474  return Expr;
4475  }
4476 
4477  bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4478 
4479  bool hasSeenOtherLoops() { return SeenOtherLoops; }
4480 
4481 private:
4482  explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4483  : SCEVRewriteVisitor(SE), L(L) {}
4484 
4485  const Loop *L;
4486  bool SeenLoopVariantSCEVUnknown = false;
4487  bool SeenOtherLoops = false;
4488 };
4489 
4490 /// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4491 /// increment expression in case its Loop is L. If it is not L then
4492 /// use AddRec itself.
4493 /// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4494 class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4495 public:
4496  static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4497  SCEVPostIncRewriter Rewriter(L, SE);
4498  const SCEV *Result = Rewriter.visit(S);
4499  return Rewriter.hasSeenLoopVariantSCEVUnknown()
4500  ? SE.getCouldNotCompute()
4501  : Result;
4502  }
4503 
4504  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4505  if (!SE.isLoopInvariant(Expr, L))
4506  SeenLoopVariantSCEVUnknown = true;
4507  return Expr;
4508  }
4509 
4510  const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4511  // Only re-write AddRecExprs for this loop.
4512  if (Expr->getLoop() == L)
4513  return Expr->getPostIncExpr(SE);
4514  SeenOtherLoops = true;
4515  return Expr;
4516  }
4517 
4518  bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4519 
4520  bool hasSeenOtherLoops() { return SeenOtherLoops; }
4521 
4522 private:
4523  explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
4524  : SCEVRewriteVisitor(SE), L(L) {}
4525 
4526  const Loop *L;
4527  bool SeenLoopVariantSCEVUnknown = false;
4528  bool SeenOtherLoops = false;
4529 };
4530 
4531 /// This class evaluates the compare condition by matching it against the
4532 /// condition of loop latch. If there is a match we assume a true value
4533 /// for the condition while building SCEV nodes.
4534 class SCEVBackedgeConditionFolder
4535  : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
4536 public:
4537  static const SCEV *rewrite(const SCEV *S, const Loop *L,
4538  ScalarEvolution &SE) {
4539  bool IsPosBECond = false;
4540  Value *BECond = nullptr;
4541  if (BasicBlock *Latch = L->getLoopLatch()) {
4542  BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
4543  if (BI && BI->isConditional()) {
4544  assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
4545  "Both outgoing branches should not target same header!");
4546  BECond = BI->getCondition();
4547  IsPosBECond = BI->getSuccessor(0) == L->getHeader();
4548  } else {
4549  return S;
4550  }
4551  }
4552  SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
4553  return Rewriter.visit(S);
4554  }
4555 
4556  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4557  const SCEV *Result = Expr;
4558  bool InvariantF = SE.isLoopInvariant(Expr, L);
4559 
4560  if (!InvariantF) {
4561  Instruction *I = cast<Instruction>(Expr->getValue());
4562  switch (I->getOpcode()) {
4563  case Instruction::Select: {
4564  SelectInst *SI = cast<SelectInst>(I);
4566  compareWithBackedgeCondition(SI->getCondition());
4567  if (Res.hasValue()) {
4568  bool IsOne = cast<SCEVConstant>(Res.getValue())->getValue()->isOne();
4569  Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
4570  }
4571  break;
4572  }
4573  default: {
4574  Optional<const SCEV *> Res = compareWithBackedgeCondition(I);
4575  if (Res.hasValue())
4576  Result = Res.getValue();
4577  break;
4578  }
4579  }
4580  }
4581  return Result;
4582  }
4583 
4584 private:
4585  explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
4586  bool IsPosBECond, ScalarEvolution &SE)
4587  : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
4588  IsPositiveBECond(IsPosBECond) {}
4589 
4590  Optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
4591 
4592  const Loop *L;
4593  /// Loop back condition.
4594  Value *BackedgeCond = nullptr;
4595  /// Set to true if loop back is on positive branch condition.
4596  bool IsPositiveBECond;
4597 };
4598 
4600 SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
4601 
4602  // If value matches the backedge condition for loop latch,
4603  // then return a constant evolution node based on loopback
4604  // branch taken.
4605  if (BackedgeCond == IC)
4606  return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
4607  : SE.getZero(Type::getInt1Ty(SE.getContext()));
4608  return None;
4609 }
4610 
4611 class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
4612 public:
4613  static const SCEV *rewrite(const SCEV *S, const Loop *L,
4614  ScalarEvolution &SE) {
4615  SCEVShiftRewriter Rewriter(L, SE);
4616  const SCEV *Result = Rewriter.visit(S);
4617  return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
4618  }
4619 
4620  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4621  // Only allow AddRecExprs for this loop.
4622  if (!SE.isLoopInvariant(Expr, L))
4623  Valid = false;
4624  return Expr;
4625  }
4626 
4627  const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4628  if (Expr->getLoop() == L && Expr->isAffine())
4629  return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
4630  Valid = false;
4631  return Expr;
4632  }
4633 
4634  bool isValid() { return Valid; }
4635 
4636 private:
4637  explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
4638  : SCEVRewriteVisitor(SE), L(L) {}
4639 
4640  const Loop *L;
4641  bool Valid = true;
4642 };
4643 
4644 } // end anonymous namespace
4645 
4647 ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
4648  if (!AR->isAffine())
4649  return SCEV::FlagAnyWrap;
4650 
4651  using OBO = OverflowingBinaryOperator;
4652 
4654 
4655  if (!AR->hasNoSignedWrap()) {
4656  ConstantRange AddRecRange = getSignedRange(AR);
4657  ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
4658 
4660  Instruction::Add, IncRange, OBO::NoSignedWrap);
4661  if (NSWRegion.contains(AddRecRange))
4663  }
4664 
4665  if (!AR->hasNoUnsignedWrap()) {
4666  ConstantRange AddRecRange = getUnsignedRange(AR);
4667  ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
4668 
4670  Instruction::Add, IncRange, OBO::NoUnsignedWrap);
4671  if (NUWRegion.contains(AddRecRange))
4673  }
4674 
4675  return Result;
4676 }
4677 
4679 ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
4681 
4682  if (AR->hasNoSignedWrap())
4683  return Result;
4684 
4685  if (!AR->isAffine())
4686  return Result;
4687 
4688  const SCEV *Step = AR->getStepRecurrence(*this);
4689  const Loop *L = AR->getLoop();
4690 
4691  // Check whether the backedge-taken count is SCEVCouldNotCompute.
4692  // Note that this serves two purposes: It filters out loops that are
4693  // simply not analyzable, and it covers the case where this code is
4694  // being called from within backedge-taken count analysis, such that
4695  // attempting to ask for the backedge-taken count would likely result
4696  // in infinite recursion. In the later case, the analysis code will
4697  // cope with a conservative value, and it will take care to purge
4698  // that value once it has finished.
4699  const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
4700 
4701  // Normally, in the cases we can prove no-overflow via a
4702  // backedge guarding condition, we can also compute a backedge
4703  // taken count for the loop. The exceptions are assumptions and
4704  // guards present in the loop -- SCEV is not great at exploiting
4705  // these to compute max backedg