LLVM  15.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1 //===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the scalar evolution analysis
10 // engine, which is used primarily to analyze expressions involving induction
11 // variables in loops.
12 //
13 // There are several aspects to this library. First is the representation of
14 // scalar expressions, which are represented as subclasses of the SCEV class.
15 // These classes are used to represent certain types of subexpressions that we
16 // can handle. We only create one SCEV of a particular shape, so
17 // pointer-comparisons for equality are legal.
18 //
19 // One important aspect of the SCEV objects is that they are never cyclic, even
20 // if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21 // the PHI node is one of the idioms that we can represent (e.g., a polynomial
22 // recurrence) then we represent it directly as a recurrence node, otherwise we
23 // represent it as a SCEVUnknown node.
24 //
25 // In addition to being able to represent expressions of various types, we also
26 // have folders that are used to build the *canonical* representation for a
27 // particular expression. These folders are capable of using a variety of
28 // rewrite rules to simplify the expressions.
29 //
30 // Once the folders are defined, we can implement the more interesting
31 // higher-level code, such as the code that recognizes PHI nodes of various
32 // types, computes the execution count of a loop, etc.
33 //
34 // TODO: We should use these routines and value representations to implement
35 // dependence analysis!
36 //
37 //===----------------------------------------------------------------------===//
38 //
39 // There are several good references for the techniques used in this analysis.
40 //
41 // Chains of recurrences -- a method to expedite the evaluation
42 // of closed-form functions
43 // Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44 //
45 // On computational properties of chains of recurrences
46 // Eugene V. Zima
47 //
48 // Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49 // Robert A. van Engelen
50 //
51 // Efficient Symbolic Analysis for Optimizing Compilers
52 // Robert A. van Engelen
53 //
54 // Using the chains of recurrences algebra for data dependence testing and
55 // induction variable substitution
56 // MS Thesis, Johnie Birch
57 //
58 //===----------------------------------------------------------------------===//
59 
61 #include "llvm/ADT/APInt.h"
62 #include "llvm/ADT/ArrayRef.h"
63 #include "llvm/ADT/DenseMap.h"
66 #include "llvm/ADT/FoldingSet.h"
67 #include "llvm/ADT/None.h"
68 #include "llvm/ADT/Optional.h"
69 #include "llvm/ADT/STLExtras.h"
70 #include "llvm/ADT/ScopeExit.h"
71 #include "llvm/ADT/Sequence.h"
72 #include "llvm/ADT/SetVector.h"
73 #include "llvm/ADT/SmallPtrSet.h"
74 #include "llvm/ADT/SmallSet.h"
75 #include "llvm/ADT/SmallVector.h"
76 #include "llvm/ADT/Statistic.h"
77 #include "llvm/ADT/StringRef.h"
81 #include "llvm/Analysis/LoopInfo.h"
85 #include "llvm/Config/llvm-config.h"
86 #include "llvm/IR/Argument.h"
87 #include "llvm/IR/BasicBlock.h"
88 #include "llvm/IR/CFG.h"
89 #include "llvm/IR/Constant.h"
90 #include "llvm/IR/ConstantRange.h"
91 #include "llvm/IR/Constants.h"
92 #include "llvm/IR/DataLayout.h"
93 #include "llvm/IR/DerivedTypes.h"
94 #include "llvm/IR/Dominators.h"
95 #include "llvm/IR/Function.h"
96 #include "llvm/IR/GlobalAlias.h"
97 #include "llvm/IR/GlobalValue.h"
98 #include "llvm/IR/InstIterator.h"
99 #include "llvm/IR/InstrTypes.h"
100 #include "llvm/IR/Instruction.h"
101 #include "llvm/IR/Instructions.h"
102 #include "llvm/IR/IntrinsicInst.h"
103 #include "llvm/IR/Intrinsics.h"
104 #include "llvm/IR/LLVMContext.h"
105 #include "llvm/IR/Operator.h"
106 #include "llvm/IR/PatternMatch.h"
107 #include "llvm/IR/Type.h"
108 #include "llvm/IR/Use.h"
109 #include "llvm/IR/User.h"
110 #include "llvm/IR/Value.h"
111 #include "llvm/IR/Verifier.h"
112 #include "llvm/InitializePasses.h"
113 #include "llvm/Pass.h"
114 #include "llvm/Support/Casting.h"
116 #include "llvm/Support/Compiler.h"
117 #include "llvm/Support/Debug.h"
119 #include "llvm/Support/KnownBits.h"
122 #include <algorithm>
123 #include <cassert>
124 #include <climits>
125 #include <cstdint>
126 #include <cstdlib>
127 #include <map>
128 #include <memory>
129 #include <tuple>
130 #include <utility>
131 #include <vector>
132 
133 using namespace llvm;
134 using namespace PatternMatch;
135 
136 #define DEBUG_TYPE "scalar-evolution"
137 
138 STATISTIC(NumTripCountsComputed,
139  "Number of loops with predictable loop counts");
140 STATISTIC(NumTripCountsNotComputed,
141  "Number of loops without predictable loop counts");
142 STATISTIC(NumBruteForceTripCountsComputed,
143  "Number of loops with trip counts computed by force");
144 
145 #ifdef EXPENSIVE_CHECKS
146 bool llvm::VerifySCEV = true;
147 #else
148 bool llvm::VerifySCEV = false;
149 #endif
150 
151 static cl::opt<unsigned>
152  MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
153  cl::desc("Maximum number of iterations SCEV will "
154  "symbolically execute a constant "
155  "derived loop"),
156  cl::init(100));
157 
159  "verify-scev", cl::Hidden, cl::location(VerifySCEV),
160  cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
162  "verify-scev-strict", cl::Hidden,
163  cl::desc("Enable stricter verification with -verify-scev is passed"));
164 static cl::opt<bool>
165  VerifySCEVMap("verify-scev-maps", cl::Hidden,
166  cl::desc("Verify no dangling value in ScalarEvolution's "
167  "ExprValueMap (slow)"));
168 
169 static cl::opt<bool> VerifyIR(
170  "scev-verify-ir", cl::Hidden,
171  cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
172  cl::init(false));
173 
175  "scev-mulops-inline-threshold", cl::Hidden,
176  cl::desc("Threshold for inlining multiplication operands into a SCEV"),
177  cl::init(32));
178 
180  "scev-addops-inline-threshold", cl::Hidden,
181  cl::desc("Threshold for inlining addition operands into a SCEV"),
182  cl::init(500));
183 
185  "scalar-evolution-max-scev-compare-depth", cl::Hidden,
186  cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
187  cl::init(32));
188 
190  "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
191  cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
192  cl::init(2));
193 
195  "scalar-evolution-max-value-compare-depth", cl::Hidden,
196  cl::desc("Maximum depth of recursive value complexity comparisons"),
197  cl::init(2));
198 
199 static cl::opt<unsigned>
200  MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
201  cl::desc("Maximum depth of recursive arithmetics"),
202  cl::init(32));
203 
205  "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
206  cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
207 
208 static cl::opt<unsigned>
209  MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
210  cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
211  cl::init(8));
212 
213 static cl::opt<unsigned>
214  MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
215  cl::desc("Max coefficients in AddRec during evolving"),
216  cl::init(8));
217 
218 static cl::opt<unsigned>
219  HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
220  cl::desc("Size of the expression which is considered huge"),
221  cl::init(4096));
222 
223 static cl::opt<bool>
224 ClassifyExpressions("scalar-evolution-classify-expressions",
225  cl::Hidden, cl::init(true),
226  cl::desc("When printing analysis, include information on every instruction"));
227 
229  "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
230  cl::init(false),
231  cl::desc("Use more powerful methods of sharpening expression ranges. May "
232  "be costly in terms of compile time"));
233 
235  "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
236  cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
237  "Phi strongly connected components"),
238  cl::init(8));
239 
240 static cl::opt<bool>
241  EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
242  cl::desc("Handle <= and >= in finite loops"),
243  cl::init(true));
244 
245 //===----------------------------------------------------------------------===//
246 // SCEV class definitions
247 //===----------------------------------------------------------------------===//
248 
249 //===----------------------------------------------------------------------===//
250 // Implementation of the SCEV class.
251 //
252 
253 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
255  print(dbgs());
256  dbgs() << '\n';
257 }
258 #endif
259 
260 void SCEV::print(raw_ostream &OS) const {
261  switch (getSCEVType()) {
262  case scConstant:
263  cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
264  return;
265  case scPtrToInt: {
266  const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
267  const SCEV *Op = PtrToInt->getOperand();
268  OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
269  << *PtrToInt->getType() << ")";
270  return;
271  }
272  case scTruncate: {
273  const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
274  const SCEV *Op = Trunc->getOperand();
275  OS << "(trunc " << *Op->getType() << " " << *Op << " to "
276  << *Trunc->getType() << ")";
277  return;
278  }
279  case scZeroExtend: {
280  const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
281  const SCEV *Op = ZExt->getOperand();
282  OS << "(zext " << *Op->getType() << " " << *Op << " to "
283  << *ZExt->getType() << ")";
284  return;
285  }
286  case scSignExtend: {
287  const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
288  const SCEV *Op = SExt->getOperand();
289  OS << "(sext " << *Op->getType() << " " << *Op << " to "
290  << *SExt->getType() << ")";
291  return;
292  }
293  case scAddRecExpr: {
294  const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
295  OS << "{" << *AR->getOperand(0);
296  for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
297  OS << ",+," << *AR->getOperand(i);
298  OS << "}<";
299  if (AR->hasNoUnsignedWrap())
300  OS << "nuw><";
301  if (AR->hasNoSignedWrap())
302  OS << "nsw><";
303  if (AR->hasNoSelfWrap() &&
304  !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)))
305  OS << "nw><";
306  AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
307  OS << ">";
308  return;
309  }
310  case scAddExpr:
311  case scMulExpr:
312  case scUMaxExpr:
313  case scSMaxExpr:
314  case scUMinExpr:
315  case scSMinExpr:
316  case scSequentialUMinExpr: {
317  const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
318  const char *OpStr = nullptr;
319  switch (NAry->getSCEVType()) {
320  case scAddExpr: OpStr = " + "; break;
321  case scMulExpr: OpStr = " * "; break;
322  case scUMaxExpr: OpStr = " umax "; break;
323  case scSMaxExpr: OpStr = " smax "; break;
324  case scUMinExpr:
325  OpStr = " umin ";
326  break;
327  case scSMinExpr:
328  OpStr = " smin ";
329  break;
331  OpStr = " umin_seq ";
332  break;
333  default:
334  llvm_unreachable("There are no other nary expression types.");
335  }
336  OS << "(";
337  ListSeparator LS(OpStr);
338  for (const SCEV *Op : NAry->operands())
339  OS << LS << *Op;
340  OS << ")";
341  switch (NAry->getSCEVType()) {
342  case scAddExpr:
343  case scMulExpr:
344  if (NAry->hasNoUnsignedWrap())
345  OS << "<nuw>";
346  if (NAry->hasNoSignedWrap())
347  OS << "<nsw>";
348  break;
349  default:
350  // Nothing to print for other nary expressions.
351  break;
352  }
353  return;
354  }
355  case scUDivExpr: {
356  const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
357  OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
358  return;
359  }
360  case scUnknown: {
361  const SCEVUnknown *U = cast<SCEVUnknown>(this);
362  Type *AllocTy;
363  if (U->isSizeOf(AllocTy)) {
364  OS << "sizeof(" << *AllocTy << ")";
365  return;
366  }
367  if (U->isAlignOf(AllocTy)) {
368  OS << "alignof(" << *AllocTy << ")";
369  return;
370  }
371 
372  Type *CTy;
373  Constant *FieldNo;
374  if (U->isOffsetOf(CTy, FieldNo)) {
375  OS << "offsetof(" << *CTy << ", ";
376  FieldNo->printAsOperand(OS, false);
377  OS << ")";
378  return;
379  }
380 
381  // Otherwise just print it normally.
382  U->getValue()->printAsOperand(OS, false);
383  return;
384  }
385  case scCouldNotCompute:
386  OS << "***COULDNOTCOMPUTE***";
387  return;
388  }
389  llvm_unreachable("Unknown SCEV kind!");
390 }
391 
392 Type *SCEV::getType() const {
393  switch (getSCEVType()) {
394  case scConstant:
395  return cast<SCEVConstant>(this)->getType();
396  case scPtrToInt:
397  case scTruncate:
398  case scZeroExtend:
399  case scSignExtend:
400  return cast<SCEVCastExpr>(this)->getType();
401  case scAddRecExpr:
402  return cast<SCEVAddRecExpr>(this)->getType();
403  case scMulExpr:
404  return cast<SCEVMulExpr>(this)->getType();
405  case scUMaxExpr:
406  case scSMaxExpr:
407  case scUMinExpr:
408  case scSMinExpr:
409  return cast<SCEVMinMaxExpr>(this)->getType();
411  return cast<SCEVSequentialMinMaxExpr>(this)->getType();
412  case scAddExpr:
413  return cast<SCEVAddExpr>(this)->getType();
414  case scUDivExpr:
415  return cast<SCEVUDivExpr>(this)->getType();
416  case scUnknown:
417  return cast<SCEVUnknown>(this)->getType();
418  case scCouldNotCompute:
419  llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
420  }
421  llvm_unreachable("Unknown SCEV kind!");
422 }
423 
424 bool SCEV::isZero() const {
425  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
426  return SC->getValue()->isZero();
427  return false;
428 }
429 
430 bool SCEV::isOne() const {
431  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
432  return SC->getValue()->isOne();
433  return false;
434 }
435 
436 bool SCEV::isAllOnesValue() const {
437  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
438  return SC->getValue()->isMinusOne();
439  return false;
440 }
441 
443  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
444  if (!Mul) return false;
445 
446  // If there is a constant factor, it will be first.
447  const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
448  if (!SC) return false;
449 
450  // Return true if the value is negative, this matches things like (-42 * V).
451  return SC->getAPInt().isNegative();
452 }
453 
456 
458  return S->getSCEVType() == scCouldNotCompute;
459 }
460 
463  ID.AddInteger(scConstant);
464  ID.AddPointer(V);
465  void *IP = nullptr;
466  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
467  SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
468  UniqueSCEVs.InsertNode(S, IP);
469  return S;
470 }
471 
473  return getConstant(ConstantInt::get(getContext(), Val));
474 }
475 
476 const SCEV *
478  IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
479  return getConstant(ConstantInt::get(ITy, V, isSigned));
480 }
481 
483  const SCEV *op, Type *ty)
484  : SCEV(ID, SCEVTy, computeExpressionSize(op)), Ty(ty) {
485  Operands[0] = op;
486 }
487 
488 SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
489  Type *ITy)
490  : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
491  assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
492  "Must be a non-bit-width-changing pointer-to-integer cast!");
493 }
494 
496  SCEVTypes SCEVTy, const SCEV *op,
497  Type *ty)
498  : SCEVCastExpr(ID, SCEVTy, op, ty) {}
499 
500 SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
501  Type *ty)
503  assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
504  "Cannot truncate non-integer value!");
505 }
506 
507 SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
508  const SCEV *op, Type *ty)
510  assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
511  "Cannot zero extend non-integer value!");
512 }
513 
514 SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
515  const SCEV *op, Type *ty)
517  assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
518  "Cannot sign extend non-integer value!");
519 }
520 
521 void SCEVUnknown::deleted() {
522  // Clear this SCEVUnknown from various maps.
523  SE->forgetMemoizedResults(this);
524 
525  // Remove this SCEVUnknown from the uniquing map.
526  SE->UniqueSCEVs.RemoveNode(this);
527 
528  // Release the value.
529  setValPtr(nullptr);
530 }
531 
532 void SCEVUnknown::allUsesReplacedWith(Value *New) {
533  // Clear this SCEVUnknown from various maps.
534  SE->forgetMemoizedResults(this);
535 
536  // Remove this SCEVUnknown from the uniquing map.
537  SE->UniqueSCEVs.RemoveNode(this);
538 
539  // Replace the value pointer in case someone is still using this SCEVUnknown.
540  setValPtr(New);
541 }
542 
543 bool SCEVUnknown::isSizeOf(Type *&AllocTy) const {
544  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
545  if (VCE->getOpcode() == Instruction::PtrToInt)
546  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
547  if (CE->getOpcode() == Instruction::GetElementPtr &&
548  CE->getOperand(0)->isNullValue() &&
549  CE->getNumOperands() == 2)
550  if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1)))
551  if (CI->isOne()) {
552  AllocTy = cast<GEPOperator>(CE)->getSourceElementType();
553  return true;
554  }
555 
556  return false;
557 }
558 
559 bool SCEVUnknown::isAlignOf(Type *&AllocTy) const {
560  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
561  if (VCE->getOpcode() == Instruction::PtrToInt)
562  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
563  if (CE->getOpcode() == Instruction::GetElementPtr &&
564  CE->getOperand(0)->isNullValue()) {
565  Type *Ty = cast<GEPOperator>(CE)->getSourceElementType();
566  if (StructType *STy = dyn_cast<StructType>(Ty))
567  if (!STy->isPacked() &&
568  CE->getNumOperands() == 3 &&
569  CE->getOperand(1)->isNullValue()) {
570  if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2)))
571  if (CI->isOne() &&
572  STy->getNumElements() == 2 &&
573  STy->getElementType(0)->isIntegerTy(1)) {
574  AllocTy = STy->getElementType(1);
575  return true;
576  }
577  }
578  }
579 
580  return false;
581 }
582 
583 bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const {
584  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
585  if (VCE->getOpcode() == Instruction::PtrToInt)
586  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
587  if (CE->getOpcode() == Instruction::GetElementPtr &&
588  CE->getNumOperands() == 3 &&
589  CE->getOperand(0)->isNullValue() &&
590  CE->getOperand(1)->isNullValue()) {
591  Type *Ty = cast<GEPOperator>(CE)->getSourceElementType();
592  // Ignore vector types here so that ScalarEvolutionExpander doesn't
593  // emit getelementptrs that index into vectors.
594  if (Ty->isStructTy() || Ty->isArrayTy()) {
595  CTy = Ty;
596  FieldNo = CE->getOperand(2);
597  return true;
598  }
599  }
600 
601  return false;
602 }
603 
604 //===----------------------------------------------------------------------===//
605 // SCEV Utilities
606 //===----------------------------------------------------------------------===//
607 
608 /// Compare the two values \p LV and \p RV in terms of their "complexity" where
609 /// "complexity" is a partial (and somewhat ad-hoc) relation used to order
610 /// operands in SCEV expressions. \p EqCache is a set of pairs of values that
611 /// have been previously deemed to be "equally complex" by this routine. It is
612 /// intended to avoid exponential time complexity in cases like:
613 ///
614 /// %a = f(%x, %y)
615 /// %b = f(%a, %a)
616 /// %c = f(%b, %b)
617 ///
618 /// %d = f(%x, %y)
619 /// %e = f(%d, %d)
620 /// %f = f(%e, %e)
621 ///
622 /// CompareValueComplexity(%f, %c)
623 ///
624 /// Since we do not continue running this routine on expression trees once we
625 /// have seen unequal values, there is no need to track them in the cache.
626 static int
628  const LoopInfo *const LI, Value *LV, Value *RV,
629  unsigned Depth) {
630  if (Depth > MaxValueCompareDepth || EqCacheValue.isEquivalent(LV, RV))
631  return 0;
632 
633  // Order pointer values after integer values. This helps SCEVExpander form
634  // GEPs.
635  bool LIsPointer = LV->getType()->isPointerTy(),
636  RIsPointer = RV->getType()->isPointerTy();
637  if (LIsPointer != RIsPointer)
638  return (int)LIsPointer - (int)RIsPointer;
639 
640  // Compare getValueID values.
641  unsigned LID = LV->getValueID(), RID = RV->getValueID();
642  if (LID != RID)
643  return (int)LID - (int)RID;
644 
645  // Sort arguments by their position.
646  if (const auto *LA = dyn_cast<Argument>(LV)) {
647  const auto *RA = cast<Argument>(RV);
648  unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
649  return (int)LArgNo - (int)RArgNo;
650  }
651 
652  if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
653  const auto *RGV = cast<GlobalValue>(RV);
654 
655  const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
656  auto LT = GV->getLinkage();
657  return !(GlobalValue::isPrivateLinkage(LT) ||
659  };
660 
661  // Use the names to distinguish the two values, but only if the
662  // names are semantically important.
663  if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
664  return LGV->getName().compare(RGV->getName());
665  }
666 
667  // For instructions, compare their loop depth, and their operand count. This
668  // is pretty loose.
669  if (const auto *LInst = dyn_cast<Instruction>(LV)) {
670  const auto *RInst = cast<Instruction>(RV);
671 
672  // Compare loop depths.
673  const BasicBlock *LParent = LInst->getParent(),
674  *RParent = RInst->getParent();
675  if (LParent != RParent) {
676  unsigned LDepth = LI->getLoopDepth(LParent),
677  RDepth = LI->getLoopDepth(RParent);
678  if (LDepth != RDepth)
679  return (int)LDepth - (int)RDepth;
680  }
681 
682  // Compare the number of operands.
683  unsigned LNumOps = LInst->getNumOperands(),
684  RNumOps = RInst->getNumOperands();
685  if (LNumOps != RNumOps)
686  return (int)LNumOps - (int)RNumOps;
687 
688  for (unsigned Idx : seq(0u, LNumOps)) {
689  int Result =
690  CompareValueComplexity(EqCacheValue, LI, LInst->getOperand(Idx),
691  RInst->getOperand(Idx), Depth + 1);
692  if (Result != 0)
693  return Result;
694  }
695  }
696 
697  EqCacheValue.unionSets(LV, RV);
698  return 0;
699 }
700 
701 // Return negative, zero, or positive, if LHS is less than, equal to, or greater
702 // than RHS, respectively. A three-way result allows recursive comparisons to be
703 // more efficient.
704 // If the max analysis depth was reached, return None, assuming we do not know
705 // if they are equivalent for sure.
706 static Optional<int>
708  EquivalenceClasses<const Value *> &EqCacheValue,
709  const LoopInfo *const LI, const SCEV *LHS,
710  const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
711  // Fast-path: SCEVs are uniqued so we can do a quick equality check.
712  if (LHS == RHS)
713  return 0;
714 
715  // Primarily, sort the SCEVs by their getSCEVType().
716  SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
717  if (LType != RType)
718  return (int)LType - (int)RType;
719 
720  if (EqCacheSCEV.isEquivalent(LHS, RHS))
721  return 0;
722 
724  return None;
725 
726  // Aside from the getSCEVType() ordering, the particular ordering
727  // isn't very important except that it's beneficial to be consistent,
728  // so that (a + b) and (b + a) don't end up as different expressions.
729  switch (LType) {
730  case scUnknown: {
731  const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
732  const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
733 
734  int X = CompareValueComplexity(EqCacheValue, LI, LU->getValue(),
735  RU->getValue(), Depth + 1);
736  if (X == 0)
737  EqCacheSCEV.unionSets(LHS, RHS);
738  return X;
739  }
740 
741  case scConstant: {
742  const SCEVConstant *LC = cast<SCEVConstant>(LHS);
743  const SCEVConstant *RC = cast<SCEVConstant>(RHS);
744 
745  // Compare constant values.
746  const APInt &LA = LC->getAPInt();
747  const APInt &RA = RC->getAPInt();
748  unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
749  if (LBitWidth != RBitWidth)
750  return (int)LBitWidth - (int)RBitWidth;
751  return LA.ult(RA) ? -1 : 1;
752  }
753 
754  case scAddRecExpr: {
755  const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
756  const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
757 
758  // There is always a dominance between two recs that are used by one SCEV,
759  // so we can safely sort recs by loop header dominance. We require such
760  // order in getAddExpr.
761  const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
762  if (LLoop != RLoop) {
763  const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
764  assert(LHead != RHead && "Two loops share the same header?");
765  if (DT.dominates(LHead, RHead))
766  return 1;
767  else
768  assert(DT.dominates(RHead, LHead) &&
769  "No dominance between recurrences used by one SCEV?");
770  return -1;
771  }
772 
773  // Addrec complexity grows with operand count.
774  unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands();
775  if (LNumOps != RNumOps)
776  return (int)LNumOps - (int)RNumOps;
777 
778  // Lexicographically compare.
779  for (unsigned i = 0; i != LNumOps; ++i) {
780  auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
781  LA->getOperand(i), RA->getOperand(i), DT,
782  Depth + 1);
783  if (X != 0)
784  return X;
785  }
786  EqCacheSCEV.unionSets(LHS, RHS);
787  return 0;
788  }
789 
790  case scAddExpr:
791  case scMulExpr:
792  case scSMaxExpr:
793  case scUMaxExpr:
794  case scSMinExpr:
795  case scUMinExpr:
796  case scSequentialUMinExpr: {
797  const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS);
798  const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
799 
800  // Lexicographically compare n-ary expressions.
801  unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
802  if (LNumOps != RNumOps)
803  return (int)LNumOps - (int)RNumOps;
804 
805  for (unsigned i = 0; i != LNumOps; ++i) {
806  auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
807  LC->getOperand(i), RC->getOperand(i), DT,
808  Depth + 1);
809  if (X != 0)
810  return X;
811  }
812  EqCacheSCEV.unionSets(LHS, RHS);
813  return 0;
814  }
815 
816  case scUDivExpr: {
817  const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS);
818  const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
819 
820  // Lexicographically compare udiv expressions.
821  auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getLHS(),
822  RC->getLHS(), DT, Depth + 1);
823  if (X != 0)
824  return X;
825  X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getRHS(),
826  RC->getRHS(), DT, Depth + 1);
827  if (X == 0)
828  EqCacheSCEV.unionSets(LHS, RHS);
829  return X;
830  }
831 
832  case scPtrToInt:
833  case scTruncate:
834  case scZeroExtend:
835  case scSignExtend: {
836  const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS);
837  const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
838 
839  // Compare cast expressions by operand.
840  auto X =
841  CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getOperand(),
842  RC->getOperand(), DT, Depth + 1);
843  if (X == 0)
844  EqCacheSCEV.unionSets(LHS, RHS);
845  return X;
846  }
847 
848  case scCouldNotCompute:
849  llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
850  }
851  llvm_unreachable("Unknown SCEV kind!");
852 }
853 
854 /// Given a list of SCEV objects, order them by their complexity, and group
855 /// objects of the same complexity together by value. When this routine is
856 /// finished, we know that any duplicates in the vector are consecutive and that
857 /// complexity is monotonically increasing.
858 ///
859 /// Note that we go take special precautions to ensure that we get deterministic
860 /// results from this routine. In other words, we don't want the results of
861 /// this to depend on where the addresses of various SCEV objects happened to
862 /// land in memory.
864  LoopInfo *LI, DominatorTree &DT) {
865  if (Ops.size() < 2) return; // Noop
866 
869 
870  // Whether LHS has provably less complexity than RHS.
871  auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
872  auto Complexity =
873  CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT);
874  return Complexity && *Complexity < 0;
875  };
876  if (Ops.size() == 2) {
877  // This is the common case, which also happens to be trivially simple.
878  // Special case it.
879  const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
880  if (IsLessComplex(RHS, LHS))
881  std::swap(LHS, RHS);
882  return;
883  }
884 
885  // Do the rough sort by complexity.
886  llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
887  return IsLessComplex(LHS, RHS);
888  });
889 
890  // Now that we are sorted by complexity, group elements of the same
891  // complexity. Note that this is, at worst, N^2, but the vector is likely to
892  // be extremely short in practice. Note that we take this approach because we
893  // do not want to depend on the addresses of the objects we are grouping.
894  for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
895  const SCEV *S = Ops[i];
896  unsigned Complexity = S->getSCEVType();
897 
898  // If there are any objects of the same complexity and same value as this
899  // one, group them.
900  for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
901  if (Ops[j] == S) { // Found a duplicate.
902  // Move it to immediately after i'th element.
903  std::swap(Ops[i+1], Ops[j]);
904  ++i; // no need to rescan it.
905  if (i == e-2) return; // Done!
906  }
907  }
908  }
909 }
910 
911 /// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
912 /// least HugeExprThreshold nodes).
914  return any_of(Ops, [](const SCEV *S) {
915  return S->getExpressionSize() >= HugeExprThreshold;
916  });
917 }
918 
919 //===----------------------------------------------------------------------===//
920 // Simple SCEV method implementations
921 //===----------------------------------------------------------------------===//
922 
923 /// Compute BC(It, K). The result has width W. Assume, K > 0.
924 static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
925  ScalarEvolution &SE,
926  Type *ResultTy) {
927  // Handle the simplest case efficiently.
928  if (K == 1)
929  return SE.getTruncateOrZeroExtend(It, ResultTy);
930 
931  // We are using the following formula for BC(It, K):
932  //
933  // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
934  //
935  // Suppose, W is the bitwidth of the return value. We must be prepared for
936  // overflow. Hence, we must assure that the result of our computation is
937  // equal to the accurate one modulo 2^W. Unfortunately, division isn't
938  // safe in modular arithmetic.
939  //
940  // However, this code doesn't use exactly that formula; the formula it uses
941  // is something like the following, where T is the number of factors of 2 in
942  // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
943  // exponentiation:
944  //
945  // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
946  //
947  // This formula is trivially equivalent to the previous formula. However,
948  // this formula can be implemented much more efficiently. The trick is that
949  // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
950  // arithmetic. To do exact division in modular arithmetic, all we have
951  // to do is multiply by the inverse. Therefore, this step can be done at
952  // width W.
953  //
954  // The next issue is how to safely do the division by 2^T. The way this
955  // is done is by doing the multiplication step at a width of at least W + T
956  // bits. This way, the bottom W+T bits of the product are accurate. Then,
957  // when we perform the division by 2^T (which is equivalent to a right shift
958  // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
959  // truncated out after the division by 2^T.
960  //
961  // In comparison to just directly using the first formula, this technique
962  // is much more efficient; using the first formula requires W * K bits,
963  // but this formula less than W + K bits. Also, the first formula requires
964  // a division step, whereas this formula only requires multiplies and shifts.
965  //
966  // It doesn't matter whether the subtraction step is done in the calculation
967  // width or the input iteration count's width; if the subtraction overflows,
968  // the result must be zero anyway. We prefer here to do it in the width of
969  // the induction variable because it helps a lot for certain cases; CodeGen
970  // isn't smart enough to ignore the overflow, which leads to much less
971  // efficient code if the width of the subtraction is wider than the native
972  // register width.
973  //
974  // (It's possible to not widen at all by pulling out factors of 2 before
975  // the multiplication; for example, K=2 can be calculated as
976  // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
977  // extra arithmetic, so it's not an obvious win, and it gets
978  // much more complicated for K > 3.)
979 
980  // Protection from insane SCEVs; this bound is conservative,
981  // but it probably doesn't matter.
982  if (K > 1000)
983  return SE.getCouldNotCompute();
984 
985  unsigned W = SE.getTypeSizeInBits(ResultTy);
986 
987  // Calculate K! / 2^T and T; we divide out the factors of two before
988  // multiplying for calculating K! / 2^T to avoid overflow.
989  // Other overflow doesn't matter because we only care about the bottom
990  // W bits of the result.
991  APInt OddFactorial(W, 1);
992  unsigned T = 1;
993  for (unsigned i = 3; i <= K; ++i) {
994  APInt Mult(W, i);
995  unsigned TwoFactors = Mult.countTrailingZeros();
996  T += TwoFactors;
997  Mult.lshrInPlace(TwoFactors);
998  OddFactorial *= Mult;
999  }
1000 
1001  // We need at least W + T bits for the multiplication step
1002  unsigned CalculationBits = W + T;
1003 
1004  // Calculate 2^T, at width T+W.
1005  APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
1006 
1007  // Calculate the multiplicative inverse of K! / 2^T;
1008  // this multiplication factor will perform the exact division by
1009  // K! / 2^T.
1011  APInt MultiplyFactor = OddFactorial.zext(W+1);
1012  MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
1013  MultiplyFactor = MultiplyFactor.trunc(W);
1014 
1015  // Calculate the product, at width T+W
1016  IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
1017  CalculationBits);
1018  const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
1019  for (unsigned i = 1; i != K; ++i) {
1020  const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
1021  Dividend = SE.getMulExpr(Dividend,
1022  SE.getTruncateOrZeroExtend(S, CalculationTy));
1023  }
1024 
1025  // Divide by 2^T
1026  const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
1027 
1028  // Truncate the result, and divide by K! / 2^T.
1029 
1030  return SE.getMulExpr(SE.getConstant(MultiplyFactor),
1031  SE.getTruncateOrZeroExtend(DivResult, ResultTy));
1032 }
1033 
1034 /// Return the value of this chain of recurrences at the specified iteration
1035 /// number. We can evaluate this recurrence by multiplying each element in the
1036 /// chain by the binomial coefficient corresponding to it. In other words, we
1037 /// can evaluate {A,+,B,+,C,+,D} as:
1038 ///
1039 /// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
1040 ///
1041 /// where BC(It, k) stands for binomial coefficient.
1043  ScalarEvolution &SE) const {
1044  return evaluateAtIteration(makeArrayRef(op_begin(), op_end()), It, SE);
1045 }
1046 
1047 const SCEV *
1049  const SCEV *It, ScalarEvolution &SE) {
1050  assert(Operands.size() > 0);
1051  const SCEV *Result = Operands[0];
1052  for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
1053  // The computation is correct in the face of overflow provided that the
1054  // multiplication is performed _after_ the evaluation of the binomial
1055  // coefficient.
1056  const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
1057  if (isa<SCEVCouldNotCompute>(Coeff))
1058  return Coeff;
1059 
1060  Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
1061  }
1062  return Result;
1063 }
1064 
1065 //===----------------------------------------------------------------------===//
1066 // SCEV Expression folder implementations
1067 //===----------------------------------------------------------------------===//
1068 
1070  unsigned Depth) {
1071  assert(Depth <= 1 &&
1072  "getLosslessPtrToIntExpr() should self-recurse at most once.");
1073 
1074  // We could be called with an integer-typed operands during SCEV rewrites.
1075  // Since the operand is an integer already, just perform zext/trunc/self cast.
1076  if (!Op->getType()->isPointerTy())
1077  return Op;
1078 
1079  // What would be an ID for such a SCEV cast expression?
1081  ID.AddInteger(scPtrToInt);
1082  ID.AddPointer(Op);
1083 
1084  void *IP = nullptr;
1085 
1086  // Is there already an expression for such a cast?
1087  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1088  return S;
1089 
1090  // It isn't legal for optimizations to construct new ptrtoint expressions
1091  // for non-integral pointers.
1092  if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1093  return getCouldNotCompute();
1094 
1095  Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1096 
1097  // We can only trivially model ptrtoint if SCEV's effective (integer) type
1098  // is sufficiently wide to represent all possible pointer values.
1099  // We could theoretically teach SCEV to truncate wider pointers, but
1100  // that isn't implemented for now.
1102  getDataLayout().getTypeSizeInBits(IntPtrTy))
1103  return getCouldNotCompute();
1104 
1105  // If not, is this expression something we can't reduce any further?
1106  if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1107  // Perform some basic constant folding. If the operand of the ptr2int cast
1108  // is a null pointer, don't create a ptr2int SCEV expression (that will be
1109  // left as-is), but produce a zero constant.
1110  // NOTE: We could handle a more general case, but lack motivational cases.
1111  if (isa<ConstantPointerNull>(U->getValue()))
1112  return getZero(IntPtrTy);
1113 
1114  // Create an explicit cast node.
1115  // We can reuse the existing insert position since if we get here,
1116  // we won't have made any changes which would invalidate it.
1117  SCEV *S = new (SCEVAllocator)
1118  SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1119  UniqueSCEVs.InsertNode(S, IP);
1120  registerUser(S, Op);
1121  return S;
1122  }
1123 
1124  assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
1125  "non-SCEVUnknown's.");
1126 
1127  // Otherwise, we've got some expression that is more complex than just a
1128  // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1129  // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1130  // only, and the expressions must otherwise be integer-typed.
1131  // So sink the cast down to the SCEVUnknown's.
1132 
1133  /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1134  /// which computes a pointer-typed value, and rewrites the whole expression
1135  /// tree so that *all* the computations are done on integers, and the only
1136  /// pointer-typed operands in the expression are SCEVUnknown.
1137  class SCEVPtrToIntSinkingRewriter
1138  : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1140 
1141  public:
1142  SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1143 
1144  static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1145  SCEVPtrToIntSinkingRewriter Rewriter(SE);
1146  return Rewriter.visit(Scev);
1147  }
1148 
1149  const SCEV *visit(const SCEV *S) {
1150  Type *STy = S->getType();
1151  // If the expression is not pointer-typed, just keep it as-is.
1152  if (!STy->isPointerTy())
1153  return S;
1154  // Else, recursively sink the cast down into it.
1155  return Base::visit(S);
1156  }
1157 
1158  const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1160  bool Changed = false;
1161  for (auto *Op : Expr->operands()) {
1162  Operands.push_back(visit(Op));
1163  Changed |= Op != Operands.back();
1164  }
1165  return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1166  }
1167 
1168  const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1170  bool Changed = false;
1171  for (auto *Op : Expr->operands()) {
1172  Operands.push_back(visit(Op));
1173  Changed |= Op != Operands.back();
1174  }
1175  return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1176  }
1177 
1178  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1179  assert(Expr->getType()->isPointerTy() &&
1180  "Should only reach pointer-typed SCEVUnknown's.");
1181  return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1182  }
1183  };
1184 
1185  // And actually perform the cast sinking.
1186  const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1187  assert(IntOp->getType()->isIntegerTy() &&
1188  "We must have succeeded in sinking the cast, "
1189  "and ending up with an integer-typed expression!");
1190  return IntOp;
1191 }
1192 
1194  assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1195 
1196  const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1197  if (isa<SCEVCouldNotCompute>(IntOp))
1198  return IntOp;
1199 
1200  return getTruncateOrZeroExtend(IntOp, Ty);
1201 }
1202 
1204  unsigned Depth) {
1205  assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1206  "This is not a truncating conversion!");
1207  assert(isSCEVable(Ty) &&
1208  "This is not a conversion to a SCEVable type!");
1209  assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1210  Ty = getEffectiveSCEVType(Ty);
1211 
1213  ID.AddInteger(scTruncate);
1214  ID.AddPointer(Op);
1215  ID.AddPointer(Ty);
1216  void *IP = nullptr;
1217  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1218 
1219  // Fold if the operand is constant.
1220  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1221  return getConstant(
1222  cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1223 
1224  // trunc(trunc(x)) --> trunc(x)
1225  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
1226  return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1227 
1228  // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1229  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1230  return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1231 
1232  // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1233  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1234  return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1235 
1236  if (Depth > MaxCastDepth) {
1237  SCEV *S =
1238  new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1239  UniqueSCEVs.InsertNode(S, IP);
1240  registerUser(S, Op);
1241  return S;
1242  }
1243 
1244  // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1245  // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1246  // if after transforming we have at most one truncate, not counting truncates
1247  // that replace other casts.
1248  if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) {
1249  auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1251  unsigned numTruncs = 0;
1252  for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1253  ++i) {
1254  const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1255  if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1256  isa<SCEVTruncateExpr>(S))
1257  numTruncs++;
1258  Operands.push_back(S);
1259  }
1260  if (numTruncs < 2) {
1261  if (isa<SCEVAddExpr>(Op))
1262  return getAddExpr(Operands);
1263  else if (isa<SCEVMulExpr>(Op))
1264  return getMulExpr(Operands);
1265  else
1266  llvm_unreachable("Unexpected SCEV type for Op.");
1267  }
1268  // Although we checked in the beginning that ID is not in the cache, it is
1269  // possible that during recursion and different modification ID was inserted
1270  // into the cache. So if we find it, just return it.
1271  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1272  return S;
1273  }
1274 
1275  // If the input value is a chrec scev, truncate the chrec's operands.
1276  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1278  for (const SCEV *Op : AddRec->operands())
1279  Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1280  return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1281  }
1282 
1283  // Return zero if truncating to known zeros.
1284  uint32_t MinTrailingZeros = GetMinTrailingZeros(Op);
1285  if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1286  return getZero(Ty);
1287 
1288  // The cast wasn't folded; create an explicit cast node. We can reuse
1289  // the existing insert position since if we get here, we won't have
1290  // made any changes which would invalidate it.
1291  SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1292  Op, Ty);
1293  UniqueSCEVs.InsertNode(S, IP);
1294  registerUser(S, Op);
1295  return S;
1296 }
1297 
1298 // Get the limit of a recurrence such that incrementing by Step cannot cause
1299 // signed overflow as long as the value of the recurrence within the
1300 // loop does not exceed this limit before incrementing.
1301 static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1302  ICmpInst::Predicate *Pred,
1303  ScalarEvolution *SE) {
1304  unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1305  if (SE->isKnownPositive(Step)) {
1306  *Pred = ICmpInst::ICMP_SLT;
1308  SE->getSignedRangeMax(Step));
1309  }
1310  if (SE->isKnownNegative(Step)) {
1311  *Pred = ICmpInst::ICMP_SGT;
1313  SE->getSignedRangeMin(Step));
1314  }
1315  return nullptr;
1316 }
1317 
1318 // Get the limit of a recurrence such that incrementing by Step cannot cause
1319 // unsigned overflow as long as the value of the recurrence within the loop does
1320 // not exceed this limit before incrementing.
1321 static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step,
1322  ICmpInst::Predicate *Pred,
1323  ScalarEvolution *SE) {
1324  unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1325  *Pred = ICmpInst::ICMP_ULT;
1326 
1328  SE->getUnsignedRangeMax(Step));
1329 }
1330 
1331 namespace {
1332 
1333 struct ExtendOpTraitsBase {
1334  typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1335  unsigned);
1336 };
1337 
1338 // Used to make code generic over signed and unsigned overflow.
1339 template <typename ExtendOp> struct ExtendOpTraits {
1340  // Members present:
1341  //
1342  // static const SCEV::NoWrapFlags WrapType;
1343  //
1344  // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1345  //
1346  // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1347  // ICmpInst::Predicate *Pred,
1348  // ScalarEvolution *SE);
1349 };
1350 
1351 template <>
1352 struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1353  static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1354 
1355  static const GetExtendExprTy GetExtendExpr;
1356 
1357  static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1358  ICmpInst::Predicate *Pred,
1359  ScalarEvolution *SE) {
1360  return getSignedOverflowLimitForStep(Step, Pred, SE);
1361  }
1362 };
1363 
1364 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1366 
1367 template <>
1368 struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1369  static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1370 
1371  static const GetExtendExprTy GetExtendExpr;
1372 
1373  static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1374  ICmpInst::Predicate *Pred,
1375  ScalarEvolution *SE) {
1376  return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1377  }
1378 };
1379 
1380 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1382 
1383 } // end anonymous namespace
1384 
1385 // The recurrence AR has been shown to have no signed/unsigned wrap or something
1386 // close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1387 // easily prove NSW/NUW for its preincrement or postincrement sibling. This
1388 // allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1389 // Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1390 // expression "Step + sext/zext(PreIncAR)" is congruent with
1391 // "sext/zext(PostIncAR)"
1392 template <typename ExtendOpTy>
1393 static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1394  ScalarEvolution *SE, unsigned Depth) {
1395  auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1396  auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1397 
1398  const Loop *L = AR->getLoop();
1399  const SCEV *Start = AR->getStart();
1400  const SCEV *Step = AR->getStepRecurrence(*SE);
1401 
1402  // Check for a simple looking step prior to loop entry.
1403  const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1404  if (!SA)
1405  return nullptr;
1406 
1407  // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1408  // subtraction is expensive. For this purpose, perform a quick and dirty
1409  // difference, by checking for Step in the operand list.
1411  for (const SCEV *Op : SA->operands())
1412  if (Op != Step)
1413  DiffOps.push_back(Op);
1414 
1415  if (DiffOps.size() == SA->getNumOperands())
1416  return nullptr;
1417 
1418  // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1419  // `Step`:
1420 
1421  // 1. NSW/NUW flags on the step increment.
1422  auto PreStartFlags =
1424  const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1425  const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1426  SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1427 
1428  // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1429  // "S+X does not sign/unsign-overflow".
1430  //
1431 
1432  const SCEV *BECount = SE->getBackedgeTakenCount(L);
1433  if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1434  !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1435  return PreStart;
1436 
1437  // 2. Direct overflow check on the step operation's expression.
1438  unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1439  Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1440  const SCEV *OperandExtendedStart =
1441  SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1442  (SE->*GetExtendExpr)(Step, WideTy, Depth));
1443  if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1444  if (PreAR && AR->getNoWrapFlags(WrapType)) {
1445  // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1446  // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1447  // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1448  SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1449  }
1450  return PreStart;
1451  }
1452 
1453  // 3. Loop precondition.
1454  ICmpInst::Predicate Pred;
1455  const SCEV *OverflowLimit =
1456  ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1457 
1458  if (OverflowLimit &&
1459  SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1460  return PreStart;
1461 
1462  return nullptr;
1463 }
1464 
1465 // Get the normalized zero or sign extended expression for this AddRec's Start.
1466 template <typename ExtendOpTy>
1467 static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1468  ScalarEvolution *SE,
1469  unsigned Depth) {
1470  auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1471 
1472  const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1473  if (!PreStart)
1474  return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1475 
1476  return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1477  Depth),
1478  (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1479 }
1480 
1481 // Try to prove away overflow by looking at "nearby" add recurrences. A
1482 // motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1483 // does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1484 //
1485 // Formally:
1486 //
1487 // {S,+,X} == {S-T,+,X} + T
1488 // => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1489 //
1490 // If ({S-T,+,X} + T) does not overflow ... (1)
1491 //
1492 // RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1493 //
1494 // If {S-T,+,X} does not overflow ... (2)
1495 //
1496 // RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1497 // == {Ext(S-T)+Ext(T),+,Ext(X)}
1498 //
1499 // If (S-T)+T does not overflow ... (3)
1500 //
1501 // RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1502 // == {Ext(S),+,Ext(X)} == LHS
1503 //
1504 // Thus, if (1), (2) and (3) are true for some T, then
1505 // Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1506 //
1507 // (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1508 // does not overflow" restricted to the 0th iteration. Therefore we only need
1509 // to check for (1) and (2).
1510 //
1511 // In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1512 // is `Delta` (defined below).
1513 template <typename ExtendOpTy>
1514 bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1515  const SCEV *Step,
1516  const Loop *L) {
1517  auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1518 
1519  // We restrict `Start` to a constant to prevent SCEV from spending too much
1520  // time here. It is correct (but more expensive) to continue with a
1521  // non-constant `Start` and do a general SCEV subtraction to compute
1522  // `PreStart` below.
1523  const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1524  if (!StartC)
1525  return false;
1526 
1527  APInt StartAI = StartC->getAPInt();
1528 
1529  for (unsigned Delta : {-2, -1, 1, 2}) {
1530  const SCEV *PreStart = getConstant(StartAI - Delta);
1531 
1533  ID.AddInteger(scAddRecExpr);
1534  ID.AddPointer(PreStart);
1535  ID.AddPointer(Step);
1536  ID.AddPointer(L);
1537  void *IP = nullptr;
1538  const auto *PreAR =
1539  static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1540 
1541  // Give up if we don't already have the add recurrence we need because
1542  // actually constructing an add recurrence is relatively expensive.
1543  if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1544  const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1546  const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1547  DeltaS, &Pred, this);
1548  if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1549  return true;
1550  }
1551  }
1552 
1553  return false;
1554 }
1555 
1556 // Finds an integer D for an expression (C + x + y + ...) such that the top
1557 // level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1558 // unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1559 // maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1560 // the (C + x + y + ...) expression is \p WholeAddExpr.
1562  const SCEVConstant *ConstantTerm,
1563  const SCEVAddExpr *WholeAddExpr) {
1564  const APInt &C = ConstantTerm->getAPInt();
1565  const unsigned BitWidth = C.getBitWidth();
1566  // Find number of trailing zeros of (x + y + ...) w/o the C first:
1567  uint32_t TZ = BitWidth;
1568  for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1569  TZ = std::min(TZ, SE.GetMinTrailingZeros(WholeAddExpr->getOperand(I)));
1570  if (TZ) {
1571  // Set D to be as many least significant bits of C as possible while still
1572  // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1573  return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1574  }
1575  return APInt(BitWidth, 0);
1576 }
1577 
1578 // Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1579 // level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1580 // number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1581 // ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1583  const APInt &ConstantStart,
1584  const SCEV *Step) {
1585  const unsigned BitWidth = ConstantStart.getBitWidth();
1586  const uint32_t TZ = SE.GetMinTrailingZeros(Step);
1587  if (TZ)
1588  return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1589  : ConstantStart;
1590  return APInt(BitWidth, 0);
1591 }
1592 
1593 const SCEV *
1595  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1596  "This is not an extending conversion!");
1597  assert(isSCEVable(Ty) &&
1598  "This is not a conversion to a SCEVable type!");
1599  assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1600  Ty = getEffectiveSCEVType(Ty);
1601 
1602  // Fold if the operand is constant.
1603  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1604  return getConstant(
1605  cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(), Ty)));
1606 
1607  // zext(zext(x)) --> zext(x)
1608  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1609  return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1610 
1611  // Before doing any expensive analysis, check to see if we've already
1612  // computed a SCEV for this Op and Ty.
1614  ID.AddInteger(scZeroExtend);
1615  ID.AddPointer(Op);
1616  ID.AddPointer(Ty);
1617  void *IP = nullptr;
1618  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1619  if (Depth > MaxCastDepth) {
1620  SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1621  Op, Ty);
1622  UniqueSCEVs.InsertNode(S, IP);
1623  registerUser(S, Op);
1624  return S;
1625  }
1626 
1627  // zext(trunc(x)) --> zext(x) or x or trunc(x)
1628  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1629  // It's possible the bits taken off by the truncate were all zero bits. If
1630  // so, we should be able to simplify this further.
1631  const SCEV *X = ST->getOperand();
1633  unsigned TruncBits = getTypeSizeInBits(ST->getType());
1634  unsigned NewBits = getTypeSizeInBits(Ty);
1635  if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1636  CR.zextOrTrunc(NewBits)))
1637  return getTruncateOrZeroExtend(X, Ty, Depth);
1638  }
1639 
1640  // If the input value is a chrec scev, and we can prove that the value
1641  // did not overflow the old, smaller, value, we can zero extend all of the
1642  // operands (often constants). This allows analysis of something like
1643  // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1644  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1645  if (AR->isAffine()) {
1646  const SCEV *Start = AR->getStart();
1647  const SCEV *Step = AR->getStepRecurrence(*this);
1648  unsigned BitWidth = getTypeSizeInBits(AR->getType());
1649  const Loop *L = AR->getLoop();
1650 
1651  if (!AR->hasNoUnsignedWrap()) {
1652  auto NewFlags = proveNoWrapViaConstantRanges(AR);
1653  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1654  }
1655 
1656  // If we have special knowledge that this addrec won't overflow,
1657  // we don't need to do any further analysis.
1658  if (AR->hasNoUnsignedWrap()) {
1659  Start =
1660  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1661  Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1662  return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1663  }
1664 
1665  // Check whether the backedge-taken count is SCEVCouldNotCompute.
1666  // Note that this serves two purposes: It filters out loops that are
1667  // simply not analyzable, and it covers the case where this code is
1668  // being called from within backedge-taken count analysis, such that
1669  // attempting to ask for the backedge-taken count would likely result
1670  // in infinite recursion. In the later case, the analysis code will
1671  // cope with a conservative value, and it will take care to purge
1672  // that value once it has finished.
1673  const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1674  if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1675  // Manually compute the final value for AR, checking for overflow.
1676 
1677  // Check whether the backedge-taken count can be losslessly casted to
1678  // the addrec's type. The count is always unsigned.
1679  const SCEV *CastedMaxBECount =
1680  getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1681  const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1682  CastedMaxBECount, MaxBECount->getType(), Depth);
1683  if (MaxBECount == RecastedMaxBECount) {
1684  Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1685  // Check whether Start+Step*MaxBECount has no unsigned overflow.
1686  const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1687  SCEV::FlagAnyWrap, Depth + 1);
1688  const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1690  Depth + 1),
1691  WideTy, Depth + 1);
1692  const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1693  const SCEV *WideMaxBECount =
1694  getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1695  const SCEV *OperandExtendedAdd =
1696  getAddExpr(WideStart,
1697  getMulExpr(WideMaxBECount,
1698  getZeroExtendExpr(Step, WideTy, Depth + 1),
1699  SCEV::FlagAnyWrap, Depth + 1),
1700  SCEV::FlagAnyWrap, Depth + 1);
1701  if (ZAdd == OperandExtendedAdd) {
1702  // Cache knowledge of AR NUW, which is propagated to this AddRec.
1703  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1704  // Return the expression with the addrec on the outside.
1705  Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1706  Depth + 1);
1707  Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1708  return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1709  }
1710  // Similar to above, only this time treat the step value as signed.
1711  // This covers loops that count down.
1712  OperandExtendedAdd =
1713  getAddExpr(WideStart,
1714  getMulExpr(WideMaxBECount,
1715  getSignExtendExpr(Step, WideTy, Depth + 1),
1716  SCEV::FlagAnyWrap, Depth + 1),
1717  SCEV::FlagAnyWrap, Depth + 1);
1718  if (ZAdd == OperandExtendedAdd) {
1719  // Cache knowledge of AR NW, which is propagated to this AddRec.
1720  // Negative step causes unsigned wrap, but it still can't self-wrap.
1721  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1722  // Return the expression with the addrec on the outside.
1723  Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1724  Depth + 1);
1725  Step = getSignExtendExpr(Step, Ty, Depth + 1);
1726  return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1727  }
1728  }
1729  }
1730 
1731  // Normally, in the cases we can prove no-overflow via a
1732  // backedge guarding condition, we can also compute a backedge
1733  // taken count for the loop. The exceptions are assumptions and
1734  // guards present in the loop -- SCEV is not great at exploiting
1735  // these to compute max backedge taken counts, but can still use
1736  // these to prove lack of overflow. Use this fact to avoid
1737  // doing extra work that may not pay off.
1738  if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1739  !AC.assumptions().empty()) {
1740 
1741  auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1742  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1743  if (AR->hasNoUnsignedWrap()) {
1744  // Same as nuw case above - duplicated here to avoid a compile time
1745  // issue. It's not clear that the order of checks does matter, but
1746  // it's one of two issue possible causes for a change which was
1747  // reverted. Be conservative for the moment.
1748  Start =
1749  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1750  Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1751  return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1752  }
1753 
1754  // For a negative step, we can extend the operands iff doing so only
1755  // traverses values in the range zext([0,UINT_MAX]).
1756  if (isKnownNegative(Step)) {
1758  getSignedRangeMin(Step));
1761  // Cache knowledge of AR NW, which is propagated to this
1762  // AddRec. Negative step causes unsigned wrap, but it
1763  // still can't self-wrap.
1764  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1765  // Return the expression with the addrec on the outside.
1766  Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1767  Depth + 1);
1768  Step = getSignExtendExpr(Step, Ty, Depth + 1);
1769  return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1770  }
1771  }
1772  }
1773 
1774  // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1775  // if D + (C - D + Step * n) could be proven to not unsigned wrap
1776  // where D maximizes the number of trailing zeros of (C - D + Step * n)
1777  if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1778  const APInt &C = SC->getAPInt();
1779  const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1780  if (D != 0) {
1781  const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1782  const SCEV *SResidual =
1783  getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1784  const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1785  return getAddExpr(SZExtD, SZExtR,
1787  Depth + 1);
1788  }
1789  }
1790 
1791  if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1792  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1793  Start =
1794  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1795  Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1796  return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1797  }
1798  }
1799 
1800  // zext(A % B) --> zext(A) % zext(B)
1801  {
1802  const SCEV *LHS;
1803  const SCEV *RHS;
1804  if (matchURem(Op, LHS, RHS))
1805  return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1806  getZeroExtendExpr(RHS, Ty, Depth + 1));
1807  }
1808 
1809  // zext(A / B) --> zext(A) / zext(B).
1810  if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1811  return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1812  getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1813 
1814  if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1815  // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1816  if (SA->hasNoUnsignedWrap()) {
1817  // If the addition does not unsign overflow then we can, by definition,
1818  // commute the zero extension with the addition operation.
1820  for (const auto *Op : SA->operands())
1821  Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1822  return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1823  }
1824 
1825  // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1826  // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1827  // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1828  //
1829  // Often address arithmetics contain expressions like
1830  // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1831  // This transformation is useful while proving that such expressions are
1832  // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1833  if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1834  const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1835  if (D != 0) {
1836  const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1837  const SCEV *SResidual =
1839  const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1840  return getAddExpr(SZExtD, SZExtR,
1842  Depth + 1);
1843  }
1844  }
1845  }
1846 
1847  if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1848  // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1849  if (SM->hasNoUnsignedWrap()) {
1850  // If the multiply does not unsign overflow then we can, by definition,
1851  // commute the zero extension with the multiply operation.
1853  for (const auto *Op : SM->operands())
1854  Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1855  return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1856  }
1857 
1858  // zext(2^K * (trunc X to iN)) to iM ->
1859  // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1860  //
1861  // Proof:
1862  //
1863  // zext(2^K * (trunc X to iN)) to iM
1864  // = zext((trunc X to iN) << K) to iM
1865  // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1866  // (because shl removes the top K bits)
1867  // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1868  // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1869  //
1870  if (SM->getNumOperands() == 2)
1871  if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1872  if (MulLHS->getAPInt().isPowerOf2())
1873  if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1874  int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
1875  MulLHS->getAPInt().logBase2();
1876  Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1877  return getMulExpr(
1878  getZeroExtendExpr(MulLHS, Ty),
1880  getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
1881  SCEV::FlagNUW, Depth + 1);
1882  }
1883  }
1884 
1885  // The cast wasn't folded; create an explicit cast node.
1886  // Recompute the insert position, as it may have been invalidated.
1887  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1888  SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1889  Op, Ty);
1890  UniqueSCEVs.InsertNode(S, IP);
1891  registerUser(S, Op);
1892  return S;
1893 }
1894 
1895 const SCEV *
1897  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1898  "This is not an extending conversion!");
1899  assert(isSCEVable(Ty) &&
1900  "This is not a conversion to a SCEVable type!");
1901  assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1902  Ty = getEffectiveSCEVType(Ty);
1903 
1904  // Fold if the operand is constant.
1905  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1906  return getConstant(
1907  cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(), Ty)));
1908 
1909  // sext(sext(x)) --> sext(x)
1910  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1911  return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1912 
1913  // sext(zext(x)) --> zext(x)
1914  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1915  return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1916 
1917  // Before doing any expensive analysis, check to see if we've already
1918  // computed a SCEV for this Op and Ty.
1920  ID.AddInteger(scSignExtend);
1921  ID.AddPointer(Op);
1922  ID.AddPointer(Ty);
1923  void *IP = nullptr;
1924  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1925  // Limit recursion depth.
1926  if (Depth > MaxCastDepth) {
1927  SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1928  Op, Ty);
1929  UniqueSCEVs.InsertNode(S, IP);
1930  registerUser(S, Op);
1931  return S;
1932  }
1933 
1934  // sext(trunc(x)) --> sext(x) or x or trunc(x)
1935  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1936  // It's possible the bits taken off by the truncate were all sign bits. If
1937  // so, we should be able to simplify this further.
1938  const SCEV *X = ST->getOperand();
1940  unsigned TruncBits = getTypeSizeInBits(ST->getType());
1941  unsigned NewBits = getTypeSizeInBits(Ty);
1942  if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1943  CR.sextOrTrunc(NewBits)))
1944  return getTruncateOrSignExtend(X, Ty, Depth);
1945  }
1946 
1947  if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1948  // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1949  if (SA->hasNoSignedWrap()) {
1950  // If the addition does not sign overflow then we can, by definition,
1951  // commute the sign extension with the addition operation.
1953  for (const auto *Op : SA->operands())
1954  Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1955  return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1956  }
1957 
1958  // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1959  // if D + (C - D + x + y + ...) could be proven to not signed wrap
1960  // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1961  //
1962  // For instance, this will bring two seemingly different expressions:
1963  // 1 + sext(5 + 20 * %x + 24 * %y) and
1964  // sext(6 + 20 * %x + 24 * %y)
1965  // to the same form:
1966  // 2 + sext(4 + 20 * %x + 24 * %y)
1967  if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1968  const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1969  if (D != 0) {
1970  const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1971  const SCEV *SResidual =
1973  const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1974  return getAddExpr(SSExtD, SSExtR,
1976  Depth + 1);
1977  }
1978  }
1979  }
1980  // If the input value is a chrec scev, and we can prove that the value
1981  // did not overflow the old, smaller, value, we can sign extend all of the
1982  // operands (often constants). This allows analysis of something like
1983  // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
1984  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1985  if (AR->isAffine()) {
1986  const SCEV *Start = AR->getStart();
1987  const SCEV *Step = AR->getStepRecurrence(*this);
1988  unsigned BitWidth = getTypeSizeInBits(AR->getType());
1989  const Loop *L = AR->getLoop();
1990 
1991  if (!AR->hasNoSignedWrap()) {
1992  auto NewFlags = proveNoWrapViaConstantRanges(AR);
1993  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1994  }
1995 
1996  // If we have special knowledge that this addrec won't overflow,
1997  // we don't need to do any further analysis.
1998  if (AR->hasNoSignedWrap()) {
1999  Start =
2000  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2001  Step = getSignExtendExpr(Step, Ty, Depth + 1);
2002  return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2003  }
2004 
2005  // Check whether the backedge-taken count is SCEVCouldNotCompute.
2006  // Note that this serves two purposes: It filters out loops that are
2007  // simply not analyzable, and it covers the case where this code is
2008  // being called from within backedge-taken count analysis, such that
2009  // attempting to ask for the backedge-taken count would likely result
2010  // in infinite recursion. In the later case, the analysis code will
2011  // cope with a conservative value, and it will take care to purge
2012  // that value once it has finished.
2013  const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2014  if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2015  // Manually compute the final value for AR, checking for
2016  // overflow.
2017 
2018  // Check whether the backedge-taken count can be losslessly casted to
2019  // the addrec's type. The count is always unsigned.
2020  const SCEV *CastedMaxBECount =
2021  getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2022  const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2023  CastedMaxBECount, MaxBECount->getType(), Depth);
2024  if (MaxBECount == RecastedMaxBECount) {
2025  Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2026  // Check whether Start+Step*MaxBECount has no signed overflow.
2027  const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2028  SCEV::FlagAnyWrap, Depth + 1);
2029  const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2031  Depth + 1),
2032  WideTy, Depth + 1);
2033  const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2034  const SCEV *WideMaxBECount =
2035  getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2036  const SCEV *OperandExtendedAdd =
2037  getAddExpr(WideStart,
2038  getMulExpr(WideMaxBECount,
2039  getSignExtendExpr(Step, WideTy, Depth + 1),
2040  SCEV::FlagAnyWrap, Depth + 1),
2041  SCEV::FlagAnyWrap, Depth + 1);
2042  if (SAdd == OperandExtendedAdd) {
2043  // Cache knowledge of AR NSW, which is propagated to this AddRec.
2044  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2045  // Return the expression with the addrec on the outside.
2046  Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2047  Depth + 1);
2048  Step = getSignExtendExpr(Step, Ty, Depth + 1);
2049  return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2050  }
2051  // Similar to above, only this time treat the step value as unsigned.
2052  // This covers loops that count up with an unsigned step.
2053  OperandExtendedAdd =
2054  getAddExpr(WideStart,
2055  getMulExpr(WideMaxBECount,
2056  getZeroExtendExpr(Step, WideTy, Depth + 1),
2057  SCEV::FlagAnyWrap, Depth + 1),
2058  SCEV::FlagAnyWrap, Depth + 1);
2059  if (SAdd == OperandExtendedAdd) {
2060  // If AR wraps around then
2061  //
2062  // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2063  // => SAdd != OperandExtendedAdd
2064  //
2065  // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2066  // (SAdd == OperandExtendedAdd => AR is NW)
2067 
2068  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2069 
2070  // Return the expression with the addrec on the outside.
2071  Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2072  Depth + 1);
2073  Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2074  return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2075  }
2076  }
2077  }
2078 
2079  auto NewFlags = proveNoSignedWrapViaInduction(AR);
2080  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2081  if (AR->hasNoSignedWrap()) {
2082  // Same as nsw case above - duplicated here to avoid a compile time
2083  // issue. It's not clear that the order of checks does matter, but
2084  // it's one of two issue possible causes for a change which was
2085  // reverted. Be conservative for the moment.
2086  Start =
2087  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2088  Step = getSignExtendExpr(Step, Ty, Depth + 1);
2089  return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2090  }
2091 
2092  // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2093  // if D + (C - D + Step * n) could be proven to not signed wrap
2094  // where D maximizes the number of trailing zeros of (C - D + Step * n)
2095  if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2096  const APInt &C = SC->getAPInt();
2097  const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2098  if (D != 0) {
2099  const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2100  const SCEV *SResidual =
2101  getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2102  const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2103  return getAddExpr(SSExtD, SSExtR,
2105  Depth + 1);
2106  }
2107  }
2108 
2109  if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2110  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2111  Start =
2112  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2113  Step = getSignExtendExpr(Step, Ty, Depth + 1);
2114  return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2115  }
2116  }
2117 
2118  // If the input value is provably positive and we could not simplify
2119  // away the sext build a zext instead.
2120  if (isKnownNonNegative(Op))
2121  return getZeroExtendExpr(Op, Ty, Depth + 1);
2122 
2123  // The cast wasn't folded; create an explicit cast node.
2124  // Recompute the insert position, as it may have been invalidated.
2125  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2126  SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2127  Op, Ty);
2128  UniqueSCEVs.InsertNode(S, IP);
2129  registerUser(S, { Op });
2130  return S;
2131 }
2132 
2134  Type *Ty) {
2135  switch (Kind) {
2136  case scTruncate:
2137  return getTruncateExpr(Op, Ty);
2138  case scZeroExtend:
2139  return getZeroExtendExpr(Op, Ty);
2140  case scSignExtend:
2141  return getSignExtendExpr(Op, Ty);
2142  case scPtrToInt:
2143  return getPtrToIntExpr(Op, Ty);
2144  default:
2145  llvm_unreachable("Not a SCEV cast expression!");
2146  }
2147 }
2148 
2149 /// getAnyExtendExpr - Return a SCEV for the given operand extended with
2150 /// unspecified bits out to the given type.
2152  Type *Ty) {
2153  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2154  "This is not an extending conversion!");
2155  assert(isSCEVable(Ty) &&
2156  "This is not a conversion to a SCEVable type!");
2157  Ty = getEffectiveSCEVType(Ty);
2158 
2159  // Sign-extend negative constants.
2160  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2161  if (SC->getAPInt().isNegative())
2162  return getSignExtendExpr(Op, Ty);
2163 
2164  // Peel off a truncate cast.
2165  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2166  const SCEV *NewOp = T->getOperand();
2167  if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2168  return getAnyExtendExpr(NewOp, Ty);
2169  return getTruncateOrNoop(NewOp, Ty);
2170  }
2171 
2172  // Next try a zext cast. If the cast is folded, use it.
2173  const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2174  if (!isa<SCEVZeroExtendExpr>(ZExt))
2175  return ZExt;
2176 
2177  // Next try a sext cast. If the cast is folded, use it.
2178  const SCEV *SExt = getSignExtendExpr(Op, Ty);
2179  if (!isa<SCEVSignExtendExpr>(SExt))
2180  return SExt;
2181 
2182  // Force the cast to be folded into the operands of an addrec.
2183  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2185  for (const SCEV *Op : AR->operands())
2186  Ops.push_back(getAnyExtendExpr(Op, Ty));
2187  return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2188  }
2189 
2190  // If the expression is obviously signed, use the sext cast value.
2191  if (isa<SCEVSMaxExpr>(Op))
2192  return SExt;
2193 
2194  // Absent any other information, use the zext cast value.
2195  return ZExt;
2196 }
2197 
2198 /// Process the given Ops list, which is a list of operands to be added under
2199 /// the given scale, update the given map. This is a helper function for
2200 /// getAddRecExpr. As an example of what it does, given a sequence of operands
2201 /// that would form an add expression like this:
2202 ///
2203 /// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2204 ///
2205 /// where A and B are constants, update the map with these values:
2206 ///
2207 /// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2208 ///
2209 /// and add 13 + A*B*29 to AccumulatedConstant.
2210 /// This will allow getAddRecExpr to produce this:
2211 ///
2212 /// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2213 ///
2214 /// This form often exposes folding opportunities that are hidden in
2215 /// the original operand list.
2216 ///
2217 /// Return true iff it appears that any interesting folding opportunities
2218 /// may be exposed. This helps getAddRecExpr short-circuit extra work in
2219 /// the common case where no interesting opportunities are present, and
2220 /// is also used as a check to avoid infinite recursion.
2221 static bool
2224  APInt &AccumulatedConstant,
2225  const SCEV *const *Ops, size_t NumOperands,
2226  const APInt &Scale,
2227  ScalarEvolution &SE) {
2228  bool Interesting = false;
2229 
2230  // Iterate over the add operands. They are sorted, with constants first.
2231  unsigned i = 0;
2232  while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2233  ++i;
2234  // Pull a buried constant out to the outside.
2235  if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2236  Interesting = true;
2237  AccumulatedConstant += Scale * C->getAPInt();
2238  }
2239 
2240  // Next comes everything else. We're especially interested in multiplies
2241  // here, but they're in the middle, so just visit the rest with one loop.
2242  for (; i != NumOperands; ++i) {
2243  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
2244  if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2245  APInt NewScale =
2246  Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2247  if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2248  // A multiplication of a constant with another add; recurse.
2249  const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2250  Interesting |=
2251  CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2252  Add->op_begin(), Add->getNumOperands(),
2253  NewScale, SE);
2254  } else {
2255  // A multiplication of a constant with some other value. Update
2256  // the map.
2258  const SCEV *Key = SE.getMulExpr(MulOps);
2259  auto Pair = M.insert({Key, NewScale});
2260  if (Pair.second) {
2261  NewOps.push_back(Pair.first->first);
2262  } else {
2263  Pair.first->second += NewScale;
2264  // The map already had an entry for this value, which may indicate
2265  // a folding opportunity.
2266  Interesting = true;
2267  }
2268  }
2269  } else {
2270  // An ordinary operand. Update the map.
2271  std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2272  M.insert({Ops[i], Scale});
2273  if (Pair.second) {
2274  NewOps.push_back(Pair.first->first);
2275  } else {
2276  Pair.first->second += Scale;
2277  // The map already had an entry for this value, which may indicate
2278  // a folding opportunity.
2279  Interesting = true;
2280  }
2281  }
2282  }
2283 
2284  return Interesting;
2285 }
2286 
2288  const SCEV *LHS, const SCEV *RHS) {
2289  const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2290  SCEV::NoWrapFlags, unsigned);
2291  switch (BinOp) {
2292  default:
2293  llvm_unreachable("Unsupported binary op");
2294  case Instruction::Add:
2296  break;
2297  case Instruction::Sub:
2299  break;
2300  case Instruction::Mul:
2302  break;
2303  }
2304 
2305  const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2308 
2309  // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2310  auto *NarrowTy = cast<IntegerType>(LHS->getType());
2311  auto *WideTy =
2312  IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2313 
2314  const SCEV *A = (this->*Extension)(
2315  (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2316  const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2317  const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2318  const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2319  return A == B;
2320 }
2321 
2322 std::pair<SCEV::NoWrapFlags, bool /*Deduced*/>
2324  const OverflowingBinaryOperator *OBO) {
2325  SCEV::NoWrapFlags Flags = SCEV::NoWrapFlags::FlagAnyWrap;
2326 
2327  if (OBO->hasNoUnsignedWrap())
2328  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2329  if (OBO->hasNoSignedWrap())
2330  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2331 
2332  bool Deduced = false;
2333 
2334  if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2335  return {Flags, Deduced};
2336 
2337  if (OBO->getOpcode() != Instruction::Add &&
2338  OBO->getOpcode() != Instruction::Sub &&
2339  OBO->getOpcode() != Instruction::Mul)
2340  return {Flags, Deduced};
2341 
2342  const SCEV *LHS = getSCEV(OBO->getOperand(0));
2343  const SCEV *RHS = getSCEV(OBO->getOperand(1));
2344 
2345  if (!OBO->hasNoUnsignedWrap() &&
2347  /* Signed */ false, LHS, RHS)) {
2348  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2349  Deduced = true;
2350  }
2351 
2352  if (!OBO->hasNoSignedWrap() &&
2354  /* Signed */ true, LHS, RHS)) {
2355  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2356  Deduced = true;
2357  }
2358 
2359  return {Flags, Deduced};
2360 }
2361 
2362 // We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2363 // `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2364 // can't-overflow flags for the operation if possible.
2365 static SCEV::NoWrapFlags
2367  const ArrayRef<const SCEV *> Ops,
2368  SCEV::NoWrapFlags Flags) {
2369  using namespace std::placeholders;
2370 
2371  using OBO = OverflowingBinaryOperator;
2372 
2373  bool CanAnalyze =
2374  Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr;
2375  (void)CanAnalyze;
2376  assert(CanAnalyze && "don't call from other places!");
2377 
2378  int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2379  SCEV::NoWrapFlags SignOrUnsignWrap =
2380  ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2381 
2382  // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2383  auto IsKnownNonNegative = [&](const SCEV *S) {
2384  return SE->isKnownNonNegative(S);
2385  };
2386 
2387  if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2388  Flags =
2389  ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2390 
2391  SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2392 
2393  if (SignOrUnsignWrap != SignOrUnsignMask &&
2394  (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2395  isa<SCEVConstant>(Ops[0])) {
2396 
2397  auto Opcode = [&] {
2398  switch (Type) {
2399  case scAddExpr:
2400  return Instruction::Add;
2401  case scMulExpr:
2402  return Instruction::Mul;
2403  default:
2404  llvm_unreachable("Unexpected SCEV op.");
2405  }
2406  }();
2407 
2408  const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2409 
2410  // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2411  if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2413  Opcode, C, OBO::NoSignedWrap);
2414  if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2415  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2416  }
2417 
2418  // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2419  if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2421  Opcode, C, OBO::NoUnsignedWrap);
2422  if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2423  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2424  }
2425  }
2426 
2427  // <0,+,nonnegative><nw> is also nuw
2428  // TODO: Add corresponding nsw case
2430  !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2431  Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2432  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2433 
2434  // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2436  Ops.size() == 2) {
2437  if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2438  if (UDiv->getOperand(1) == Ops[1])
2439  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2440  if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2441  if (UDiv->getOperand(1) == Ops[0])
2442  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2443  }
2444 
2445  return Flags;
2446 }
2447 
2449  return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2450 }
2451 
2452 /// Get a canonical add expression, or something simpler if possible.
2454  SCEV::NoWrapFlags OrigFlags,
2455  unsigned Depth) {
2456  assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2457  "only nuw or nsw allowed");
2458  assert(!Ops.empty() && "Cannot get empty add!");
2459  if (Ops.size() == 1) return Ops[0];
2460 #ifndef NDEBUG
2461  Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2462  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2463  assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2464  "SCEVAddExpr operand types don't match!");
2465  unsigned NumPtrs = count_if(
2466  Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2467  assert(NumPtrs <= 1 && "add has at most one pointer operand");
2468 #endif
2469 
2470  // Sort by complexity, this groups all similar expression types together.
2471  GroupByComplexity(Ops, &LI, DT);
2472 
2473  // If there are any constants, fold them together.
2474  unsigned Idx = 0;
2475  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2476  ++Idx;
2477  assert(Idx < Ops.size());
2478  while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2479  // We found two constants, fold them together!
2480  Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt());
2481  if (Ops.size() == 2) return Ops[0];
2482  Ops.erase(Ops.begin()+1); // Erase the folded element
2483  LHSC = cast<SCEVConstant>(Ops[0]);
2484  }
2485 
2486  // If we are left with a constant zero being added, strip it off.
2487  if (LHSC->getValue()->isZero()) {
2488  Ops.erase(Ops.begin());
2489  --Idx;
2490  }
2491 
2492  if (Ops.size() == 1) return Ops[0];
2493  }
2494 
2495  // Delay expensive flag strengthening until necessary.
2496  auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
2497  return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2498  };
2499 
2500  // Limit recursion calls depth.
2501  if (Depth > MaxArithDepth || hasHugeExpression(Ops))
2502  return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2503 
2504  if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2505  // Don't strengthen flags if we have no new information.
2506  SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2507  if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2508  Add->setNoWrapFlags(ComputeFlags(Ops));
2509  return S;
2510  }
2511 
2512  // Okay, check to see if the same value occurs in the operand list more than
2513  // once. If so, merge them together into an multiply expression. Since we
2514  // sorted the list, these values are required to be adjacent.
2515  Type *Ty = Ops[0]->getType();
2516  bool FoundMatch = false;
2517  for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2518  if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2519  // Scan ahead to count how many equal operands there are.
2520  unsigned Count = 2;
2521  while (i+Count != e && Ops[i+Count] == Ops[i])
2522  ++Count;
2523  // Merge the values into a multiply.
2524  const SCEV *Scale = getConstant(Ty, Count);
2525  const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2526  if (Ops.size() == Count)
2527  return Mul;
2528  Ops[i] = Mul;
2529  Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2530  --i; e -= Count - 1;
2531  FoundMatch = true;
2532  }
2533  if (FoundMatch)
2534  return getAddExpr(Ops, OrigFlags, Depth + 1);
2535 
2536  // Check for truncates. If all the operands are truncated from the same
2537  // type, see if factoring out the truncate would permit the result to be
2538  // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2539  // if the contents of the resulting outer trunc fold to something simple.
2540  auto FindTruncSrcType = [&]() -> Type * {
2541  // We're ultimately looking to fold an addrec of truncs and muls of only
2542  // constants and truncs, so if we find any other types of SCEV
2543  // as operands of the addrec then we bail and return nullptr here.
2544  // Otherwise, we return the type of the operand of a trunc that we find.
2545  if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2546  return T->getOperand()->getType();
2547  if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2548  const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2549  if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2550  return T->getOperand()->getType();
2551  }
2552  return nullptr;
2553  };
2554  if (auto *SrcType = FindTruncSrcType()) {
2556  bool Ok = true;
2557  // Check all the operands to see if they can be represented in the
2558  // source type of the truncate.
2559  for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
2560  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
2561  if (T->getOperand()->getType() != SrcType) {
2562  Ok = false;
2563  break;
2564  }
2565  LargeOps.push_back(T->getOperand());
2566  } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2567  LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2568  } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
2569  SmallVector<const SCEV *, 8> LargeMulOps;
2570  for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2571  if (const SCEVTruncateExpr *T =
2572  dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2573  if (T->getOperand()->getType() != SrcType) {
2574  Ok = false;
2575  break;
2576  }
2577  LargeMulOps.push_back(T->getOperand());
2578  } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2579  LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2580  } else {
2581  Ok = false;
2582  break;
2583  }
2584  }
2585  if (Ok)
2586  LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2587  } else {
2588  Ok = false;
2589  break;
2590  }
2591  }
2592  if (Ok) {
2593  // Evaluate the expression in the larger type.
2594  const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2595  // If it folds to something simple, use it. Otherwise, don't.
2596  if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2597  return getTruncateExpr(Fold, Ty);
2598  }
2599  }
2600 
2601  if (Ops.size() == 2) {
2602  // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2603  // C2 can be folded in a way that allows retaining wrapping flags of (X +
2604  // C1).
2605  const SCEV *A = Ops[0];
2606  const SCEV *B = Ops[1];
2607  auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2608  auto *C = dyn_cast<SCEVConstant>(A);
2609  if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2610  auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2611  auto C2 = C->getAPInt();
2612  SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2613 
2614  APInt ConstAdd = C1 + C2;
2615  auto AddFlags = AddExpr->getNoWrapFlags();
2616  // Adding a smaller constant is NUW if the original AddExpr was NUW.
2617  if (ScalarEvolution::hasFlags(AddFlags, SCEV::FlagNUW) &&
2618  ConstAdd.ule(C1)) {
2619  PreservedFlags =
2620  ScalarEvolution::setFlags(PreservedFlags, SCEV::FlagNUW);
2621  }
2622 
2623  // Adding a constant with the same sign and small magnitude is NSW, if the
2624  // original AddExpr was NSW.
2625  if (ScalarEvolution::hasFlags(AddFlags, SCEV::FlagNSW) &&
2626  C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2627  ConstAdd.abs().ule(C1.abs())) {
2628  PreservedFlags =
2629  ScalarEvolution::setFlags(PreservedFlags, SCEV::FlagNSW);
2630  }
2631 
2632  if (PreservedFlags != SCEV::FlagAnyWrap) {
2633  SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2634  NewOps[0] = getConstant(ConstAdd);
2635  return getAddExpr(NewOps, PreservedFlags);
2636  }
2637  }
2638  }
2639 
2640  // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2641  if (Ops.size() == 2) {
2642  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
2643  if (Mul && Mul->getNumOperands() == 2 &&
2644  Mul->getOperand(0)->isAllOnesValue()) {
2645  const SCEV *X;
2646  const SCEV *Y;
2647  if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
2648  return getMulExpr(Y, getUDivExpr(X, Y));
2649  }
2650  }
2651  }
2652 
2653  // Skip past any other cast SCEVs.
2654  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2655  ++Idx;
2656 
2657  // If there are add operands they would be next.
2658  if (Idx < Ops.size()) {
2659  bool DeletedAdd = false;
2660  // If the original flags and all inlined SCEVAddExprs are NUW, use the
2661  // common NUW flag for expression after inlining. Other flags cannot be
2662  // preserved, because they may depend on the original order of operations.
2663  SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2664  while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2665  if (Ops.size() > AddOpsInlineThreshold ||
2666  Add->getNumOperands() > AddOpsInlineThreshold)
2667  break;
2668  // If we have an add, expand the add operands onto the end of the operands
2669  // list.
2670  Ops.erase(Ops.begin()+Idx);
2671  Ops.append(Add->op_begin(), Add->op_end());
2672  DeletedAdd = true;
2673  CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2674  }
2675 
2676  // If we deleted at least one add, we added operands to the end of the list,
2677  // and they are not necessarily sorted. Recurse to resort and resimplify
2678  // any operands we just acquired.
2679  if (DeletedAdd)
2680  return getAddExpr(Ops, CommonFlags, Depth + 1);
2681  }
2682 
2683  // Skip over the add expression until we get to a multiply.
2684  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2685  ++Idx;
2686 
2687  // Check to see if there are any folding opportunities present with
2688  // operands multiplied by constant values.
2689  if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2693  APInt AccumulatedConstant(BitWidth, 0);
2694  if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2695  Ops.data(), Ops.size(),
2696  APInt(BitWidth, 1), *this)) {
2697  struct APIntCompare {
2698  bool operator()(const APInt &LHS, const APInt &RHS) const {
2699  return LHS.ult(RHS);
2700  }
2701  };
2702 
2703  // Some interesting folding opportunity is present, so its worthwhile to
2704  // re-generate the operands list. Group the operands by constant scale,
2705  // to avoid multiplying by the same constant scale multiple times.
2706  std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2707  for (const SCEV *NewOp : NewOps)
2708  MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2709  // Re-generate the operands list.
2710  Ops.clear();
2711  if (AccumulatedConstant != 0)
2712  Ops.push_back(getConstant(AccumulatedConstant));
2713  for (auto &MulOp : MulOpLists) {
2714  if (MulOp.first == 1) {
2715  Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2716  } else if (MulOp.first != 0) {
2717  Ops.push_back(getMulExpr(
2718  getConstant(MulOp.first),
2719  getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2720  SCEV::FlagAnyWrap, Depth + 1));
2721  }
2722  }
2723  if (Ops.empty())
2724  return getZero(Ty);
2725  if (Ops.size() == 1)
2726  return Ops[0];
2727  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2728  }
2729  }
2730 
2731  // If we are adding something to a multiply expression, make sure the
2732  // something is not already an operand of the multiply. If so, merge it into
2733  // the multiply.
2734  for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2735  const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2736  for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2737  const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2738  if (isa<SCEVConstant>(MulOpSCEV))
2739  continue;
2740  for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2741  if (MulOpSCEV == Ops[AddOp]) {
2742  // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2743  const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2744  if (Mul->getNumOperands() != 2) {
2745  // If the multiply has more than two operands, we must get the
2746  // Y*Z term.
2748  Mul->op_begin()+MulOp);
2749  MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
2750  InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2751  }
2752  SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2753  const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2754  const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2755  SCEV::FlagAnyWrap, Depth + 1);
2756  if (Ops.size() == 2) return OuterMul;
2757  if (AddOp < Idx) {
2758  Ops.erase(Ops.begin()+AddOp);
2759  Ops.erase(Ops.begin()+Idx-1);
2760  } else {
2761  Ops.erase(Ops.begin()+Idx);
2762  Ops.erase(Ops.begin()+AddOp-1);
2763  }
2764  Ops.push_back(OuterMul);
2765  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2766  }
2767 
2768  // Check this multiply against other multiplies being added together.
2769  for (unsigned OtherMulIdx = Idx+1;
2770  OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2771  ++OtherMulIdx) {
2772  const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2773  // If MulOp occurs in OtherMul, we can fold the two multiplies
2774  // together.
2775  for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2776  OMulOp != e; ++OMulOp)
2777  if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2778  // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2779  const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2780  if (Mul->getNumOperands() != 2) {
2782  Mul->op_begin()+MulOp);
2783  MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
2784  InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2785  }
2786  const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2787  if (OtherMul->getNumOperands() != 2) {
2788  SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(),
2789  OtherMul->op_begin()+OMulOp);
2790  MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end());
2791  InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2792  }
2793  SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2794  const SCEV *InnerMulSum =
2795  getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2796  const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2797  SCEV::FlagAnyWrap, Depth + 1);
2798  if (Ops.size() == 2) return OuterMul;
2799  Ops.erase(Ops.begin()+Idx);
2800  Ops.erase(Ops.begin()+OtherMulIdx-1);
2801  Ops.push_back(OuterMul);
2802  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2803  }
2804  }
2805  }
2806  }
2807 
2808  // If there are any add recurrences in the operands list, see if any other
2809  // added values are loop invariant. If so, we can fold them into the
2810  // recurrence.
2811  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2812  ++Idx;
2813 
2814  // Scan over all recurrences, trying to fold loop invariants into them.
2815  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2816  // Scan all of the other operands to this add and add them to the vector if
2817  // they are loop invariant w.r.t. the recurrence.
2819  const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2820  const Loop *AddRecLoop = AddRec->getLoop();
2821  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2822  if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2823  LIOps.push_back(Ops[i]);
2824  Ops.erase(Ops.begin()+i);
2825  --i; --e;
2826  }
2827 
2828  // If we found some loop invariants, fold them into the recurrence.
2829  if (!LIOps.empty()) {
2830  // Compute nowrap flags for the addition of the loop-invariant ops and
2831  // the addrec. Temporarily push it as an operand for that purpose. These
2832  // flags are valid in the scope of the addrec only.
2833  LIOps.push_back(AddRec);
2834  SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2835  LIOps.pop_back();
2836 
2837  // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2838  LIOps.push_back(AddRec->getStart());
2839 
2840  SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2841 
2842  // It is not in general safe to propagate flags valid on an add within
2843  // the addrec scope to one outside it. We must prove that the inner
2844  // scope is guaranteed to execute if the outer one does to be able to
2845  // safely propagate. We know the program is undefined if poison is
2846  // produced on the inner scoped addrec. We also know that *for this use*
2847  // the outer scoped add can't overflow (because of the flags we just
2848  // computed for the inner scoped add) without the program being undefined.
2849  // Proving that entry to the outer scope neccesitates entry to the inner
2850  // scope, thus proves the program undefined if the flags would be violated
2851  // in the outer scope.
2852  SCEV::NoWrapFlags AddFlags = Flags;
2853  if (AddFlags != SCEV::FlagAnyWrap) {
2854  auto *DefI = getDefiningScopeBound(LIOps);
2855  auto *ReachI = &*AddRecLoop->getHeader()->begin();
2856  if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2857  AddFlags = SCEV::FlagAnyWrap;
2858  }
2859  AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2860 
2861  // Build the new addrec. Propagate the NUW and NSW flags if both the
2862  // outer add and the inner addrec are guaranteed to have no overflow.
2863  // Always propagate NW.
2864  Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2865  const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2866 
2867  // If all of the other operands were loop invariant, we are done.
2868  if (Ops.size() == 1) return NewRec;
2869 
2870  // Otherwise, add the folded AddRec by the non-invariant parts.
2871  for (unsigned i = 0;; ++i)
2872  if (Ops[i] == AddRec) {
2873  Ops[i] = NewRec;
2874  break;
2875  }
2876  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2877  }
2878 
2879  // Okay, if there weren't any loop invariants to be folded, check to see if
2880  // there are multiple AddRec's with the same loop induction variable being
2881  // added together. If so, we can fold them.
2882  for (unsigned OtherIdx = Idx+1;
2883  OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2884  ++OtherIdx) {
2885  // We expect the AddRecExpr's to be sorted in reverse dominance order,
2886  // so that the 1st found AddRecExpr is dominated by all others.
2887  assert(DT.dominates(
2888  cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2889  AddRec->getLoop()->getHeader()) &&
2890  "AddRecExprs are not sorted in reverse dominance order?");
2891  if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2892  // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2893  SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2894  for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2895  ++OtherIdx) {
2896  const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2897  if (OtherAddRec->getLoop() == AddRecLoop) {
2898  for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2899  i != e; ++i) {
2900  if (i >= AddRecOps.size()) {
2901  AddRecOps.append(OtherAddRec->op_begin()+i,
2902  OtherAddRec->op_end());
2903  break;
2904  }
2905  SmallVector<const SCEV *, 2> TwoOps = {
2906  AddRecOps[i], OtherAddRec->getOperand(i)};
2907  AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2908  }
2909  Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2910  }
2911  }
2912  // Step size has changed, so we cannot guarantee no self-wraparound.
2913  Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2914  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2915  }
2916  }
2917 
2918  // Otherwise couldn't fold anything into this recurrence. Move onto the
2919  // next one.
2920  }
2921 
2922  // Okay, it looks like we really DO need an add expr. Check to see if we
2923  // already have one, otherwise create a new one.
2924  return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2925 }
2926 
2927 const SCEV *
2928 ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2929  SCEV::NoWrapFlags Flags) {
2931  ID.AddInteger(scAddExpr);
2932  for (const SCEV *Op : Ops)
2933  ID.AddPointer(Op);
2934  void *IP = nullptr;
2935  SCEVAddExpr *S =
2936  static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2937  if (!S) {
2938  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2939  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2940  S = new (SCEVAllocator)
2941  SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
2942  UniqueSCEVs.InsertNode(S, IP);
2943  registerUser(S, Ops);
2944  }
2945  S->setNoWrapFlags(Flags);
2946  return S;
2947 }
2948 
2949 const SCEV *
2950 ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
2951  const Loop *L, SCEV::NoWrapFlags Flags) {
2953  ID.AddInteger(scAddRecExpr);
2954  for (const SCEV *Op : Ops)
2955  ID.AddPointer(Op);
2956  ID.AddPointer(L);
2957  void *IP = nullptr;
2958  SCEVAddRecExpr *S =
2959  static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2960  if (!S) {
2961  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2962  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2963  S = new (SCEVAllocator)
2964  SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
2965  UniqueSCEVs.InsertNode(S, IP);
2966  LoopUsers[L].push_back(S);
2967  registerUser(S, Ops);
2968  }
2969  setNoWrapFlags(S, Flags);
2970  return S;
2971 }
2972 
2973 const SCEV *
2974 ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
2975  SCEV::NoWrapFlags Flags) {
2977  ID.AddInteger(scMulExpr);
2978  for (const SCEV *Op : Ops)
2979  ID.AddPointer(Op);
2980  void *IP = nullptr;
2981  SCEVMulExpr *S =
2982  static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2983  if (!S) {
2984  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2985  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2986  S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
2987  O, Ops.size());
2988  UniqueSCEVs.InsertNode(S, IP);
2989  registerUser(S, Ops);
2990  }
2991  S->setNoWrapFlags(Flags);
2992  return S;
2993 }
2994 
2995 static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
2996  uint64_t k = i*j;
2997  if (j > 1 && k / j != i) Overflow = true;
2998  return k;
2999 }
3000 
3001 /// Compute the result of "n choose k", the binomial coefficient. If an
3002 /// intermediate computation overflows, Overflow will be set and the return will
3003 /// be garbage. Overflow is not cleared on absence of overflow.
3004 static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3005  // We use the multiplicative formula:
3006  // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3007  // At each iteration, we take the n-th term of the numeral and divide by the
3008  // (k-n)th term of the denominator. This division will always produce an
3009  // integral result, and helps reduce the chance of overflow in the
3010  // intermediate computations. However, we can still overflow even when the
3011  // final result would fit.
3012 
3013  if (n == 0 || n == k) return 1;
3014  if (k > n) return 0;
3015 
3016  if (k > n/2)
3017  k = n-k;
3018 
3019  uint64_t r = 1;
3020  for (uint64_t i = 1; i <= k; ++i) {
3021  r = umul_ov(r, n-(i-1), Overflow);
3022  r /= i;
3023  }
3024  return r;
3025 }
3026 
3027 /// Determine if any of the operands in this SCEV are a constant or if
3028 /// any of the add or multiply expressions in this SCEV contain a constant.
3029 static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3030  struct FindConstantInAddMulChain {
3031  bool FoundConstant = false;
3032 
3033  bool follow(const SCEV *S) {
3034  FoundConstant |= isa<SCEVConstant>(S);
3035  return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3036  }
3037 
3038  bool isDone() const {
3039  return FoundConstant;
3040  }
3041  };
3042 
3043  FindConstantInAddMulChain F;
3045  ST.visitAll(StartExpr);
3046  return F.FoundConstant;
3047 }
3048 
3049 /// Get a canonical multiply expression, or something simpler if possible.
3051  SCEV::NoWrapFlags OrigFlags,
3052  unsigned Depth) {
3053  assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3054  "only nuw or nsw allowed");
3055  assert(!Ops.empty() && "Cannot get empty mul!");
3056  if (Ops.size() == 1) return Ops[0];
3057 #ifndef NDEBUG
3058  Type *ETy = Ops[0]->getType();
3059  assert(!ETy->isPointerTy());
3060  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3061  assert(Ops[i]->getType() == ETy &&
3062  "SCEVMulExpr operand types don't match!");
3063 #endif
3064 
3065  // Sort by complexity, this groups all similar expression types together.
3066  GroupByComplexity(Ops, &LI, DT);
3067 
3068  // If there are any constants, fold them together.
3069  unsigned Idx = 0;
3070  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3071  ++Idx;
3072  assert(Idx < Ops.size());
3073  while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3074  // We found two constants, fold them together!
3075  Ops[0] = getConstant(LHSC->getAPInt() * RHSC->getAPInt());
3076  if (Ops.size() == 2) return Ops[0];
3077  Ops.erase(Ops.begin()+1); // Erase the folded element
3078  LHSC = cast<SCEVConstant>(Ops[0]);
3079  }
3080 
3081  // If we have a multiply of zero, it will always be zero.
3082  if (LHSC->getValue()->isZero())
3083  return LHSC;
3084 
3085  // If we are left with a constant one being multiplied, strip it off.
3086  if (LHSC->getValue()->isOne()) {
3087  Ops.erase(Ops.begin());
3088  --Idx;
3089  }
3090 
3091  if (Ops.size() == 1)
3092  return Ops[0];
3093  }
3094 
3095  // Delay expensive flag strengthening until necessary.
3096  auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
3097  return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3098  };
3099 
3100  // Limit recursion calls depth.
3101  if (Depth > MaxArithDepth || hasHugeExpression(Ops))
3102  return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3103 
3104  if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3105  // Don't strengthen flags if we have no new information.
3106  SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3107  if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3108  Mul->setNoWrapFlags(ComputeFlags(Ops));
3109  return S;
3110  }
3111 
3112  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3113  if (Ops.size() == 2) {
3114  // C1*(C2+V) -> C1*C2 + C1*V
3115  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
3116  // If any of Add's ops are Adds or Muls with a constant, apply this
3117  // transformation as well.
3118  //
3119  // TODO: There are some cases where this transformation is not
3120  // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3121  // this transformation should be narrowed down.
3122  if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) {
3123  const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0),
3124  SCEV::FlagAnyWrap, Depth + 1);
3125  const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1),
3126  SCEV::FlagAnyWrap, Depth + 1);
3127  return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3128  }
3129 
3130  if (Ops[0]->isAllOnesValue()) {
3131  // If we have a mul by -1 of an add, try distributing the -1 among the
3132  // add operands.
3133  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3135  bool AnyFolded = false;
3136  for (const SCEV *AddOp : Add->operands()) {
3137  const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3138  Depth + 1);
3139  if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3140  NewOps.push_back(Mul);
3141  }
3142  if (AnyFolded)
3143  return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3144  } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3145  // Negation preserves a recurrence's no self-wrap property.
3147  for (const SCEV *AddRecOp : AddRec->operands())
3148  Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3149  Depth + 1));
3150 
3151  return getAddRecExpr(Operands, AddRec->getLoop(),
3152  AddRec->getNoWrapFlags(SCEV::FlagNW));
3153  }
3154  }
3155  }
3156  }
3157 
3158  // Skip over the add expression until we get to a multiply.
3159  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3160  ++Idx;
3161 
3162  // If there are mul operands inline them all into this expression.
3163  if (Idx < Ops.size()) {
3164  bool DeletedMul = false;
3165  while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3166  if (Ops.size() > MulOpsInlineThreshold)
3167  break;
3168  // If we have an mul, expand the mul operands onto the end of the
3169  // operands list.
3170  Ops.erase(Ops.begin()+Idx);
3171  Ops.append(Mul->op_begin(), Mul->op_end());
3172  DeletedMul = true;
3173  }
3174 
3175  // If we deleted at least one mul, we added operands to the end of the
3176  // list, and they are not necessarily sorted. Recurse to resort and
3177  // resimplify any operands we just acquired.
3178  if (DeletedMul)
3179  return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3180  }
3181 
3182  // If there are any add recurrences in the operands list, see if any other
3183  // added values are loop invariant. If so, we can fold them into the
3184  // recurrence.
3185  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3186  ++Idx;
3187 
3188  // Scan over all recurrences, trying to fold loop invariants into them.
3189  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3190  // Scan all of the other operands to this mul and add them to the vector
3191  // if they are loop invariant w.r.t. the recurrence.
3193  const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3194  const Loop *AddRecLoop = AddRec->getLoop();
3195  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3196  if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
3197  LIOps.push_back(Ops[i]);
3198  Ops.erase(Ops.begin()+i);
3199  --i; --e;
3200  }
3201 
3202  // If we found some loop invariants, fold them into the recurrence.
3203  if (!LIOps.empty()) {
3204  // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3206  NewOps.reserve(AddRec->getNumOperands());
3207  const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3208  for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
3209  NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3210  SCEV::FlagAnyWrap, Depth + 1));
3211 
3212  // Build the new addrec. Propagate the NUW and NSW flags if both the
3213  // outer mul and the inner addrec are guaranteed to have no overflow.
3214  //
3215  // No self-wrap cannot be guaranteed after changing the step size, but
3216  // will be inferred if either NUW or NSW is true.
3217  SCEV::NoWrapFlags Flags = ComputeFlags({Scale, AddRec});
3218  const SCEV *NewRec = getAddRecExpr(
3219  NewOps, AddRecLoop, AddRec->getNoWrapFlags(Flags));
3220 
3221  // If all of the other operands were loop invariant, we are done.
3222  if (Ops.size() == 1) return NewRec;
3223 
3224  // Otherwise, multiply the folded AddRec by the non-invariant parts.
3225  for (unsigned i = 0;; ++i)
3226  if (Ops[i] == AddRec) {
3227  Ops[i] = NewRec;
3228  break;
3229  }
3230  return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3231  }
3232 
3233  // Okay, if there weren't any loop invariants to be folded, check to see
3234  // if there are multiple AddRec's with the same loop induction variable
3235  // being multiplied together. If so, we can fold them.
3236 
3237  // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3238  // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3239  // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3240  // ]]],+,...up to x=2n}.
3241  // Note that the arguments to choose() are always integers with values
3242  // known at compile time, never SCEV objects.
3243  //
3244  // The implementation avoids pointless extra computations when the two
3245  // addrec's are of different length (mathematically, it's equivalent to
3246  // an infinite stream of zeros on the right).
3247  bool OpsModified = false;
3248  for (unsigned OtherIdx = Idx+1;
3249  OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3250  ++OtherIdx) {
3251  const SCEVAddRecExpr *OtherAddRec =
3252  dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3253  if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop)
3254  continue;
3255 
3256  // Limit max number of arguments to avoid creation of unreasonably big
3257  // SCEVAddRecs with very complex operands.
3258  if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3259  MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3260  continue;
3261 
3262  bool Overflow = false;
3263  Type *Ty = AddRec->getType();
3264  bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3265  SmallVector<const SCEV*, 7> AddRecOps;
3266  for (int x = 0, xe = AddRec->getNumOperands() +
3267  OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3269  for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3270  uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3271  for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3272  ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3273  z < ze && !Overflow; ++z) {
3274  uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3275  uint64_t Coeff;
3276  if (LargerThan64Bits)
3277  Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3278  else
3279  Coeff = Coeff1*Coeff2;
3280  const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3281  const SCEV *Term1 = AddRec->getOperand(y-z);
3282  const SCEV *Term2 = OtherAddRec->getOperand(z);
3283  SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3284  SCEV::FlagAnyWrap, Depth + 1));
3285  }
3286  }
3287  if (SumOps.empty())
3288  SumOps.push_back(getZero(Ty));
3289  AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3290  }
3291  if (!Overflow) {
3292  const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRecLoop,
3294  if (Ops.size() == 2) return NewAddRec;
3295  Ops[Idx] = NewAddRec;
3296  Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3297  OpsModified = true;
3298  AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3299  if (!AddRec)
3300  break;
3301  }
3302  }
3303  if (OpsModified)
3304  return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3305 
3306  // Otherwise couldn't fold anything into this recurrence. Move onto the
3307  // next one.
3308  }
3309 
3310  // Okay, it looks like we really DO need an mul expr. Check to see if we
3311  // already have one, otherwise create a new one.
3312  return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3313 }
3314 
3315 /// Represents an unsigned remainder expression based on unsigned division.
3317  const SCEV *RHS) {
3320  "SCEVURemExpr operand types don't match!");
3321 
3322  // Short-circuit easy cases
3323  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3324  // If constant is one, the result is trivial
3325  if (RHSC->getValue()->isOne())
3326  return getZero(LHS->getType()); // X urem 1 --> 0
3327 
3328  // If constant is a power of two, fold into a zext(trunc(LHS)).
3329  if (RHSC->getAPInt().isPowerOf2()) {
3330  Type *FullTy = LHS->getType();
3331  Type *TruncTy =
3332  IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3333  return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3334  }
3335  }
3336 
3337  // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3338  const SCEV *UDiv = getUDivExpr(LHS, RHS);
3339  const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3340  return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3341 }
3342 
3343 /// Get a canonical unsigned division expression, or something simpler if
3344 /// possible.
3346  const SCEV *RHS) {
3347  assert(!LHS->getType()->isPointerTy() &&
3348  "SCEVUDivExpr operand can't be pointer!");
3349  assert(LHS->getType() == RHS->getType() &&
3350  "SCEVUDivExpr operand types don't match!");
3351 
3353  ID.AddInteger(scUDivExpr);
3354  ID.AddPointer(LHS);
3355  ID.AddPointer(RHS);
3356  void *IP = nullptr;
3357  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3358  return S;
3359 
3360  // 0 udiv Y == 0
3361  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3362  if (LHSC->getValue()->isZero())
3363  return LHS;
3364 
3365  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3366  if (RHSC->getValue()->isOne())
3367  return LHS; // X udiv 1 --> x
3368  // If the denominator is zero, the result of the udiv is undefined. Don't
3369  // try to analyze it, because the resolution chosen here may differ from
3370  // the resolution chosen in other parts of the compiler.
3371  if (!RHSC->getValue()->isZero()) {
3372  // Determine if the division can be folded into the operands of
3373  // its operands.
3374  // TODO: Generalize this to non-constants by using known-bits information.
3375  Type *Ty = LHS->getType();
3376  unsigned LZ = RHSC->getAPInt().countLeadingZeros();
3377  unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3378  // For non-power-of-two values, effectively round the value up to the
3379  // nearest power of two.
3380  if (!RHSC->getAPInt().isPowerOf2())
3381  ++MaxShiftAmt;
3382  IntegerType *ExtTy =
3383  IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3384  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3385  if (const SCEVConstant *Step =
3386  dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3387  // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3388  const APInt &StepInt = Step->getAPInt();
3389  const APInt &DivInt = RHSC->getAPInt();
3390  if (!StepInt.urem(DivInt) &&
3391  getZeroExtendExpr(AR, ExtTy) ==
3392  getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3393  getZeroExtendExpr(Step, ExtTy),
3394  AR->getLoop(), SCEV::FlagAnyWrap)) {
3396  for (const SCEV *Op : AR->operands())
3397  Operands.push_back(getUDivExpr(Op, RHS));
3398  return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3399  }
3400  /// Get a canonical UDivExpr for a recurrence.
3401  /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3402  // We can currently only fold X%N if X is constant.
3403  const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
3404  if (StartC && !DivInt.urem(StepInt) &&
3405  getZeroExtendExpr(AR, ExtTy) ==
3406  getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3407  getZeroExtendExpr(Step, ExtTy),
3408  AR->getLoop(), SCEV::FlagAnyWrap)) {
3409  const APInt &StartInt = StartC->getAPInt();
3410  const APInt &StartRem = StartInt.urem(StepInt);
3411  if (StartRem != 0) {
3412  const SCEV *NewLHS =
3413  getAddRecExpr(getConstant(StartInt - StartRem), Step,
3414  AR->getLoop(), SCEV::FlagNW);
3415  if (LHS != NewLHS) {
3416  LHS = NewLHS;
3417 
3418  // Reset the ID to include the new LHS, and check if it is
3419  // already cached.
3420  ID.clear();
3421  ID.AddInteger(scUDivExpr);
3422  ID.AddPointer(LHS);
3423  ID.AddPointer(RHS);
3424  IP = nullptr;
3425  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3426  return S;
3427  }
3428  }
3429  }
3430  }
3431  // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3432  if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3434  for (const SCEV *Op : M->operands())
3435  Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3436  if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3437  // Find an operand that's safely divisible.
3438  for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3439  const SCEV *Op = M->getOperand(i);
3440  const SCEV *Div = getUDivExpr(Op, RHSC);
3441  if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3442  Operands = SmallVector<const SCEV *, 4>(M->operands());
3443  Operands[i] = Div;
3444  return getMulExpr(Operands);
3445  }
3446  }
3447  }
3448 
3449  // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3450  if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3451  if (auto *DivisorConstant =
3452  dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3453  bool Overflow = false;
3454  APInt NewRHS =
3455  DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3456  if (Overflow) {
3457  return getConstant(RHSC->getType(), 0, false);
3458  }
3459  return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3460  }
3461  }
3462 
3463  // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3464  if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3466  for (const SCEV *Op : A->operands())
3467  Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3468  if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3469  Operands.clear();
3470  for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3471  const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3472  if (isa<SCEVUDivExpr>(Op) ||
3473  getMulExpr(Op, RHS) != A->getOperand(i))
3474  break;
3475  Operands.push_back(Op);
3476  }
3477  if (Operands.size() == A->getNumOperands())
3478  return getAddExpr(Operands);
3479  }
3480  }
3481 
3482  // Fold if both operands are constant.
3483  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3484  return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3485  }
3486  }
3487 
3488  // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3489  // changes). Make sure we get a new one.
3490  IP = nullptr;
3491  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3492  SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3493  LHS, RHS);
3494  UniqueSCEVs.InsertNode(S, IP);
3495  registerUser(S, {LHS, RHS});
3496  return S;
3497 }
3498 
3499 APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3500  APInt A = C1->getAPInt().abs();
3501  APInt B = C2->getAPInt().abs();
3502  uint32_t ABW = A.getBitWidth();
3503  uint32_t BBW = B.getBitWidth();
3504 
3505  if (ABW > BBW)
3506  B = B.zext(ABW);
3507  else if (ABW < BBW)
3508  A = A.zext(BBW);
3509 
3511 }
3512 
3513 /// Get a canonical unsigned division expression, or something simpler if
3514 /// possible. There is no representation for an exact udiv in SCEV IR, but we
3515 /// can attempt to remove factors from the LHS and RHS. We can't do this when
3516 /// it's not exact because the udiv may be clearing bits.
3518  const SCEV *RHS) {
3519  // TODO: we could try to find factors in all sorts of things, but for now we
3520  // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3521  // end of this file for inspiration.
3522 
3523  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
3524  if (!Mul || !Mul->hasNoUnsignedWrap())
3525  return getUDivExpr(LHS, RHS);
3526 
3527  if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3528  // If the mulexpr multiplies by a constant, then that constant must be the
3529  // first element of the mulexpr.
3530  if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3531  if (LHSCst == RHSCst) {
3533  return getMulExpr(Operands);
3534  }
3535 
3536  // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3537  // that there's a factor provided by one of the other terms. We need to
3538  // check.
3539  APInt Factor = gcd(LHSCst, RHSCst);
3540  if (!Factor.isIntN(1)) {
3541  LHSCst =
3542  cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3543  RHSCst =
3544  cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3546  Operands.push_back(LHSCst);
3547  Operands.append(Mul->op_begin() + 1, Mul->op_end());
3548  LHS = getMulExpr(Operands);
3549  RHS = RHSCst;
3550  Mul = dyn_cast<SCEVMulExpr>(LHS);
3551  if (!Mul)
3552  return getUDivExactExpr(LHS, RHS);
3553  }
3554  }
3555  }
3556 
3557  for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3558  if (Mul->getOperand(i) == RHS) {
3560  Operands.append(Mul->op_begin(), Mul->op_begin() + i);
3561  Operands.append(Mul->op_begin() + i + 1, Mul->op_end());
3562  return getMulExpr(Operands);
3563  }
3564  }
3565 
3566  return getUDivExpr(LHS, RHS);
3567 }
3568 
3569 /// Get an add recurrence expression for the specified loop. Simplify the
3570 /// expression as much as possible.
3571 const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3572  const Loop *L,
3573  SCEV::NoWrapFlags Flags) {
3575  Operands.push_back(Start);
3576  if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3577  if (StepChrec->getLoop() == L) {
3578  Operands.append(StepChrec->op_begin(), StepChrec->op_end());
3579  return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3580  }
3581 
3582  Operands.push_back(Step);
3583  return getAddRecExpr(Operands, L, Flags);
3584 }
3585 
3586 /// Get an add recurrence expression for the specified loop. Simplify the
3587 /// expression as much as possible.
3588 const SCEV *
3590  const Loop *L, SCEV::NoWrapFlags Flags) {
3591  if (Operands.size() == 1) return Operands[0];
3592 #ifndef NDEBUG
3593  Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3594  for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
3596  "SCEVAddRecExpr operand types don't match!");
3597  assert(!Operands[i]->getType()->isPointerTy() && "Step must be integer");
3598  }
3599  for (unsigned i = 0, e = Operands.size(); i != e; ++i)
3601  "SCEVAddRecExpr operand is not loop-invariant!");
3602 #endif
3603 
3604  if (Operands.back()->isZero()) {
3605  Operands.pop_back();
3606  return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3607  }
3608 
3609  // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3610  // use that information to infer NUW and NSW flags. However, computing a
3611  // BE count requires calling getAddRecExpr, so we may not yet have a
3612  // meaningful BE count at this point (and if we don't, we'd be stuck
3613  // with a SCEVCouldNotCompute as the cached BE count).
3614 
3615  Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3616 
3617  // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3618  if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3619  const Loop *NestedLoop = NestedAR->getLoop();
3620  if (L->contains(NestedLoop)
3621  ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3622  : (!NestedLoop->contains(L) &&
3623  DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3624  SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3625  Operands[0] = NestedAR->getStart();
3626  // AddRecs require their operands be loop-invariant with respect to their
3627  // loops. Don't perform this transformation if it would break this
3628  // requirement.
3629  bool AllInvariant = all_of(
3630  Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3631 
3632  if (AllInvariant) {
3633  // Create a recurrence for the outer loop with the same step size.
3634  //
3635  // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3636  // inner recurrence has the same property.
3637  SCEV::NoWrapFlags OuterFlags =
3638  maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3639 
3640  NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3641  AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3642  return isLoopInvariant(Op, NestedLoop);
3643  });
3644 
3645  if (AllInvariant) {
3646  // Ok, both add recurrences are valid after the transformation.
3647  //
3648  // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3649  // the outer recurrence has the same property.
3650  SCEV::NoWrapFlags InnerFlags =
3651  maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3652  return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3653  }
3654  }
3655  // Reset Operands to its original state.
3656  Operands[0] = NestedAR;
3657  }
3658  }
3659 
3660  // Okay, it looks like we really DO need an addrec expr. Check to see if we
3661  // already have one, otherwise create a new one.
3662  return getOrCreateAddRecExpr(Operands, L, Flags);
3663 }
3664 
3665 const SCEV *
3667  const SmallVectorImpl<const SCEV *> &IndexExprs) {
3668  const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3669  // getSCEV(Base)->getType() has the same address space as Base->getType()
3670  // because SCEV::getType() preserves the address space.
3671  Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3672  const bool AssumeInBoundsFlags = [&]() {
3673  if (!GEP->isInBounds())
3674  return false;
3675 
3676  // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3677  // but to do that, we have to ensure that said flag is valid in the entire
3678  // defined scope of the SCEV.
3679  auto *GEPI = dyn_cast<Instruction>(GEP);
3680  // TODO: non-instructions have global scope. We might be able to prove
3681  // some global scope cases
3682  return GEPI && isSCEVExprNeverPoison(GEPI);
3683  }();
3684 
3685  SCEV::NoWrapFlags OffsetWrap =
3686  AssumeInBoundsFlags ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
3687 
3688  Type *CurTy = GEP->getType();
3689  bool FirstIter = true;
3691  for (const SCEV *IndexExpr : IndexExprs) {
3692  // Compute the (potentially symbolic) offset in bytes for this index.
3693  if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3694  // For a struct, add the member offset.
3695  ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3696  unsigned FieldNo = Index->getZExtValue();
3697  const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3698  Offsets.push_back(FieldOffset);
3699 
3700  // Update CurTy to the type of the field at Index.
3701  CurTy = STy->getTypeAtIndex(Index);
3702  } else {
3703  // Update CurTy to its element type.
3704  if (FirstIter) {
3705  assert(isa<PointerType>(CurTy) &&
3706  "The first index of a GEP indexes a pointer");
3707  CurTy = GEP->getSourceElementType();
3708  FirstIter = false;
3709  } else {
3710  CurTy = GetElementPtrInst::getTypeAtIndex(CurTy, (uint64_t)0);
3711  }
3712  // For an array, add the element offset, explicitly scaled.
3713  const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3714  // Getelementptr indices are signed.
3715  IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3716 
3717  // Multiply the index by the element size to compute the element offset.
3718  const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3719  Offsets.push_back(LocalOffset);
3720  }
3721  }
3722 
3723  // Handle degenerate case of GEP without offsets.
3724  if (Offsets.empty())
3725  return BaseExpr;
3726 
3727  // Add the offsets together, assuming nsw if inbounds.
3728  const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3729  // Add the base address and the offset. We cannot use the nsw flag, as the
3730  // base address is unsigned. However, if we know that the offset is
3731  // non-negative, we can use nuw.
3732  SCEV::NoWrapFlags BaseWrap = AssumeInBoundsFlags && isKnownNonNegative(Offset)
3734  auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3735  assert(BaseExpr->getType() == GEPExpr->getType() &&
3736  "GEP should not change type mid-flight.");
3737  return GEPExpr;
3738 }
3739 
3740 SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3741  ArrayRef<const SCEV *> Ops) {
3743  ID.AddInteger(SCEVType);
3744  for (const SCEV *Op : Ops)
3745  ID.AddPointer(Op);
3746  void *IP = nullptr;
3747  return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3748 }
3749 
3750 const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3752  return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3753 }
3754 
3757  assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3758  assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3759  if (Ops.size() == 1) return Ops[0];
3760 #ifndef NDEBUG
3761  Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3762  for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3763  assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3764  "Operand types don't match!");
3765  assert(Ops[0]->getType()->isPointerTy() ==
3766  Ops[i]->getType()->isPointerTy() &&
3767  "min/max should be consistently pointerish");
3768  }
3769 #endif
3770 
3771  bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3772  bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3773 
3774  // Sort by complexity, this groups all similar expression types together.
3775  GroupByComplexity(Ops, &LI, DT);
3776 
3777  // Check if we have created the same expression before.
3778  if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3779  return S;
3780  }
3781 
3782  // If there are any constants, fold them together.
3783  unsigned Idx = 0;
3784  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3785  ++Idx;
3786  assert(Idx < Ops.size());
3787  auto FoldOp = [&](const APInt &LHS, const APInt &RHS) {
3788  if (Kind == scSMaxExpr)
3789  return APIntOps::smax(LHS, RHS);
3790  else if (Kind == scSMinExpr)
3791  return APIntOps::smin(LHS, RHS);
3792  else if (Kind == scUMaxExpr)
3793  return APIntOps::umax(LHS, RHS);
3794  else if (Kind == scUMinExpr)
3795  return APIntOps::umin(LHS, RHS);
3796  llvm_unreachable("Unknown SCEV min/max opcode");
3797  };
3798 
3799  while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3800  // We found two constants, fold them together!
3801  ConstantInt *Fold = ConstantInt::get(
3802  getContext(), FoldOp(LHSC->getAPInt(), RHSC->getAPInt()));
3803  Ops[0] = getConstant(Fold);
3804  Ops.erase(Ops.begin()+1); // Erase the folded element
3805  if (Ops.size() == 1) return Ops[0];
3806  LHSC = cast<SCEVConstant>(Ops[0]);
3807  }
3808 
3809  bool IsMinV = LHSC->getValue()->isMinValue(IsSigned);
3810  bool IsMaxV = LHSC->getValue()->isMaxValue(IsSigned);
3811 
3812  if (IsMax ? IsMinV : IsMaxV) {
3813  // If we are left with a constant minimum(/maximum)-int, strip it off.
3814  Ops.erase(Ops.begin());
3815  --Idx;
3816  } else if (IsMax ? IsMaxV : IsMinV) {
3817  // If we have a max(/min) with a constant maximum(/minimum)-int,
3818  // it will always be the extremum.
3819  return LHSC;
3820  }
3821 
3822  if (Ops.size() == 1) return Ops[0];
3823  }
3824 
3825  // Find the first operation of the same kind
3826  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3827  ++Idx;
3828 
3829  // Check to see if one of the operands is of the same kind. If so, expand its
3830  // operands onto our operand list, and recurse to simplify.
3831  if (Idx < Ops.size()) {
3832  bool DeletedAny = false;
3833  while (Ops[Idx]->getSCEVType() == Kind) {
3834  const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3835  Ops.erase(Ops.begin()+Idx);
3836  Ops.append(SMME->op_begin(), SMME->op_end());
3837  DeletedAny = true;
3838  }
3839 
3840  if (DeletedAny)
3841  return getMinMaxExpr(Kind, Ops);
3842  }
3843 
3844  // Okay, check to see if the same value occurs in the operand list twice. If
3845  // so, delete one. Since we sorted the list, these values are required to
3846  // be adjacent.
3847  llvm::CmpInst::Predicate GEPred =
3849  llvm::CmpInst::Predicate LEPred =
3851  llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3852  llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3853  for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3854  if (Ops[i] == Ops[i + 1] ||
3855  isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3856  // X op Y op Y --> X op Y
3857  // X op Y --> X, if we know X, Y are ordered appropriately
3858  Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3859  --i;
3860  --e;
3861  } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3862  Ops[i + 1])) {
3863  // X op Y --> Y, if we know X, Y are ordered appropriately
3864  Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3865  --i;
3866  --e;
3867  }
3868  }
3869 
3870  if (Ops.size() == 1) return Ops[0];
3871 
3872  assert(!Ops.empty() && "Reduced smax down to nothing!");
3873 
3874  // Okay, it looks like we really DO need an expr. Check to see if we
3875  // already have one, otherwise create a new one.
3877  ID.AddInteger(Kind);
3878  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3879  ID.AddPointer(Ops[i]);
3880  void *IP = nullptr;
3881  const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3882  if (ExistingSCEV)
3883  return ExistingSCEV;
3884  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3885  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3886  SCEV *S = new (SCEVAllocator)
3887  SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3888 
3889  UniqueSCEVs.InsertNode(S, IP);
3890  registerUser(S, Ops);
3891  return S;
3892 }
3893 
3894 namespace {
3895 
3896 class SCEVSequentialMinMaxDeduplicatingVisitor final
3897  : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
3898  Optional<const SCEV *>> {
3899  using RetVal = Optional<const SCEV *>;
3901 
3902  ScalarEvolution &SE;
3903  const SCEVTypes RootKind; // Must be a sequential min/max expression.
3904  const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
3906 
3907  bool canRecurseInto(SCEVTypes Kind) const {
3908  // We can only recurse into the SCEV expression of the same effective type
3909  // as the type of our root SCEV expression.
3910  return RootKind == Kind || NonSequentialRootKind == Kind;
3911  };
3912 
3913  RetVal visitAnyMinMaxExpr(const SCEV *S) {
3914  assert((isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) &&
3915  "Only for min/max expressions.");
3916  SCEVTypes Kind = S->getSCEVType();
3917 
3918  if (!canRecurseInto(Kind))
3919  return S;
3920 
3921  auto *NAry = cast<SCEVNAryExpr>(S);
3923  bool Changed =
3924  visit(Kind, makeArrayRef(NAry->op_begin(), NAry->op_end()), NewOps);
3925 
3926  if (!Changed)
3927  return S;
3928  if (NewOps.empty())
3929  return None;
3930 
3931  return isa<SCEVSequentialMinMaxExpr>(S)
3932  ? SE.getSequentialMinMaxExpr(Kind, NewOps)
3933  : SE.getMinMaxExpr(Kind, NewOps);
3934  }
3935 
3936  RetVal visit(const SCEV *S) {
3937  // Has the whole operand been seen already?
3938  if (!SeenOps.insert(S).second)
3939  return None;
3940  return Base::visit(S);
3941  }
3942 
3943 public:
3944  SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
3945  SCEVTypes RootKind)
3946  : SE(SE), RootKind(RootKind),
3947  NonSequentialRootKind(
3948  SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
3949  RootKind)) {}
3950 
3951  bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
3953  bool Changed = false;
3955  Ops.reserve(OrigOps.size());
3956 
3957  for (const SCEV *Op : OrigOps) {
3958  RetVal NewOp = visit(Op);
3959  if (NewOp != Op)
3960  Changed = true;
3961  if (NewOp)
3962  Ops.emplace_back(*NewOp);
3963  }
3964 
3965  if (Changed)
3966  NewOps = std::move(Ops);
3967  return Changed;
3968  }
3969 
3970  RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
3971 
3972  RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
3973 
3974  RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
3975 
3976  RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
3977 
3978  RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
3979 
3980  RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
3981 
3982  RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
3983 
3984  RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
3985 
3986  RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
3987 
3988  RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
3989  return visitAnyMinMaxExpr(Expr);
3990  }
3991 
3992  RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
3993  return visitAnyMinMaxExpr(Expr);
3994  }
3995 
3996  RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
3997  return visitAnyMinMaxExpr(Expr);
3998  }
3999 
4000  RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4001  return visitAnyMinMaxExpr(Expr);
4002  }
4003 
4004  RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4005  return visitAnyMinMaxExpr(Expr);
4006  }
4007 
4008  RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4009 
4010  RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4011 };
4012 
4013 } // namespace
4014 
4015 /// Return true if V is poison given that AssumedPoison is already poison.
4016 static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4017  // The only way poison may be introduced in a SCEV expression is from a
4018  // poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4019  // not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4020  // introduce poison -- they encode guaranteed, non-speculated knowledge.
4021  //
4022  // Additionally, all SCEV nodes propagate poison from inputs to outputs,
4023  // with the notable exception of umin_seq, where only poison from the first
4024  // operand is (unconditionally) propagated.
4025  struct SCEVPoisonCollector {
4026  bool LookThroughSeq;
4027  SmallPtrSet<const SCEV *, 4> MaybePoison;
4028  SCEVPoisonCollector(bool LookThroughSeq) : LookThroughSeq(LookThroughSeq) {}
4029 
4030  bool follow(const SCEV *S) {
4031  // TODO: We can always follow the first operand, but the SCEVTraversal
4032  // API doesn't support this.
4033  if (!LookThroughSeq && isa<SCEVSequentialMinMaxExpr>(S))
4034  return false;
4035 
4036  if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4037  if (!isGuaranteedNotToBePoison(SU->getValue()))
4038  MaybePoison.insert(S);
4039  }
4040  return true;
4041  }
4042  bool isDone() const { return false; }
4043  };
4044 
4045  // First collect all SCEVs that might result in AssumedPoison to be poison.
4046  // We need to look through umin_seq here, because we want to find all SCEVs
4047  // that *might* result in poison, not only those that are *required* to.
4048  SCEVPoisonCollector PC1(/* LookThroughSeq */ true);
4049  visitAll(AssumedPoison, PC1);
4050 
4051  // AssumedPoison is never poison. As the assumption is false, the implication
4052  // is true. Don't bother walking the other SCEV in this case.
4053  if (PC1.MaybePoison.empty())
4054  return true;
4055 
4056  // Collect all SCEVs in S that, if poison, *will* result in S being poison
4057  // as well. We cannot look through umin_seq here, as its argument only *may*
4058  // make the result poison.
4059  SCEVPoisonCollector PC2(/* LookThroughSeq */ false);
4060  visitAll(S, PC2);
4061 
4062  // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4063  // it will also make S poison by being part of PC2.MaybePoison.
4064  return all_of(PC1.MaybePoison,
4065  [&](const SCEV *S) { return PC2.MaybePoison.contains(S); });
4066 }
4067 
4068 const SCEV *
4071  assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4072  "Not a SCEVSequentialMinMaxExpr!");
4073  assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4074  if (Ops.size() == 1)
4075  return Ops[0];
4076 #ifndef NDEBUG
4077  Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4078  for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4079  assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4080  "Operand types don't match!");
4081  assert(Ops[0]->getType()->isPointerTy() ==
4082  Ops[i]->getType()->isPointerTy() &&
4083  "min/max should be consistently pointerish");
4084  }
4085 #endif
4086 
4087  // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4088  // so we can *NOT* do any kind of sorting of the expressions!
4089 
4090  // Check if we have created the same expression before.
4091  if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4092  return S;
4093 
4094  // FIXME: there are *some* simplifications that we can do here.
4095 
4096  // Keep only the first instance of an operand.
4097  {
4098  SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4099  bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4100  if (Changed)
4101  return getSequentialMinMaxExpr(Kind, Ops);
4102  }
4103 
4104  // Check to see if one of the operands is of the same kind. If so, expand its
4105  // operands onto our operand list, and recurse to simplify.
4106  {
4107  unsigned Idx = 0;
4108  bool DeletedAny = false;
4109  while (Idx < Ops.size()) {
4110  if (Ops[Idx]->getSCEVType() != Kind) {
4111  ++Idx;
4112  continue;
4113  }
4114  const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4115  Ops.erase(Ops.begin() + Idx);
4116  Ops.insert(Ops.begin() + Idx, SMME->op_begin(), SMME->op_end());
4117  DeletedAny = true;
4118  }
4119 
4120  if (DeletedAny)
4121  return getSequentialMinMaxExpr(Kind, Ops);
4122  }
4123 
4124  const SCEV *SaturationPoint;
4125  ICmpInst::Predicate Pred;
4126  switch (Kind) {
4127  case scSequentialUMinExpr:
4128  SaturationPoint = getZero(Ops[0]->getType());
4129  Pred = ICmpInst::ICMP_ULE;
4130  break;
4131  default:
4132  llvm_unreachable("Not a sequential min/max type.");
4133  }
4134 
4135  for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4136  // We can replace %x umin_seq %y with %x umin %y if either:
4137  // * %y being poison implies %x is also poison.
4138  // * %x cannot be the saturating value (e.g. zero for umin).
4139  if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4140  isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4141  SaturationPoint)) {
4142  SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
4143  Ops[i - 1] = getMinMaxExpr(
4145  SeqOps);
4146  Ops.erase(Ops.begin() + i);
4147  return getSequentialMinMaxExpr(Kind, Ops);
4148  }
4149  // Fold %x umin_seq %y to %x if %x ule %y.
4150  // TODO: We might be able to prove the predicate for a later operand.
4151  if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4152  Ops.erase(Ops.begin() + i);
4153  return getSequentialMinMaxExpr(Kind, Ops);
4154  }
4155  }
4156 
4157  // Okay, it looks like we really DO need an expr. Check to see if we
4158  // already have one, otherwise create a new one.
4160  ID.AddInteger(Kind);
4161  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
4162  ID.AddPointer(Ops[i]);
4163  void *IP = nullptr;
4164  const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4165  if (ExistingSCEV)
4166  return ExistingSCEV;
4167 
4168  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4169  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
4170  SCEV *S = new (SCEVAllocator)
4171  SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4172 
4173  UniqueSCEVs.InsertNode(S, IP);
4174  registerUser(S, Ops);
4175  return S;
4176 }
4177 
4180  return getSMaxExpr(Ops);
4181 }
4182 
4184  return getMinMaxExpr(scSMaxExpr, Ops);
4185 }
4186 
4189  return getUMaxExpr(Ops);
4190 }
4191 
4193  return getMinMaxExpr(scUMaxExpr, Ops);
4194 }
4195 
4197  const SCEV *RHS) {
4199  return getSMinExpr(Ops);
4200 }
4201 
4203  return getMinMaxExpr(scSMinExpr, Ops);
4204 }
4205 
4207  bool Sequential) {
4209  return getUMinExpr(Ops, Sequential);
4210 }
4211 
4213  bool Sequential) {
4214  return Sequential ? getSequentialMinMaxExpr(scSequentialUMinExpr, Ops)
4215  : getMinMaxExpr(scUMinExpr, Ops);
4216 }
4217 
4218 const SCEV *
4220  ScalableVectorType *ScalableTy) {
4221  Constant *NullPtr = Constant::getNullValue(ScalableTy->getPointerTo());
4222  Constant *One = ConstantInt::get(IntTy, 1);
4223  Constant *GEP = ConstantExpr::getGetElementPtr(ScalableTy, NullPtr, One);
4224  // Note that the expression we created is the final expression, we don't
4225  // want to simplify it any further Also, if we call a normal getSCEV(),
4226  // we'll end up in an endless recursion. So just create an SCEVUnknown.
4227  return getUnknown(ConstantExpr::getPtrToInt(GEP, IntTy));
4228 }
4229 
4230 const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) {
4231  if (auto *ScalableAllocTy = dyn_cast<ScalableVectorType>(AllocTy))
4232  return getSizeOfScalableVectorExpr(IntTy, ScalableAllocTy);
4233  // We can bypass creating a target-independent constant expression and then
4234  // folding it back into a ConstantInt. This is just a compile-time
4235  // optimization.
4236  return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4237 }
4238 
4240  if (auto *ScalableStoreTy = dyn_cast<ScalableVectorType>(StoreTy))
4241  return getSizeOfScalableVectorExpr(IntTy, ScalableStoreTy);
4242  // We can bypass creating a target-independent constant expression and then
4243  // folding it back into a ConstantInt. This is just a compile-time
4244  // optimization.
4245  return getConstant(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4246 }
4247 
4249  StructType *STy,
4250  unsigned FieldNo) {
4251  // We can bypass creating a target-independent constant expression and then
4252  // folding it back into a ConstantInt. This is just a compile-time
4253  // optimization.
4254  return getConstant(
4255  IntTy, getDataLayout().getStructLayout(STy)->getElementOffset(FieldNo));
4256 }
4257 
4259  // Don't attempt to do anything other than create a SCEVUnknown object
4260  // here. createSCEV only calls getUnknown after checking for all other
4261  // interesting possibilities, and any other code that calls getUnknown
4262  // is doing so in order to hide a value from SCEV canonicalization.
4263 
4265  ID.AddInteger(scUnknown);
4266  ID.AddPointer(V);
4267  void *IP = nullptr;
4268  if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4269  assert(cast<SCEVUnknown>(S)->getValue() == V &&
4270  "Stale SCEVUnknown in uniquing map!");
4271  return S;
4272  }
4273  SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4274  FirstUnknown);
4275  FirstUnknown = cast<SCEVUnknown>(S);
4276  UniqueSCEVs.InsertNode(S, IP);
4277  return S;
4278 }
4279 
4280 //===----------------------------------------------------------------------===//
4281 // Basic SCEV Analysis and PHI Idiom Recognition Code
4282 //
4283 
4284 /// Test if values of the given type are analyzable within the SCEV
4285 /// framework. This primarily includes integer types, and it can optionally
4286 /// include pointer types if the ScalarEvolution class has access to
4287 /// target-specific information.
4289  // Integers and pointers are always SCEVable.
4290  return Ty->isIntOrPtrTy();
4291 }
4292 
4293 /// Return the size in bits of the specified type, for which isSCEVable must
4294 /// return true.
4296  assert(isSCEVable(Ty) && "Type is not SCEVable!");
4297  if (Ty->isPointerTy())
4299  return getDataLayout().getTypeSizeInBits(Ty);
4300 }
4301 
4302 /// Return a type with the same bitwidth as the given type and which represents
4303 /// how SCEV will treat the given type, for which isSCEVable must return
4304 /// true. For pointer types, this is the pointer index sized integer type.
4306  assert(isSCEVable(Ty) && "Type is not SCEVable!");
4307 
4308  if (Ty->isIntegerTy())
4309  return Ty;
4310 
4311  // The only other support type is pointer.
4312  assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4313  return getDataLayout().getIndexType(Ty);
4314 }
4315 
4317  return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4318 }
4319 
4321  const SCEV *B) {
4322  /// For a valid use point to exist, the defining scope of one operand
4323  /// must dominate the other.
4324  bool PreciseA, PreciseB;
4325  auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4326  auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4327  if (!PreciseA || !PreciseB)
4328  // Can't tell.
4329  return false;
4330  return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4331  DT.dominates(ScopeB, ScopeA);
4332 }
4333 
4334 
4336  return CouldNotCompute.get();
4337 }
4338 
4339 bool ScalarEvolution::checkValidity(const SCEV *S) const {
4340  bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4341  auto *SU = dyn_cast<SCEVUnknown>(S);
4342  return SU && SU->getValue() == nullptr;
4343  });
4344 
4345  return !ContainsNulls;
4346 }
4347 
4349  HasRecMapType::iterator I = HasRecMap.find(S);
4350  if (I != HasRecMap.end())
4351  return I->second;
4352 
4353  bool FoundAddRec =
4354  SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4355  HasRecMap.insert({S, FoundAddRec});
4356  return FoundAddRec;
4357 }
4358 
4359 /// Return the ValueOffsetPair set for \p S. \p S can be represented
4360 /// by the value and offset from any ValueOffsetPair in the set.
4361 ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4362  ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4363  if (SI == ExprValueMap.end())
4364  return None;
4365 #ifndef NDEBUG
4366  if (VerifySCEVMap) {
4367  // Check there is no dangling Value in the set returned.
4368  for (Value *V : SI->second)
4369  assert(ValueExprMap.count(V));
4370  }
4371 #endif
4372  return SI->second.getArrayRef();
4373 }
4374 
4375 /// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4376 /// cannot be used separately. eraseValueFromMap should be used to remove
4377 /// V from ValueExprMap and ExprValueMap at the same time.
4378 void ScalarEvolution::eraseValueFromMap(Value *V) {
4379  ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4380  if (I != ValueExprMap.end()) {
4381  auto EVIt = ExprValueMap.find(I->second);
4382  bool Removed = EVIt->second.remove(V);
4383  (void) Removed;
4384  assert(Removed && "Value not in ExprValueMap?");
4385  ValueExprMap.erase(I);
4386  }
4387 }
4388 
4389 void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4390  // A recursive query may have already computed the SCEV. It should be
4391  // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4392  // inferred nowrap flags.
4393  auto It = ValueExprMap.find_as(V);
4394  if (It == ValueExprMap.end()) {
4395  ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4396  ExprValueMap[S].insert(V);
4397  }
4398 }
4399 
4400 /// Return an existing SCEV if it exists, otherwise analyze the expression and
4401 /// create a new one.
4403  assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4404 
4405  if (const SCEV *S = getExistingSCEV(V))
4406  return S;
4407  return createSCEVIter(V);
4408 }
4409 
4410 const SCEV *ScalarEvolution::getExistingSCEV(Value *V) {
4411  assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4412 
4413  ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4414  if (I != ValueExprMap.end()) {
4415  const SCEV *S = I->second;
4416  assert(checkValidity(S) &&
4417  "existing SCEV has not been properly invalidated");
4418  return S;
4419  }
4420  return nullptr;
4421 }
4422 
4423 /// Return a SCEV corresponding to -V = -1*V
4425  SCEV::NoWrapFlags Flags) {
4426  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4427  return getConstant(
4428  cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4429 
4430  Type *Ty = V->getType();
4431  Ty = getEffectiveSCEVType(Ty);
4432  return getMulExpr(V, getMinusOne(Ty), Flags);
4433 }
4434 
4435 /// If Expr computes ~A, return A else return nullptr
4436 static const SCEV *MatchNotExpr(const SCEV *Expr) {
4437  const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
4438  if (!Add || Add->getNumOperands() != 2 ||
4439  !Add->getOperand(0)->isAllOnesValue())
4440  return nullptr;
4441 
4442  const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
4443  if (!AddRHS || AddRHS->getNumOperands() != 2 ||
4444  !AddRHS->getOperand(0)->isAllOnesValue())
4445  return nullptr;
4446 
4447  return AddRHS->getOperand(1);
4448 }
4449 
4450 /// Return a SCEV corresponding to ~V = -1-V
4452  assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4453 
4454  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4455  return getConstant(
4456  cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4457 
4458  // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4459  if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4460  auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4461  SmallVector<const SCEV *, 2> MatchedOperands;
4462  for (const SCEV *Operand : MME->operands()) {
4463  const SCEV *Matched = MatchNotExpr(Operand);
4464  if (!Matched)
4465  return (const SCEV *)nullptr;
4466  MatchedOperands.push_back(Matched);
4467  }
4468  return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4469  MatchedOperands);
4470  };
4471  if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4472  return Replaced;
4473  }
4474 
4475  Type *Ty = V->getType();
4476  Ty = getEffectiveSCEVType(Ty);
4477  return getMinusSCEV(getMinusOne(Ty), V);
4478 }
4479 
4481  assert(P->getType()->isPointerTy());
4482 
4483  if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4484  // The base of an AddRec is the first operand.
4485  SmallVector<const SCEV *> Ops{AddRec->operands()};
4486  Ops[0] = removePointerBase(Ops[0]);
4487  // Don't try to transfer nowrap flags for now. We could in some cases
4488  // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4489  return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4490  }
4491  if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4492  // The base of an Add is the pointer operand.
4493  SmallVector<const SCEV *> Ops{Add->operands()};
4494  const SCEV **PtrOp = nullptr;
4495  for (const SCEV *&AddOp : Ops) {
4496  if (AddOp->getType()->isPointerTy()) {
4497  assert(!PtrOp && "Cannot have multiple pointer ops");
4498  PtrOp = &AddOp;
4499  }
4500  }
4501  *PtrOp = removePointerBase(*PtrOp);
4502  // Don't try to transfer nowrap flags for now. We could in some cases
4503  // (for example, if the pointer operand of the Add is a SCEVUnknown).
4504  return getAddExpr(Ops);
4505  }
4506  // Any other expression must be a pointer base.
4507  return getZero(P->getType());
4508 }
4509 
4511  SCEV::NoWrapFlags Flags,
4512  unsigned Depth) {
4513  // Fast path: X - X --> 0.
4514  if (LHS == RHS)
4515  return getZero(LHS->getType());
4516 
4517  // If we subtract two pointers with different pointer bases, bail.
4518  // Eventually, we're going to add an assertion to getMulExpr that we
4519  // can't multiply by a pointer.
4520  if (RHS->getType()->isPointerTy()) {
4521  if (!LHS->getType()->isPointerTy() ||
4523  return getCouldNotCompute();
4526  }
4527 
4528  // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4529  // makes it so that we cannot make much use of NUW.
4530  auto AddFlags = SCEV::FlagAnyWrap;
4531  const bool RHSIsNotMinSigned =
4533  if (hasFlags(Flags, SCEV::FlagNSW)) {
4534  // Let M be the minimum representable signed value. Then (-1)*RHS
4535  // signed-wraps if and only if RHS is M. That can happen even for
4536  // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4537  // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4538  // (-1)*RHS, we need to prove that RHS != M.
4539  //
4540  // If LHS is non-negative and we know that LHS - RHS does not
4541  // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4542  // either by proving that RHS > M or that LHS >= 0.
4543  if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4544  AddFlags = SCEV::FlagNSW;
4545  }
4546  }
4547 
4548  // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4549  // RHS is NSW and LHS >= 0.
4550  //
4551  // The difficulty here is that the NSW flag may have been proven
4552  // relative to a loop that is to be found in a recurrence in LHS and
4553  // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4554  // larger scope than intended.
4555  auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4556 
4557  return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4558 }
4559 
4561  unsigned Depth) {
4562  Type *SrcTy = V->getType();
4563  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4564  "Cannot truncate or zero extend with non-integer arguments!");
4565  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4566  return V; // No conversion
4567  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4568  return getTruncateExpr(V, Ty, Depth);
4569  return getZeroExtendExpr(V, Ty, Depth);
4570 }
4571 
4573  unsigned Depth) {
4574  Type *SrcTy = V->getType();
4575  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4576  "Cannot truncate or zero extend with non-integer arguments!");
4577  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4578  return V; // No conversion
4579  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4580  return getTruncateExpr(V, Ty, Depth);
4581  return getSignExtendExpr(V, Ty, Depth);
4582 }
4583 
4584 const SCEV *
4586  Type *SrcTy = V->getType();
4587  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4588  "Cannot noop or zero extend with non-integer arguments!");
4589  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
4590  "getNoopOrZeroExtend cannot truncate!");
4591  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4592  return V; // No conversion
4593  return getZeroExtendExpr(V, Ty);
4594 }
4595 
4596 const SCEV *
4598  Type *SrcTy = V->getType();
4599  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4600  "Cannot noop or sign extend with non-integer arguments!");
4601  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
4602  "getNoopOrSignExtend cannot truncate!");
4603  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4604  return V; // No conversion
4605  return getSignExtendExpr(V, Ty);
4606 }
4607 
4608 const SCEV *
4610  Type *SrcTy = V->getType();
4611  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4612  "Cannot noop or any extend with non-integer arguments!");
4613  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
4614  "getNoopOrAnyExtend cannot truncate!");
4615  if (