LLVM 19.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
66#include "llvm/ADT/FoldingSet.h"
67#include "llvm/ADT/STLExtras.h"
68#include "llvm/ADT/ScopeExit.h"
69#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/SmallSet.h"
73#include "llvm/ADT/Statistic.h"
75#include "llvm/ADT/StringRef.h"
84#include "llvm/Config/llvm-config.h"
85#include "llvm/IR/Argument.h"
86#include "llvm/IR/BasicBlock.h"
87#include "llvm/IR/CFG.h"
88#include "llvm/IR/Constant.h"
90#include "llvm/IR/Constants.h"
91#include "llvm/IR/DataLayout.h"
93#include "llvm/IR/Dominators.h"
94#include "llvm/IR/Function.h"
95#include "llvm/IR/GlobalAlias.h"
96#include "llvm/IR/GlobalValue.h"
98#include "llvm/IR/InstrTypes.h"
99#include "llvm/IR/Instruction.h"
100#include "llvm/IR/Instructions.h"
102#include "llvm/IR/Intrinsics.h"
103#include "llvm/IR/LLVMContext.h"
104#include "llvm/IR/Operator.h"
105#include "llvm/IR/PatternMatch.h"
106#include "llvm/IR/Type.h"
107#include "llvm/IR/Use.h"
108#include "llvm/IR/User.h"
109#include "llvm/IR/Value.h"
110#include "llvm/IR/Verifier.h"
112#include "llvm/Pass.h"
113#include "llvm/Support/Casting.h"
116#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136
137#define DEBUG_TYPE "scalar-evolution"
138
139STATISTIC(NumExitCountsComputed,
140 "Number of loop exits with predictable exit counts");
141STATISTIC(NumExitCountsNotComputed,
142 "Number of loop exits without predictable exit counts");
143STATISTIC(NumBruteForceTripCountsComputed,
144 "Number of loops with trip counts computed by force");
145
146#ifdef EXPENSIVE_CHECKS
147bool llvm::VerifySCEV = true;
148#else
149bool llvm::VerifySCEV = false;
150#endif
151
153 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
154 cl::desc("Maximum number of iterations SCEV will "
155 "symbolically execute a constant "
156 "derived loop"),
157 cl::init(100));
158
160 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
161 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
163 "verify-scev-strict", cl::Hidden,
164 cl::desc("Enable stricter verification with -verify-scev is passed"));
165
167 "scev-verify-ir", cl::Hidden,
168 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
169 cl::init(false));
170
172 "scev-mulops-inline-threshold", cl::Hidden,
173 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
174 cl::init(32));
175
177 "scev-addops-inline-threshold", cl::Hidden,
178 cl::desc("Threshold for inlining addition operands into a SCEV"),
179 cl::init(500));
180
182 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
183 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
184 cl::init(32));
185
187 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
188 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
189 cl::init(2));
190
192 "scalar-evolution-max-value-compare-depth", cl::Hidden,
193 cl::desc("Maximum depth of recursive value complexity comparisons"),
194 cl::init(2));
195
197 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
198 cl::desc("Maximum depth of recursive arithmetics"),
199 cl::init(32));
200
202 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
203 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
204
206 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
207 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
208 cl::init(8));
209
211 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
212 cl::desc("Max coefficients in AddRec during evolving"),
213 cl::init(8));
214
216 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
217 cl::desc("Size of the expression which is considered huge"),
218 cl::init(4096));
219
221 "scev-range-iter-threshold", cl::Hidden,
222 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
223 cl::init(32));
224
225static cl::opt<bool>
226ClassifyExpressions("scalar-evolution-classify-expressions",
227 cl::Hidden, cl::init(true),
228 cl::desc("When printing analysis, include information on every instruction"));
229
231 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
232 cl::init(false),
233 cl::desc("Use more powerful methods of sharpening expression ranges. May "
234 "be costly in terms of compile time"));
235
237 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
238 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
239 "Phi strongly connected components"),
240 cl::init(8));
241
242static cl::opt<bool>
243 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
244 cl::desc("Handle <= and >= in finite loops"),
245 cl::init(true));
246
248 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
249 cl::desc("Infer nuw/nsw flags using context where suitable"),
250 cl::init(true));
251
252//===----------------------------------------------------------------------===//
253// SCEV class definitions
254//===----------------------------------------------------------------------===//
255
256//===----------------------------------------------------------------------===//
257// Implementation of the SCEV class.
258//
259
260#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
262 print(dbgs());
263 dbgs() << '\n';
264}
265#endif
266
268 switch (getSCEVType()) {
269 case scConstant:
270 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
271 return;
272 case scVScale:
273 OS << "vscale";
274 return;
275 case scPtrToInt: {
276 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
277 const SCEV *Op = PtrToInt->getOperand();
278 OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
279 << *PtrToInt->getType() << ")";
280 return;
281 }
282 case scTruncate: {
283 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
284 const SCEV *Op = Trunc->getOperand();
285 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
286 << *Trunc->getType() << ")";
287 return;
288 }
289 case scZeroExtend: {
290 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
291 const SCEV *Op = ZExt->getOperand();
292 OS << "(zext " << *Op->getType() << " " << *Op << " to "
293 << *ZExt->getType() << ")";
294 return;
295 }
296 case scSignExtend: {
297 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
298 const SCEV *Op = SExt->getOperand();
299 OS << "(sext " << *Op->getType() << " " << *Op << " to "
300 << *SExt->getType() << ")";
301 return;
302 }
303 case scAddRecExpr: {
304 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
305 OS << "{" << *AR->getOperand(0);
306 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
307 OS << ",+," << *AR->getOperand(i);
308 OS << "}<";
309 if (AR->hasNoUnsignedWrap())
310 OS << "nuw><";
311 if (AR->hasNoSignedWrap())
312 OS << "nsw><";
313 if (AR->hasNoSelfWrap() &&
315 OS << "nw><";
316 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
317 OS << ">";
318 return;
319 }
320 case scAddExpr:
321 case scMulExpr:
322 case scUMaxExpr:
323 case scSMaxExpr:
324 case scUMinExpr:
325 case scSMinExpr:
327 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
328 const char *OpStr = nullptr;
329 switch (NAry->getSCEVType()) {
330 case scAddExpr: OpStr = " + "; break;
331 case scMulExpr: OpStr = " * "; break;
332 case scUMaxExpr: OpStr = " umax "; break;
333 case scSMaxExpr: OpStr = " smax "; break;
334 case scUMinExpr:
335 OpStr = " umin ";
336 break;
337 case scSMinExpr:
338 OpStr = " smin ";
339 break;
341 OpStr = " umin_seq ";
342 break;
343 default:
344 llvm_unreachable("There are no other nary expression types.");
345 }
346 OS << "(";
347 ListSeparator LS(OpStr);
348 for (const SCEV *Op : NAry->operands())
349 OS << LS << *Op;
350 OS << ")";
351 switch (NAry->getSCEVType()) {
352 case scAddExpr:
353 case scMulExpr:
354 if (NAry->hasNoUnsignedWrap())
355 OS << "<nuw>";
356 if (NAry->hasNoSignedWrap())
357 OS << "<nsw>";
358 break;
359 default:
360 // Nothing to print for other nary expressions.
361 break;
362 }
363 return;
364 }
365 case scUDivExpr: {
366 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
367 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
368 return;
369 }
370 case scUnknown:
371 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
372 return;
374 OS << "***COULDNOTCOMPUTE***";
375 return;
376 }
377 llvm_unreachable("Unknown SCEV kind!");
378}
379
381 switch (getSCEVType()) {
382 case scConstant:
383 return cast<SCEVConstant>(this)->getType();
384 case scVScale:
385 return cast<SCEVVScale>(this)->getType();
386 case scPtrToInt:
387 case scTruncate:
388 case scZeroExtend:
389 case scSignExtend:
390 return cast<SCEVCastExpr>(this)->getType();
391 case scAddRecExpr:
392 return cast<SCEVAddRecExpr>(this)->getType();
393 case scMulExpr:
394 return cast<SCEVMulExpr>(this)->getType();
395 case scUMaxExpr:
396 case scSMaxExpr:
397 case scUMinExpr:
398 case scSMinExpr:
399 return cast<SCEVMinMaxExpr>(this)->getType();
401 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
402 case scAddExpr:
403 return cast<SCEVAddExpr>(this)->getType();
404 case scUDivExpr:
405 return cast<SCEVUDivExpr>(this)->getType();
406 case scUnknown:
407 return cast<SCEVUnknown>(this)->getType();
409 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
410 }
411 llvm_unreachable("Unknown SCEV kind!");
412}
413
415 switch (getSCEVType()) {
416 case scConstant:
417 case scVScale:
418 case scUnknown:
419 return {};
420 case scPtrToInt:
421 case scTruncate:
422 case scZeroExtend:
423 case scSignExtend:
424 return cast<SCEVCastExpr>(this)->operands();
425 case scAddRecExpr:
426 case scAddExpr:
427 case scMulExpr:
428 case scUMaxExpr:
429 case scSMaxExpr:
430 case scUMinExpr:
431 case scSMinExpr:
433 return cast<SCEVNAryExpr>(this)->operands();
434 case scUDivExpr:
435 return cast<SCEVUDivExpr>(this)->operands();
437 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
438 }
439 llvm_unreachable("Unknown SCEV kind!");
440}
441
442bool SCEV::isZero() const {
443 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
444 return SC->getValue()->isZero();
445 return false;
446}
447
448bool SCEV::isOne() const {
449 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
450 return SC->getValue()->isOne();
451 return false;
452}
453
455 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
456 return SC->getValue()->isMinusOne();
457 return false;
458}
459
461 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
462 if (!Mul) return false;
463
464 // If there is a constant factor, it will be first.
465 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
466 if (!SC) return false;
467
468 // Return true if the value is negative, this matches things like (-42 * V).
469 return SC->getAPInt().isNegative();
470}
471
474
476 return S->getSCEVType() == scCouldNotCompute;
477}
478
481 ID.AddInteger(scConstant);
482 ID.AddPointer(V);
483 void *IP = nullptr;
484 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
485 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
486 UniqueSCEVs.InsertNode(S, IP);
487 return S;
488}
489
491 return getConstant(ConstantInt::get(getContext(), Val));
492}
493
494const SCEV *
496 IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
497 return getConstant(ConstantInt::get(ITy, V, isSigned));
498}
499
502 ID.AddInteger(scVScale);
503 ID.AddPointer(Ty);
504 void *IP = nullptr;
505 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
506 return S;
507 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
508 UniqueSCEVs.InsertNode(S, IP);
509 return S;
510}
511
513 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
514 if (EC.isScalable())
515 Res = getMulExpr(Res, getVScale(Ty));
516 return Res;
517}
518
520 const SCEV *op, Type *ty)
521 : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
522
523SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
524 Type *ITy)
525 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
527 "Must be a non-bit-width-changing pointer-to-integer cast!");
528}
529
531 SCEVTypes SCEVTy, const SCEV *op,
532 Type *ty)
533 : SCEVCastExpr(ID, SCEVTy, op, ty) {}
534
535SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
536 Type *ty)
538 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
539 "Cannot truncate non-integer value!");
540}
541
542SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
543 const SCEV *op, Type *ty)
545 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
546 "Cannot zero extend non-integer value!");
547}
548
549SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
550 const SCEV *op, Type *ty)
552 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
553 "Cannot sign extend non-integer value!");
554}
555
556void SCEVUnknown::deleted() {
557 // Clear this SCEVUnknown from various maps.
558 SE->forgetMemoizedResults(this);
559
560 // Remove this SCEVUnknown from the uniquing map.
561 SE->UniqueSCEVs.RemoveNode(this);
562
563 // Release the value.
564 setValPtr(nullptr);
565}
566
567void SCEVUnknown::allUsesReplacedWith(Value *New) {
568 // Clear this SCEVUnknown from various maps.
569 SE->forgetMemoizedResults(this);
570
571 // Remove this SCEVUnknown from the uniquing map.
572 SE->UniqueSCEVs.RemoveNode(this);
573
574 // Replace the value pointer in case someone is still using this SCEVUnknown.
575 setValPtr(New);
576}
577
578//===----------------------------------------------------------------------===//
579// SCEV Utilities
580//===----------------------------------------------------------------------===//
581
582/// Compare the two values \p LV and \p RV in terms of their "complexity" where
583/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
584/// operands in SCEV expressions. \p EqCache is a set of pairs of values that
585/// have been previously deemed to be "equally complex" by this routine. It is
586/// intended to avoid exponential time complexity in cases like:
587///
588/// %a = f(%x, %y)
589/// %b = f(%a, %a)
590/// %c = f(%b, %b)
591///
592/// %d = f(%x, %y)
593/// %e = f(%d, %d)
594/// %f = f(%e, %e)
595///
596/// CompareValueComplexity(%f, %c)
597///
598/// Since we do not continue running this routine on expression trees once we
599/// have seen unequal values, there is no need to track them in the cache.
600static int
602 const LoopInfo *const LI, Value *LV, Value *RV,
603 unsigned Depth) {
604 if (Depth > MaxValueCompareDepth || EqCacheValue.isEquivalent(LV, RV))
605 return 0;
606
607 // Order pointer values after integer values. This helps SCEVExpander form
608 // GEPs.
609 bool LIsPointer = LV->getType()->isPointerTy(),
610 RIsPointer = RV->getType()->isPointerTy();
611 if (LIsPointer != RIsPointer)
612 return (int)LIsPointer - (int)RIsPointer;
613
614 // Compare getValueID values.
615 unsigned LID = LV->getValueID(), RID = RV->getValueID();
616 if (LID != RID)
617 return (int)LID - (int)RID;
618
619 // Sort arguments by their position.
620 if (const auto *LA = dyn_cast<Argument>(LV)) {
621 const auto *RA = cast<Argument>(RV);
622 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
623 return (int)LArgNo - (int)RArgNo;
624 }
625
626 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
627 const auto *RGV = cast<GlobalValue>(RV);
628
629 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
630 auto LT = GV->getLinkage();
631 return !(GlobalValue::isPrivateLinkage(LT) ||
633 };
634
635 // Use the names to distinguish the two values, but only if the
636 // names are semantically important.
637 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
638 return LGV->getName().compare(RGV->getName());
639 }
640
641 // For instructions, compare their loop depth, and their operand count. This
642 // is pretty loose.
643 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
644 const auto *RInst = cast<Instruction>(RV);
645
646 // Compare loop depths.
647 const BasicBlock *LParent = LInst->getParent(),
648 *RParent = RInst->getParent();
649 if (LParent != RParent) {
650 unsigned LDepth = LI->getLoopDepth(LParent),
651 RDepth = LI->getLoopDepth(RParent);
652 if (LDepth != RDepth)
653 return (int)LDepth - (int)RDepth;
654 }
655
656 // Compare the number of operands.
657 unsigned LNumOps = LInst->getNumOperands(),
658 RNumOps = RInst->getNumOperands();
659 if (LNumOps != RNumOps)
660 return (int)LNumOps - (int)RNumOps;
661
662 for (unsigned Idx : seq(LNumOps)) {
663 int Result =
664 CompareValueComplexity(EqCacheValue, LI, LInst->getOperand(Idx),
665 RInst->getOperand(Idx), Depth + 1);
666 if (Result != 0)
667 return Result;
668 }
669 }
670
671 EqCacheValue.unionSets(LV, RV);
672 return 0;
673}
674
675// Return negative, zero, or positive, if LHS is less than, equal to, or greater
676// than RHS, respectively. A three-way result allows recursive comparisons to be
677// more efficient.
678// If the max analysis depth was reached, return std::nullopt, assuming we do
679// not know if they are equivalent for sure.
680static std::optional<int>
683 const LoopInfo *const LI, const SCEV *LHS,
684 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
685 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
686 if (LHS == RHS)
687 return 0;
688
689 // Primarily, sort the SCEVs by their getSCEVType().
690 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
691 if (LType != RType)
692 return (int)LType - (int)RType;
693
694 if (EqCacheSCEV.isEquivalent(LHS, RHS))
695 return 0;
696
698 return std::nullopt;
699
700 // Aside from the getSCEVType() ordering, the particular ordering
701 // isn't very important except that it's beneficial to be consistent,
702 // so that (a + b) and (b + a) don't end up as different expressions.
703 switch (LType) {
704 case scUnknown: {
705 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
706 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
707
708 int X = CompareValueComplexity(EqCacheValue, LI, LU->getValue(),
709 RU->getValue(), Depth + 1);
710 if (X == 0)
711 EqCacheSCEV.unionSets(LHS, RHS);
712 return X;
713 }
714
715 case scConstant: {
716 const SCEVConstant *LC = cast<SCEVConstant>(LHS);
717 const SCEVConstant *RC = cast<SCEVConstant>(RHS);
718
719 // Compare constant values.
720 const APInt &LA = LC->getAPInt();
721 const APInt &RA = RC->getAPInt();
722 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
723 if (LBitWidth != RBitWidth)
724 return (int)LBitWidth - (int)RBitWidth;
725 return LA.ult(RA) ? -1 : 1;
726 }
727
728 case scVScale: {
729 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
730 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
731 return LTy->getBitWidth() - RTy->getBitWidth();
732 }
733
734 case scAddRecExpr: {
735 const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
736 const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
737
738 // There is always a dominance between two recs that are used by one SCEV,
739 // so we can safely sort recs by loop header dominance. We require such
740 // order in getAddExpr.
741 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
742 if (LLoop != RLoop) {
743 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
744 assert(LHead != RHead && "Two loops share the same header?");
745 if (DT.dominates(LHead, RHead))
746 return 1;
747 assert(DT.dominates(RHead, LHead) &&
748 "No dominance between recurrences used by one SCEV?");
749 return -1;
750 }
751
752 [[fallthrough]];
753 }
754
755 case scTruncate:
756 case scZeroExtend:
757 case scSignExtend:
758 case scPtrToInt:
759 case scAddExpr:
760 case scMulExpr:
761 case scUDivExpr:
762 case scSMaxExpr:
763 case scUMaxExpr:
764 case scSMinExpr:
765 case scUMinExpr:
767 ArrayRef<const SCEV *> LOps = LHS->operands();
768 ArrayRef<const SCEV *> ROps = RHS->operands();
769
770 // Lexicographically compare n-ary-like expressions.
771 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
772 if (LNumOps != RNumOps)
773 return (int)LNumOps - (int)RNumOps;
774
775 for (unsigned i = 0; i != LNumOps; ++i) {
776 auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LOps[i],
777 ROps[i], DT, Depth + 1);
778 if (X != 0)
779 return X;
780 }
781 EqCacheSCEV.unionSets(LHS, RHS);
782 return 0;
783 }
784
786 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
787 }
788 llvm_unreachable("Unknown SCEV kind!");
789}
790
791/// Given a list of SCEV objects, order them by their complexity, and group
792/// objects of the same complexity together by value. When this routine is
793/// finished, we know that any duplicates in the vector are consecutive and that
794/// complexity is monotonically increasing.
795///
796/// Note that we go take special precautions to ensure that we get deterministic
797/// results from this routine. In other words, we don't want the results of
798/// this to depend on where the addresses of various SCEV objects happened to
799/// land in memory.
801 LoopInfo *LI, DominatorTree &DT) {
802 if (Ops.size() < 2) return; // Noop
803
806
807 // Whether LHS has provably less complexity than RHS.
808 auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
809 auto Complexity =
810 CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT);
811 return Complexity && *Complexity < 0;
812 };
813 if (Ops.size() == 2) {
814 // This is the common case, which also happens to be trivially simple.
815 // Special case it.
816 const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
817 if (IsLessComplex(RHS, LHS))
818 std::swap(LHS, RHS);
819 return;
820 }
821
822 // Do the rough sort by complexity.
823 llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
824 return IsLessComplex(LHS, RHS);
825 });
826
827 // Now that we are sorted by complexity, group elements of the same
828 // complexity. Note that this is, at worst, N^2, but the vector is likely to
829 // be extremely short in practice. Note that we take this approach because we
830 // do not want to depend on the addresses of the objects we are grouping.
831 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
832 const SCEV *S = Ops[i];
833 unsigned Complexity = S->getSCEVType();
834
835 // If there are any objects of the same complexity and same value as this
836 // one, group them.
837 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
838 if (Ops[j] == S) { // Found a duplicate.
839 // Move it to immediately after i'th element.
840 std::swap(Ops[i+1], Ops[j]);
841 ++i; // no need to rescan it.
842 if (i == e-2) return; // Done!
843 }
844 }
845 }
846}
847
848/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
849/// least HugeExprThreshold nodes).
851 return any_of(Ops, [](const SCEV *S) {
853 });
854}
855
856//===----------------------------------------------------------------------===//
857// Simple SCEV method implementations
858//===----------------------------------------------------------------------===//
859
860/// Compute BC(It, K). The result has width W. Assume, K > 0.
861static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
862 ScalarEvolution &SE,
863 Type *ResultTy) {
864 // Handle the simplest case efficiently.
865 if (K == 1)
866 return SE.getTruncateOrZeroExtend(It, ResultTy);
867
868 // We are using the following formula for BC(It, K):
869 //
870 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
871 //
872 // Suppose, W is the bitwidth of the return value. We must be prepared for
873 // overflow. Hence, we must assure that the result of our computation is
874 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
875 // safe in modular arithmetic.
876 //
877 // However, this code doesn't use exactly that formula; the formula it uses
878 // is something like the following, where T is the number of factors of 2 in
879 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
880 // exponentiation:
881 //
882 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
883 //
884 // This formula is trivially equivalent to the previous formula. However,
885 // this formula can be implemented much more efficiently. The trick is that
886 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
887 // arithmetic. To do exact division in modular arithmetic, all we have
888 // to do is multiply by the inverse. Therefore, this step can be done at
889 // width W.
890 //
891 // The next issue is how to safely do the division by 2^T. The way this
892 // is done is by doing the multiplication step at a width of at least W + T
893 // bits. This way, the bottom W+T bits of the product are accurate. Then,
894 // when we perform the division by 2^T (which is equivalent to a right shift
895 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
896 // truncated out after the division by 2^T.
897 //
898 // In comparison to just directly using the first formula, this technique
899 // is much more efficient; using the first formula requires W * K bits,
900 // but this formula less than W + K bits. Also, the first formula requires
901 // a division step, whereas this formula only requires multiplies and shifts.
902 //
903 // It doesn't matter whether the subtraction step is done in the calculation
904 // width or the input iteration count's width; if the subtraction overflows,
905 // the result must be zero anyway. We prefer here to do it in the width of
906 // the induction variable because it helps a lot for certain cases; CodeGen
907 // isn't smart enough to ignore the overflow, which leads to much less
908 // efficient code if the width of the subtraction is wider than the native
909 // register width.
910 //
911 // (It's possible to not widen at all by pulling out factors of 2 before
912 // the multiplication; for example, K=2 can be calculated as
913 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
914 // extra arithmetic, so it's not an obvious win, and it gets
915 // much more complicated for K > 3.)
916
917 // Protection from insane SCEVs; this bound is conservative,
918 // but it probably doesn't matter.
919 if (K > 1000)
920 return SE.getCouldNotCompute();
921
922 unsigned W = SE.getTypeSizeInBits(ResultTy);
923
924 // Calculate K! / 2^T and T; we divide out the factors of two before
925 // multiplying for calculating K! / 2^T to avoid overflow.
926 // Other overflow doesn't matter because we only care about the bottom
927 // W bits of the result.
928 APInt OddFactorial(W, 1);
929 unsigned T = 1;
930 for (unsigned i = 3; i <= K; ++i) {
931 unsigned TwoFactors = countr_zero(i);
932 T += TwoFactors;
933 OddFactorial *= (i >> TwoFactors);
934 }
935
936 // We need at least W + T bits for the multiplication step
937 unsigned CalculationBits = W + T;
938
939 // Calculate 2^T, at width T+W.
940 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
941
942 // Calculate the multiplicative inverse of K! / 2^T;
943 // this multiplication factor will perform the exact division by
944 // K! / 2^T.
945 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
946
947 // Calculate the product, at width T+W
948 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
949 CalculationBits);
950 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
951 for (unsigned i = 1; i != K; ++i) {
952 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
953 Dividend = SE.getMulExpr(Dividend,
954 SE.getTruncateOrZeroExtend(S, CalculationTy));
955 }
956
957 // Divide by 2^T
958 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
959
960 // Truncate the result, and divide by K! / 2^T.
961
962 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
963 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
964}
965
966/// Return the value of this chain of recurrences at the specified iteration
967/// number. We can evaluate this recurrence by multiplying each element in the
968/// chain by the binomial coefficient corresponding to it. In other words, we
969/// can evaluate {A,+,B,+,C,+,D} as:
970///
971/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
972///
973/// where BC(It, k) stands for binomial coefficient.
975 ScalarEvolution &SE) const {
976 return evaluateAtIteration(operands(), It, SE);
977}
978
979const SCEV *
981 const SCEV *It, ScalarEvolution &SE) {
982 assert(Operands.size() > 0);
983 const SCEV *Result = Operands[0];
984 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
985 // The computation is correct in the face of overflow provided that the
986 // multiplication is performed _after_ the evaluation of the binomial
987 // coefficient.
988 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
989 if (isa<SCEVCouldNotCompute>(Coeff))
990 return Coeff;
991
992 Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
993 }
994 return Result;
995}
996
997//===----------------------------------------------------------------------===//
998// SCEV Expression folder implementations
999//===----------------------------------------------------------------------===//
1000
1002 unsigned Depth) {
1003 assert(Depth <= 1 &&
1004 "getLosslessPtrToIntExpr() should self-recurse at most once.");
1005
1006 // We could be called with an integer-typed operands during SCEV rewrites.
1007 // Since the operand is an integer already, just perform zext/trunc/self cast.
1008 if (!Op->getType()->isPointerTy())
1009 return Op;
1010
1011 // What would be an ID for such a SCEV cast expression?
1013 ID.AddInteger(scPtrToInt);
1014 ID.AddPointer(Op);
1015
1016 void *IP = nullptr;
1017
1018 // Is there already an expression for such a cast?
1019 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1020 return S;
1021
1022 // It isn't legal for optimizations to construct new ptrtoint expressions
1023 // for non-integral pointers.
1024 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1025 return getCouldNotCompute();
1026
1027 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1028
1029 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1030 // is sufficiently wide to represent all possible pointer values.
1031 // We could theoretically teach SCEV to truncate wider pointers, but
1032 // that isn't implemented for now.
1034 getDataLayout().getTypeSizeInBits(IntPtrTy))
1035 return getCouldNotCompute();
1036
1037 // If not, is this expression something we can't reduce any further?
1038 if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1039 // Perform some basic constant folding. If the operand of the ptr2int cast
1040 // is a null pointer, don't create a ptr2int SCEV expression (that will be
1041 // left as-is), but produce a zero constant.
1042 // NOTE: We could handle a more general case, but lack motivational cases.
1043 if (isa<ConstantPointerNull>(U->getValue()))
1044 return getZero(IntPtrTy);
1045
1046 // Create an explicit cast node.
1047 // We can reuse the existing insert position since if we get here,
1048 // we won't have made any changes which would invalidate it.
1049 SCEV *S = new (SCEVAllocator)
1050 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1051 UniqueSCEVs.InsertNode(S, IP);
1052 registerUser(S, Op);
1053 return S;
1054 }
1055
1056 assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
1057 "non-SCEVUnknown's.");
1058
1059 // Otherwise, we've got some expression that is more complex than just a
1060 // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1061 // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1062 // only, and the expressions must otherwise be integer-typed.
1063 // So sink the cast down to the SCEVUnknown's.
1064
1065 /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1066 /// which computes a pointer-typed value, and rewrites the whole expression
1067 /// tree so that *all* the computations are done on integers, and the only
1068 /// pointer-typed operands in the expression are SCEVUnknown.
1069 class SCEVPtrToIntSinkingRewriter
1070 : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1072
1073 public:
1074 SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1075
1076 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1077 SCEVPtrToIntSinkingRewriter Rewriter(SE);
1078 return Rewriter.visit(Scev);
1079 }
1080
1081 const SCEV *visit(const SCEV *S) {
1082 Type *STy = S->getType();
1083 // If the expression is not pointer-typed, just keep it as-is.
1084 if (!STy->isPointerTy())
1085 return S;
1086 // Else, recursively sink the cast down into it.
1087 return Base::visit(S);
1088 }
1089
1090 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1092 bool Changed = false;
1093 for (const auto *Op : Expr->operands()) {
1094 Operands.push_back(visit(Op));
1095 Changed |= Op != Operands.back();
1096 }
1097 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1098 }
1099
1100 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1102 bool Changed = false;
1103 for (const auto *Op : Expr->operands()) {
1104 Operands.push_back(visit(Op));
1105 Changed |= Op != Operands.back();
1106 }
1107 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1108 }
1109
1110 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1111 assert(Expr->getType()->isPointerTy() &&
1112 "Should only reach pointer-typed SCEVUnknown's.");
1113 return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1114 }
1115 };
1116
1117 // And actually perform the cast sinking.
1118 const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1119 assert(IntOp->getType()->isIntegerTy() &&
1120 "We must have succeeded in sinking the cast, "
1121 "and ending up with an integer-typed expression!");
1122 return IntOp;
1123}
1124
1126 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1127
1128 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1129 if (isa<SCEVCouldNotCompute>(IntOp))
1130 return IntOp;
1131
1132 return getTruncateOrZeroExtend(IntOp, Ty);
1133}
1134
1136 unsigned Depth) {
1137 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1138 "This is not a truncating conversion!");
1139 assert(isSCEVable(Ty) &&
1140 "This is not a conversion to a SCEVable type!");
1141 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1142 Ty = getEffectiveSCEVType(Ty);
1143
1145 ID.AddInteger(scTruncate);
1146 ID.AddPointer(Op);
1147 ID.AddPointer(Ty);
1148 void *IP = nullptr;
1149 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1150
1151 // Fold if the operand is constant.
1152 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1153 return getConstant(
1154 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1155
1156 // trunc(trunc(x)) --> trunc(x)
1157 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
1158 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1159
1160 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1161 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1162 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1163
1164 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1165 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1166 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1167
1168 if (Depth > MaxCastDepth) {
1169 SCEV *S =
1170 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1171 UniqueSCEVs.InsertNode(S, IP);
1172 registerUser(S, Op);
1173 return S;
1174 }
1175
1176 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1177 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1178 // if after transforming we have at most one truncate, not counting truncates
1179 // that replace other casts.
1180 if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) {
1181 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1183 unsigned numTruncs = 0;
1184 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1185 ++i) {
1186 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1187 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1188 isa<SCEVTruncateExpr>(S))
1189 numTruncs++;
1190 Operands.push_back(S);
1191 }
1192 if (numTruncs < 2) {
1193 if (isa<SCEVAddExpr>(Op))
1194 return getAddExpr(Operands);
1195 if (isa<SCEVMulExpr>(Op))
1196 return getMulExpr(Operands);
1197 llvm_unreachable("Unexpected SCEV type for Op.");
1198 }
1199 // Although we checked in the beginning that ID is not in the cache, it is
1200 // possible that during recursion and different modification ID was inserted
1201 // into the cache. So if we find it, just return it.
1202 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1203 return S;
1204 }
1205
1206 // If the input value is a chrec scev, truncate the chrec's operands.
1207 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1209 for (const SCEV *Op : AddRec->operands())
1210 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1211 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1212 }
1213
1214 // Return zero if truncating to known zeros.
1215 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1216 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1217 return getZero(Ty);
1218
1219 // The cast wasn't folded; create an explicit cast node. We can reuse
1220 // the existing insert position since if we get here, we won't have
1221 // made any changes which would invalidate it.
1222 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1223 Op, Ty);
1224 UniqueSCEVs.InsertNode(S, IP);
1225 registerUser(S, Op);
1226 return S;
1227}
1228
1229// Get the limit of a recurrence such that incrementing by Step cannot cause
1230// signed overflow as long as the value of the recurrence within the
1231// loop does not exceed this limit before incrementing.
1232static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1233 ICmpInst::Predicate *Pred,
1234 ScalarEvolution *SE) {
1235 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1236 if (SE->isKnownPositive(Step)) {
1237 *Pred = ICmpInst::ICMP_SLT;
1239 SE->getSignedRangeMax(Step));
1240 }
1241 if (SE->isKnownNegative(Step)) {
1242 *Pred = ICmpInst::ICMP_SGT;
1244 SE->getSignedRangeMin(Step));
1245 }
1246 return nullptr;
1247}
1248
1249// Get the limit of a recurrence such that incrementing by Step cannot cause
1250// unsigned overflow as long as the value of the recurrence within the loop does
1251// not exceed this limit before incrementing.
1253 ICmpInst::Predicate *Pred,
1254 ScalarEvolution *SE) {
1255 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1256 *Pred = ICmpInst::ICMP_ULT;
1257
1259 SE->getUnsignedRangeMax(Step));
1260}
1261
1262namespace {
1263
1264struct ExtendOpTraitsBase {
1265 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1266 unsigned);
1267};
1268
1269// Used to make code generic over signed and unsigned overflow.
1270template <typename ExtendOp> struct ExtendOpTraits {
1271 // Members present:
1272 //
1273 // static const SCEV::NoWrapFlags WrapType;
1274 //
1275 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1276 //
1277 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1278 // ICmpInst::Predicate *Pred,
1279 // ScalarEvolution *SE);
1280};
1281
1282template <>
1283struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1284 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1285
1286 static const GetExtendExprTy GetExtendExpr;
1287
1288 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1289 ICmpInst::Predicate *Pred,
1290 ScalarEvolution *SE) {
1291 return getSignedOverflowLimitForStep(Step, Pred, SE);
1292 }
1293};
1294
1295const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1297
1298template <>
1299struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1300 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1301
1302 static const GetExtendExprTy GetExtendExpr;
1303
1304 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1305 ICmpInst::Predicate *Pred,
1306 ScalarEvolution *SE) {
1307 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1308 }
1309};
1310
1311const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1313
1314} // end anonymous namespace
1315
1316// The recurrence AR has been shown to have no signed/unsigned wrap or something
1317// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1318// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1319// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1320// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1321// expression "Step + sext/zext(PreIncAR)" is congruent with
1322// "sext/zext(PostIncAR)"
1323template <typename ExtendOpTy>
1324static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1325 ScalarEvolution *SE, unsigned Depth) {
1326 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1327 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1328
1329 const Loop *L = AR->getLoop();
1330 const SCEV *Start = AR->getStart();
1331 const SCEV *Step = AR->getStepRecurrence(*SE);
1332
1333 // Check for a simple looking step prior to loop entry.
1334 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1335 if (!SA)
1336 return nullptr;
1337
1338 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1339 // subtraction is expensive. For this purpose, perform a quick and dirty
1340 // difference, by checking for Step in the operand list. Note, that
1341 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1343 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1344 if (*It == Step) {
1345 DiffOps.erase(It);
1346 break;
1347 }
1348
1349 if (DiffOps.size() == SA->getNumOperands())
1350 return nullptr;
1351
1352 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1353 // `Step`:
1354
1355 // 1. NSW/NUW flags on the step increment.
1356 auto PreStartFlags =
1358 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1359 const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1360 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1361
1362 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1363 // "S+X does not sign/unsign-overflow".
1364 //
1365
1366 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1367 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1368 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1369 return PreStart;
1370
1371 // 2. Direct overflow check on the step operation's expression.
1372 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1373 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1374 const SCEV *OperandExtendedStart =
1375 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1376 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1377 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1378 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1379 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1380 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1381 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1382 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1383 }
1384 return PreStart;
1385 }
1386
1387 // 3. Loop precondition.
1389 const SCEV *OverflowLimit =
1390 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1391
1392 if (OverflowLimit &&
1393 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1394 return PreStart;
1395
1396 return nullptr;
1397}
1398
1399// Get the normalized zero or sign extended expression for this AddRec's Start.
1400template <typename ExtendOpTy>
1401static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1402 ScalarEvolution *SE,
1403 unsigned Depth) {
1404 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1405
1406 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1407 if (!PreStart)
1408 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1409
1410 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1411 Depth),
1412 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1413}
1414
1415// Try to prove away overflow by looking at "nearby" add recurrences. A
1416// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1417// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1418//
1419// Formally:
1420//
1421// {S,+,X} == {S-T,+,X} + T
1422// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1423//
1424// If ({S-T,+,X} + T) does not overflow ... (1)
1425//
1426// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1427//
1428// If {S-T,+,X} does not overflow ... (2)
1429//
1430// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1431// == {Ext(S-T)+Ext(T),+,Ext(X)}
1432//
1433// If (S-T)+T does not overflow ... (3)
1434//
1435// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1436// == {Ext(S),+,Ext(X)} == LHS
1437//
1438// Thus, if (1), (2) and (3) are true for some T, then
1439// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1440//
1441// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1442// does not overflow" restricted to the 0th iteration. Therefore we only need
1443// to check for (1) and (2).
1444//
1445// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1446// is `Delta` (defined below).
1447template <typename ExtendOpTy>
1448bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1449 const SCEV *Step,
1450 const Loop *L) {
1451 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1452
1453 // We restrict `Start` to a constant to prevent SCEV from spending too much
1454 // time here. It is correct (but more expensive) to continue with a
1455 // non-constant `Start` and do a general SCEV subtraction to compute
1456 // `PreStart` below.
1457 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1458 if (!StartC)
1459 return false;
1460
1461 APInt StartAI = StartC->getAPInt();
1462
1463 for (unsigned Delta : {-2, -1, 1, 2}) {
1464 const SCEV *PreStart = getConstant(StartAI - Delta);
1465
1467 ID.AddInteger(scAddRecExpr);
1468 ID.AddPointer(PreStart);
1469 ID.AddPointer(Step);
1470 ID.AddPointer(L);
1471 void *IP = nullptr;
1472 const auto *PreAR =
1473 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1474
1475 // Give up if we don't already have the add recurrence we need because
1476 // actually constructing an add recurrence is relatively expensive.
1477 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1478 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1480 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1481 DeltaS, &Pred, this);
1482 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1483 return true;
1484 }
1485 }
1486
1487 return false;
1488}
1489
1490// Finds an integer D for an expression (C + x + y + ...) such that the top
1491// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1492// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1493// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1494// the (C + x + y + ...) expression is \p WholeAddExpr.
1496 const SCEVConstant *ConstantTerm,
1497 const SCEVAddExpr *WholeAddExpr) {
1498 const APInt &C = ConstantTerm->getAPInt();
1499 const unsigned BitWidth = C.getBitWidth();
1500 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1501 uint32_t TZ = BitWidth;
1502 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1503 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1504 if (TZ) {
1505 // Set D to be as many least significant bits of C as possible while still
1506 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1507 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1508 }
1509 return APInt(BitWidth, 0);
1510}
1511
1512// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1513// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1514// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1515// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1517 const APInt &ConstantStart,
1518 const SCEV *Step) {
1519 const unsigned BitWidth = ConstantStart.getBitWidth();
1520 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1521 if (TZ)
1522 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1523 : ConstantStart;
1524 return APInt(BitWidth, 0);
1525}
1526
1528 const ScalarEvolution::FoldID &ID, const SCEV *S,
1531 &FoldCacheUser) {
1532 auto I = FoldCache.insert({ID, S});
1533 if (!I.second) {
1534 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1535 // entry.
1536 auto &UserIDs = FoldCacheUser[I.first->second];
1537 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1538 for (unsigned I = 0; I != UserIDs.size(); ++I)
1539 if (UserIDs[I] == ID) {
1540 std::swap(UserIDs[I], UserIDs.back());
1541 break;
1542 }
1543 UserIDs.pop_back();
1544 I.first->second = S;
1545 }
1546 auto R = FoldCacheUser.insert({S, {}});
1547 R.first->second.push_back(ID);
1548}
1549
1550const SCEV *
1552 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1553 "This is not an extending conversion!");
1554 assert(isSCEVable(Ty) &&
1555 "This is not a conversion to a SCEVable type!");
1556 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1557 Ty = getEffectiveSCEVType(Ty);
1558
1559 FoldID ID(scZeroExtend, Op, Ty);
1560 auto Iter = FoldCache.find(ID);
1561 if (Iter != FoldCache.end())
1562 return Iter->second;
1563
1564 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1565 if (!isa<SCEVZeroExtendExpr>(S))
1566 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1567 return S;
1568}
1569
1571 unsigned Depth) {
1572 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1573 "This is not an extending conversion!");
1574 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1575 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1576
1577 // Fold if the operand is constant.
1578 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1579 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1580
1581 // zext(zext(x)) --> zext(x)
1582 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1583 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1584
1585 // Before doing any expensive analysis, check to see if we've already
1586 // computed a SCEV for this Op and Ty.
1588 ID.AddInteger(scZeroExtend);
1589 ID.AddPointer(Op);
1590 ID.AddPointer(Ty);
1591 void *IP = nullptr;
1592 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1593 if (Depth > MaxCastDepth) {
1594 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1595 Op, Ty);
1596 UniqueSCEVs.InsertNode(S, IP);
1597 registerUser(S, Op);
1598 return S;
1599 }
1600
1601 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1602 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1603 // It's possible the bits taken off by the truncate were all zero bits. If
1604 // so, we should be able to simplify this further.
1605 const SCEV *X = ST->getOperand();
1607 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1608 unsigned NewBits = getTypeSizeInBits(Ty);
1609 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1610 CR.zextOrTrunc(NewBits)))
1611 return getTruncateOrZeroExtend(X, Ty, Depth);
1612 }
1613
1614 // If the input value is a chrec scev, and we can prove that the value
1615 // did not overflow the old, smaller, value, we can zero extend all of the
1616 // operands (often constants). This allows analysis of something like
1617 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1618 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1619 if (AR->isAffine()) {
1620 const SCEV *Start = AR->getStart();
1621 const SCEV *Step = AR->getStepRecurrence(*this);
1622 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1623 const Loop *L = AR->getLoop();
1624
1625 // If we have special knowledge that this addrec won't overflow,
1626 // we don't need to do any further analysis.
1627 if (AR->hasNoUnsignedWrap()) {
1628 Start =
1629 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1630 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1631 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1632 }
1633
1634 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1635 // Note that this serves two purposes: It filters out loops that are
1636 // simply not analyzable, and it covers the case where this code is
1637 // being called from within backedge-taken count analysis, such that
1638 // attempting to ask for the backedge-taken count would likely result
1639 // in infinite recursion. In the later case, the analysis code will
1640 // cope with a conservative value, and it will take care to purge
1641 // that value once it has finished.
1642 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1643 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1644 // Manually compute the final value for AR, checking for overflow.
1645
1646 // Check whether the backedge-taken count can be losslessly casted to
1647 // the addrec's type. The count is always unsigned.
1648 const SCEV *CastedMaxBECount =
1649 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1650 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1651 CastedMaxBECount, MaxBECount->getType(), Depth);
1652 if (MaxBECount == RecastedMaxBECount) {
1653 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1654 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1655 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1657 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1659 Depth + 1),
1660 WideTy, Depth + 1);
1661 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1662 const SCEV *WideMaxBECount =
1663 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1664 const SCEV *OperandExtendedAdd =
1665 getAddExpr(WideStart,
1666 getMulExpr(WideMaxBECount,
1667 getZeroExtendExpr(Step, WideTy, Depth + 1),
1670 if (ZAdd == OperandExtendedAdd) {
1671 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1672 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1673 // Return the expression with the addrec on the outside.
1674 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1675 Depth + 1);
1676 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1677 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1678 }
1679 // Similar to above, only this time treat the step value as signed.
1680 // This covers loops that count down.
1681 OperandExtendedAdd =
1682 getAddExpr(WideStart,
1683 getMulExpr(WideMaxBECount,
1684 getSignExtendExpr(Step, WideTy, Depth + 1),
1687 if (ZAdd == OperandExtendedAdd) {
1688 // Cache knowledge of AR NW, which is propagated to this AddRec.
1689 // Negative step causes unsigned wrap, but it still can't self-wrap.
1690 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1691 // Return the expression with the addrec on the outside.
1692 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1693 Depth + 1);
1694 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1695 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1696 }
1697 }
1698 }
1699
1700 // Normally, in the cases we can prove no-overflow via a
1701 // backedge guarding condition, we can also compute a backedge
1702 // taken count for the loop. The exceptions are assumptions and
1703 // guards present in the loop -- SCEV is not great at exploiting
1704 // these to compute max backedge taken counts, but can still use
1705 // these to prove lack of overflow. Use this fact to avoid
1706 // doing extra work that may not pay off.
1707 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1708 !AC.assumptions().empty()) {
1709
1710 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1711 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1712 if (AR->hasNoUnsignedWrap()) {
1713 // Same as nuw case above - duplicated here to avoid a compile time
1714 // issue. It's not clear that the order of checks does matter, but
1715 // it's one of two issue possible causes for a change which was
1716 // reverted. Be conservative for the moment.
1717 Start =
1718 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1719 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1720 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1721 }
1722
1723 // For a negative step, we can extend the operands iff doing so only
1724 // traverses values in the range zext([0,UINT_MAX]).
1725 if (isKnownNegative(Step)) {
1727 getSignedRangeMin(Step));
1730 // Cache knowledge of AR NW, which is propagated to this
1731 // AddRec. Negative step causes unsigned wrap, but it
1732 // still can't self-wrap.
1733 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1734 // Return the expression with the addrec on the outside.
1735 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1736 Depth + 1);
1737 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1738 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1739 }
1740 }
1741 }
1742
1743 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1744 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1745 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1746 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1747 const APInt &C = SC->getAPInt();
1748 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1749 if (D != 0) {
1750 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1751 const SCEV *SResidual =
1752 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1753 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1754 return getAddExpr(SZExtD, SZExtR,
1756 Depth + 1);
1757 }
1758 }
1759
1760 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1761 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1762 Start =
1763 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1764 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1765 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1766 }
1767 }
1768
1769 // zext(A % B) --> zext(A) % zext(B)
1770 {
1771 const SCEV *LHS;
1772 const SCEV *RHS;
1773 if (matchURem(Op, LHS, RHS))
1774 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1775 getZeroExtendExpr(RHS, Ty, Depth + 1));
1776 }
1777
1778 // zext(A / B) --> zext(A) / zext(B).
1779 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1780 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1781 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1782
1783 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1784 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1785 if (SA->hasNoUnsignedWrap()) {
1786 // If the addition does not unsign overflow then we can, by definition,
1787 // commute the zero extension with the addition operation.
1789 for (const auto *Op : SA->operands())
1790 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1791 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1792 }
1793
1794 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1795 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1796 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1797 //
1798 // Often address arithmetics contain expressions like
1799 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1800 // This transformation is useful while proving that such expressions are
1801 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1802 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1803 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1804 if (D != 0) {
1805 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1806 const SCEV *SResidual =
1808 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1809 return getAddExpr(SZExtD, SZExtR,
1811 Depth + 1);
1812 }
1813 }
1814 }
1815
1816 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1817 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1818 if (SM->hasNoUnsignedWrap()) {
1819 // If the multiply does not unsign overflow then we can, by definition,
1820 // commute the zero extension with the multiply operation.
1822 for (const auto *Op : SM->operands())
1823 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1824 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1825 }
1826
1827 // zext(2^K * (trunc X to iN)) to iM ->
1828 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1829 //
1830 // Proof:
1831 //
1832 // zext(2^K * (trunc X to iN)) to iM
1833 // = zext((trunc X to iN) << K) to iM
1834 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1835 // (because shl removes the top K bits)
1836 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1837 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1838 //
1839 if (SM->getNumOperands() == 2)
1840 if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1841 if (MulLHS->getAPInt().isPowerOf2())
1842 if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1843 int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
1844 MulLHS->getAPInt().logBase2();
1845 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1846 return getMulExpr(
1847 getZeroExtendExpr(MulLHS, Ty),
1849 getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
1850 SCEV::FlagNUW, Depth + 1);
1851 }
1852 }
1853
1854 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1855 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1856 if (isa<SCEVUMinExpr>(Op) || isa<SCEVUMaxExpr>(Op)) {
1857 auto *MinMax = cast<SCEVMinMaxExpr>(Op);
1859 for (auto *Operand : MinMax->operands())
1860 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1861 if (isa<SCEVUMinExpr>(MinMax))
1862 return getUMinExpr(Operands);
1863 return getUMaxExpr(Operands);
1864 }
1865
1866 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1867 if (auto *MinMax = dyn_cast<SCEVSequentialMinMaxExpr>(Op)) {
1868 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1870 for (auto *Operand : MinMax->operands())
1871 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1872 return getUMinExpr(Operands, /*Sequential*/ true);
1873 }
1874
1875 // The cast wasn't folded; create an explicit cast node.
1876 // Recompute the insert position, as it may have been invalidated.
1877 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1878 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1879 Op, Ty);
1880 UniqueSCEVs.InsertNode(S, IP);
1881 registerUser(S, Op);
1882 return S;
1883}
1884
1885const SCEV *
1887 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1888 "This is not an extending conversion!");
1889 assert(isSCEVable(Ty) &&
1890 "This is not a conversion to a SCEVable type!");
1891 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1892 Ty = getEffectiveSCEVType(Ty);
1893
1894 FoldID ID(scSignExtend, Op, Ty);
1895 auto Iter = FoldCache.find(ID);
1896 if (Iter != FoldCache.end())
1897 return Iter->second;
1898
1899 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
1900 if (!isa<SCEVSignExtendExpr>(S))
1901 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1902 return S;
1903}
1904
1906 unsigned Depth) {
1907 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1908 "This is not an extending conversion!");
1909 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1910 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1911 Ty = getEffectiveSCEVType(Ty);
1912
1913 // Fold if the operand is constant.
1914 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1915 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
1916
1917 // sext(sext(x)) --> sext(x)
1918 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1919 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1920
1921 // sext(zext(x)) --> zext(x)
1922 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1923 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1924
1925 // Before doing any expensive analysis, check to see if we've already
1926 // computed a SCEV for this Op and Ty.
1928 ID.AddInteger(scSignExtend);
1929 ID.AddPointer(Op);
1930 ID.AddPointer(Ty);
1931 void *IP = nullptr;
1932 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1933 // Limit recursion depth.
1934 if (Depth > MaxCastDepth) {
1935 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1936 Op, Ty);
1937 UniqueSCEVs.InsertNode(S, IP);
1938 registerUser(S, Op);
1939 return S;
1940 }
1941
1942 // sext(trunc(x)) --> sext(x) or x or trunc(x)
1943 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1944 // It's possible the bits taken off by the truncate were all sign bits. If
1945 // so, we should be able to simplify this further.
1946 const SCEV *X = ST->getOperand();
1948 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1949 unsigned NewBits = getTypeSizeInBits(Ty);
1950 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1951 CR.sextOrTrunc(NewBits)))
1952 return getTruncateOrSignExtend(X, Ty, Depth);
1953 }
1954
1955 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1956 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1957 if (SA->hasNoSignedWrap()) {
1958 // If the addition does not sign overflow then we can, by definition,
1959 // commute the sign extension with the addition operation.
1961 for (const auto *Op : SA->operands())
1962 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1963 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1964 }
1965
1966 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1967 // if D + (C - D + x + y + ...) could be proven to not signed wrap
1968 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1969 //
1970 // For instance, this will bring two seemingly different expressions:
1971 // 1 + sext(5 + 20 * %x + 24 * %y) and
1972 // sext(6 + 20 * %x + 24 * %y)
1973 // to the same form:
1974 // 2 + sext(4 + 20 * %x + 24 * %y)
1975 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1976 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1977 if (D != 0) {
1978 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1979 const SCEV *SResidual =
1981 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1982 return getAddExpr(SSExtD, SSExtR,
1984 Depth + 1);
1985 }
1986 }
1987 }
1988 // If the input value is a chrec scev, and we can prove that the value
1989 // did not overflow the old, smaller, value, we can sign extend all of the
1990 // operands (often constants). This allows analysis of something like
1991 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
1992 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1993 if (AR->isAffine()) {
1994 const SCEV *Start = AR->getStart();
1995 const SCEV *Step = AR->getStepRecurrence(*this);
1996 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1997 const Loop *L = AR->getLoop();
1998
1999 // If we have special knowledge that this addrec won't overflow,
2000 // we don't need to do any further analysis.
2001 if (AR->hasNoSignedWrap()) {
2002 Start =
2003 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2004 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2005 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2006 }
2007
2008 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2009 // Note that this serves two purposes: It filters out loops that are
2010 // simply not analyzable, and it covers the case where this code is
2011 // being called from within backedge-taken count analysis, such that
2012 // attempting to ask for the backedge-taken count would likely result
2013 // in infinite recursion. In the later case, the analysis code will
2014 // cope with a conservative value, and it will take care to purge
2015 // that value once it has finished.
2016 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2017 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2018 // Manually compute the final value for AR, checking for
2019 // overflow.
2020
2021 // Check whether the backedge-taken count can be losslessly casted to
2022 // the addrec's type. The count is always unsigned.
2023 const SCEV *CastedMaxBECount =
2024 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2025 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2026 CastedMaxBECount, MaxBECount->getType(), Depth);
2027 if (MaxBECount == RecastedMaxBECount) {
2028 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2029 // Check whether Start+Step*MaxBECount has no signed overflow.
2030 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2032 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2034 Depth + 1),
2035 WideTy, Depth + 1);
2036 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2037 const SCEV *WideMaxBECount =
2038 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2039 const SCEV *OperandExtendedAdd =
2040 getAddExpr(WideStart,
2041 getMulExpr(WideMaxBECount,
2042 getSignExtendExpr(Step, WideTy, Depth + 1),
2045 if (SAdd == OperandExtendedAdd) {
2046 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2047 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2048 // Return the expression with the addrec on the outside.
2049 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2050 Depth + 1);
2051 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2052 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2053 }
2054 // Similar to above, only this time treat the step value as unsigned.
2055 // This covers loops that count up with an unsigned step.
2056 OperandExtendedAdd =
2057 getAddExpr(WideStart,
2058 getMulExpr(WideMaxBECount,
2059 getZeroExtendExpr(Step, WideTy, Depth + 1),
2062 if (SAdd == OperandExtendedAdd) {
2063 // If AR wraps around then
2064 //
2065 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2066 // => SAdd != OperandExtendedAdd
2067 //
2068 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2069 // (SAdd == OperandExtendedAdd => AR is NW)
2070
2071 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2072
2073 // Return the expression with the addrec on the outside.
2074 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2075 Depth + 1);
2076 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2077 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2078 }
2079 }
2080 }
2081
2082 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2083 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2084 if (AR->hasNoSignedWrap()) {
2085 // Same as nsw case above - duplicated here to avoid a compile time
2086 // issue. It's not clear that the order of checks does matter, but
2087 // it's one of two issue possible causes for a change which was
2088 // reverted. Be conservative for the moment.
2089 Start =
2090 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2091 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2092 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2093 }
2094
2095 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2096 // if D + (C - D + Step * n) could be proven to not signed wrap
2097 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2098 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2099 const APInt &C = SC->getAPInt();
2100 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2101 if (D != 0) {
2102 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2103 const SCEV *SResidual =
2104 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2105 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2106 return getAddExpr(SSExtD, SSExtR,
2108 Depth + 1);
2109 }
2110 }
2111
2112 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2113 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2114 Start =
2115 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2116 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2117 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2118 }
2119 }
2120
2121 // If the input value is provably positive and we could not simplify
2122 // away the sext build a zext instead.
2124 return getZeroExtendExpr(Op, Ty, Depth + 1);
2125
2126 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2127 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2128 if (isa<SCEVSMinExpr>(Op) || isa<SCEVSMaxExpr>(Op)) {
2129 auto *MinMax = cast<SCEVMinMaxExpr>(Op);
2131 for (auto *Operand : MinMax->operands())
2132 Operands.push_back(getSignExtendExpr(Operand, Ty));
2133 if (isa<SCEVSMinExpr>(MinMax))
2134 return getSMinExpr(Operands);
2135 return getSMaxExpr(Operands);
2136 }
2137
2138 // The cast wasn't folded; create an explicit cast node.
2139 // Recompute the insert position, as it may have been invalidated.
2140 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2141 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2142 Op, Ty);
2143 UniqueSCEVs.InsertNode(S, IP);
2144 registerUser(S, { Op });
2145 return S;
2146}
2147
2149 Type *Ty) {
2150 switch (Kind) {
2151 case scTruncate:
2152 return getTruncateExpr(Op, Ty);
2153 case scZeroExtend:
2154 return getZeroExtendExpr(Op, Ty);
2155 case scSignExtend:
2156 return getSignExtendExpr(Op, Ty);
2157 case scPtrToInt:
2158 return getPtrToIntExpr(Op, Ty);
2159 default:
2160 llvm_unreachable("Not a SCEV cast expression!");
2161 }
2162}
2163
2164/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2165/// unspecified bits out to the given type.
2167 Type *Ty) {
2168 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2169 "This is not an extending conversion!");
2170 assert(isSCEVable(Ty) &&
2171 "This is not a conversion to a SCEVable type!");
2172 Ty = getEffectiveSCEVType(Ty);
2173
2174 // Sign-extend negative constants.
2175 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2176 if (SC->getAPInt().isNegative())
2177 return getSignExtendExpr(Op, Ty);
2178
2179 // Peel off a truncate cast.
2180 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2181 const SCEV *NewOp = T->getOperand();
2182 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2183 return getAnyExtendExpr(NewOp, Ty);
2184 return getTruncateOrNoop(NewOp, Ty);
2185 }
2186
2187 // Next try a zext cast. If the cast is folded, use it.
2188 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2189 if (!isa<SCEVZeroExtendExpr>(ZExt))
2190 return ZExt;
2191
2192 // Next try a sext cast. If the cast is folded, use it.
2193 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2194 if (!isa<SCEVSignExtendExpr>(SExt))
2195 return SExt;
2196
2197 // Force the cast to be folded into the operands of an addrec.
2198 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2200 for (const SCEV *Op : AR->operands())
2201 Ops.push_back(getAnyExtendExpr(Op, Ty));
2202 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2203 }
2204
2205 // If the expression is obviously signed, use the sext cast value.
2206 if (isa<SCEVSMaxExpr>(Op))
2207 return SExt;
2208
2209 // Absent any other information, use the zext cast value.
2210 return ZExt;
2211}
2212
2213/// Process the given Ops list, which is a list of operands to be added under
2214/// the given scale, update the given map. This is a helper function for
2215/// getAddRecExpr. As an example of what it does, given a sequence of operands
2216/// that would form an add expression like this:
2217///
2218/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2219///
2220/// where A and B are constants, update the map with these values:
2221///
2222/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2223///
2224/// and add 13 + A*B*29 to AccumulatedConstant.
2225/// This will allow getAddRecExpr to produce this:
2226///
2227/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2228///
2229/// This form often exposes folding opportunities that are hidden in
2230/// the original operand list.
2231///
2232/// Return true iff it appears that any interesting folding opportunities
2233/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2234/// the common case where no interesting opportunities are present, and
2235/// is also used as a check to avoid infinite recursion.
2236static bool
2239 APInt &AccumulatedConstant,
2240 ArrayRef<const SCEV *> Ops, const APInt &Scale,
2241 ScalarEvolution &SE) {
2242 bool Interesting = false;
2243
2244 // Iterate over the add operands. They are sorted, with constants first.
2245 unsigned i = 0;
2246 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2247 ++i;
2248 // Pull a buried constant out to the outside.
2249 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2250 Interesting = true;
2251 AccumulatedConstant += Scale * C->getAPInt();
2252 }
2253
2254 // Next comes everything else. We're especially interested in multiplies
2255 // here, but they're in the middle, so just visit the rest with one loop.
2256 for (; i != Ops.size(); ++i) {
2257 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
2258 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2259 APInt NewScale =
2260 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2261 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2262 // A multiplication of a constant with another add; recurse.
2263 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2264 Interesting |=
2265 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2266 Add->operands(), NewScale, SE);
2267 } else {
2268 // A multiplication of a constant with some other value. Update
2269 // the map.
2270 SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2271 const SCEV *Key = SE.getMulExpr(MulOps);
2272 auto Pair = M.insert({Key, NewScale});
2273 if (Pair.second) {
2274 NewOps.push_back(Pair.first->first);
2275 } else {
2276 Pair.first->second += NewScale;
2277 // The map already had an entry for this value, which may indicate
2278 // a folding opportunity.
2279 Interesting = true;
2280 }
2281 }
2282 } else {
2283 // An ordinary operand. Update the map.
2284 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2285 M.insert({Ops[i], Scale});
2286 if (Pair.second) {
2287 NewOps.push_back(Pair.first->first);
2288 } else {
2289 Pair.first->second += Scale;
2290 // The map already had an entry for this value, which may indicate
2291 // a folding opportunity.
2292 Interesting = true;
2293 }
2294 }
2295 }
2296
2297 return Interesting;
2298}
2299
2301 const SCEV *LHS, const SCEV *RHS,
2302 const Instruction *CtxI) {
2303 const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2304 SCEV::NoWrapFlags, unsigned);
2305 switch (BinOp) {
2306 default:
2307 llvm_unreachable("Unsupported binary op");
2308 case Instruction::Add:
2310 break;
2311 case Instruction::Sub:
2313 break;
2314 case Instruction::Mul:
2316 break;
2317 }
2318
2319 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2322
2323 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2324 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2325 auto *WideTy =
2326 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2327
2328 const SCEV *A = (this->*Extension)(
2329 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2330 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2331 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2332 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2333 if (A == B)
2334 return true;
2335 // Can we use context to prove the fact we need?
2336 if (!CtxI)
2337 return false;
2338 // TODO: Support mul.
2339 if (BinOp == Instruction::Mul)
2340 return false;
2341 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2342 // TODO: Lift this limitation.
2343 if (!RHSC)
2344 return false;
2345 APInt C = RHSC->getAPInt();
2346 unsigned NumBits = C.getBitWidth();
2347 bool IsSub = (BinOp == Instruction::Sub);
2348 bool IsNegativeConst = (Signed && C.isNegative());
2349 // Compute the direction and magnitude by which we need to check overflow.
2350 bool OverflowDown = IsSub ^ IsNegativeConst;
2351 APInt Magnitude = C;
2352 if (IsNegativeConst) {
2353 if (C == APInt::getSignedMinValue(NumBits))
2354 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2355 // want to deal with that.
2356 return false;
2357 Magnitude = -C;
2358 }
2359
2361 if (OverflowDown) {
2362 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2363 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2364 : APInt::getMinValue(NumBits);
2365 APInt Limit = Min + Magnitude;
2366 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2367 } else {
2368 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2369 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2370 : APInt::getMaxValue(NumBits);
2371 APInt Limit = Max - Magnitude;
2372 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2373 }
2374}
2375
2376std::optional<SCEV::NoWrapFlags>
2378 const OverflowingBinaryOperator *OBO) {
2379 // It cannot be done any better.
2380 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2381 return std::nullopt;
2382
2384
2385 if (OBO->hasNoUnsignedWrap())
2387 if (OBO->hasNoSignedWrap())
2389
2390 bool Deduced = false;
2391
2392 if (OBO->getOpcode() != Instruction::Add &&
2393 OBO->getOpcode() != Instruction::Sub &&
2394 OBO->getOpcode() != Instruction::Mul)
2395 return std::nullopt;
2396
2397 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2398 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2399
2400 const Instruction *CtxI =
2401 UseContextForNoWrapFlagInference ? dyn_cast<Instruction>(OBO) : nullptr;
2402 if (!OBO->hasNoUnsignedWrap() &&
2404 /* Signed */ false, LHS, RHS, CtxI)) {
2406 Deduced = true;
2407 }
2408
2409 if (!OBO->hasNoSignedWrap() &&
2411 /* Signed */ true, LHS, RHS, CtxI)) {
2413 Deduced = true;
2414 }
2415
2416 if (Deduced)
2417 return Flags;
2418 return std::nullopt;
2419}
2420
2421// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2422// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2423// can't-overflow flags for the operation if possible.
2424static SCEV::NoWrapFlags
2426 const ArrayRef<const SCEV *> Ops,
2427 SCEV::NoWrapFlags Flags) {
2428 using namespace std::placeholders;
2429
2430 using OBO = OverflowingBinaryOperator;
2431
2432 bool CanAnalyze =
2434 (void)CanAnalyze;
2435 assert(CanAnalyze && "don't call from other places!");
2436
2437 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2438 SCEV::NoWrapFlags SignOrUnsignWrap =
2439 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2440
2441 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2442 auto IsKnownNonNegative = [&](const SCEV *S) {
2443 return SE->isKnownNonNegative(S);
2444 };
2445
2446 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2447 Flags =
2448 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2449
2450 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2451
2452 if (SignOrUnsignWrap != SignOrUnsignMask &&
2453 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2454 isa<SCEVConstant>(Ops[0])) {
2455
2456 auto Opcode = [&] {
2457 switch (Type) {
2458 case scAddExpr:
2459 return Instruction::Add;
2460 case scMulExpr:
2461 return Instruction::Mul;
2462 default:
2463 llvm_unreachable("Unexpected SCEV op.");
2464 }
2465 }();
2466
2467 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2468
2469 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2470 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2472 Opcode, C, OBO::NoSignedWrap);
2473 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2475 }
2476
2477 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2478 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2480 Opcode, C, OBO::NoUnsignedWrap);
2481 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2483 }
2484 }
2485
2486 // <0,+,nonnegative><nw> is also nuw
2487 // TODO: Add corresponding nsw case
2489 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2490 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2492
2493 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2495 Ops.size() == 2) {
2496 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2497 if (UDiv->getOperand(1) == Ops[1])
2499 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2500 if (UDiv->getOperand(1) == Ops[0])
2502 }
2503
2504 return Flags;
2505}
2506
2508 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2509}
2510
2511/// Get a canonical add expression, or something simpler if possible.
2513 SCEV::NoWrapFlags OrigFlags,
2514 unsigned Depth) {
2515 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2516 "only nuw or nsw allowed");
2517 assert(!Ops.empty() && "Cannot get empty add!");
2518 if (Ops.size() == 1) return Ops[0];
2519#ifndef NDEBUG
2520 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2521 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2522 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2523 "SCEVAddExpr operand types don't match!");
2524 unsigned NumPtrs = count_if(
2525 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2526 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2527#endif
2528
2529 // Sort by complexity, this groups all similar expression types together.
2530 GroupByComplexity(Ops, &LI, DT);
2531
2532 // If there are any constants, fold them together.
2533 unsigned Idx = 0;
2534 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2535 ++Idx;
2536 assert(Idx < Ops.size());
2537 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2538 // We found two constants, fold them together!
2539 Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt());
2540 if (Ops.size() == 2) return Ops[0];
2541 Ops.erase(Ops.begin()+1); // Erase the folded element
2542 LHSC = cast<SCEVConstant>(Ops[0]);
2543 }
2544
2545 // If we are left with a constant zero being added, strip it off.
2546 if (LHSC->getValue()->isZero()) {
2547 Ops.erase(Ops.begin());
2548 --Idx;
2549 }
2550
2551 if (Ops.size() == 1) return Ops[0];
2552 }
2553
2554 // Delay expensive flag strengthening until necessary.
2555 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
2556 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2557 };
2558
2559 // Limit recursion calls depth.
2561 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2562
2563 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2564 // Don't strengthen flags if we have no new information.
2565 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2566 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2567 Add->setNoWrapFlags(ComputeFlags(Ops));
2568 return S;
2569 }
2570
2571 // Okay, check to see if the same value occurs in the operand list more than
2572 // once. If so, merge them together into an multiply expression. Since we
2573 // sorted the list, these values are required to be adjacent.
2574 Type *Ty = Ops[0]->getType();
2575 bool FoundMatch = false;
2576 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2577 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2578 // Scan ahead to count how many equal operands there are.
2579 unsigned Count = 2;
2580 while (i+Count != e && Ops[i+Count] == Ops[i])
2581 ++Count;
2582 // Merge the values into a multiply.
2583 const SCEV *Scale = getConstant(Ty, Count);
2584 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2585 if (Ops.size() == Count)
2586 return Mul;
2587 Ops[i] = Mul;
2588 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2589 --i; e -= Count - 1;
2590 FoundMatch = true;
2591 }
2592 if (FoundMatch)
2593 return getAddExpr(Ops, OrigFlags, Depth + 1);
2594
2595 // Check for truncates. If all the operands are truncated from the same
2596 // type, see if factoring out the truncate would permit the result to be
2597 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2598 // if the contents of the resulting outer trunc fold to something simple.
2599 auto FindTruncSrcType = [&]() -> Type * {
2600 // We're ultimately looking to fold an addrec of truncs and muls of only
2601 // constants and truncs, so if we find any other types of SCEV
2602 // as operands of the addrec then we bail and return nullptr here.
2603 // Otherwise, we return the type of the operand of a trunc that we find.
2604 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2605 return T->getOperand()->getType();
2606 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2607 const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2608 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2609 return T->getOperand()->getType();
2610 }
2611 return nullptr;
2612 };
2613 if (auto *SrcType = FindTruncSrcType()) {
2615 bool Ok = true;
2616 // Check all the operands to see if they can be represented in the
2617 // source type of the truncate.
2618 for (const SCEV *Op : Ops) {
2619 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2620 if (T->getOperand()->getType() != SrcType) {
2621 Ok = false;
2622 break;
2623 }
2624 LargeOps.push_back(T->getOperand());
2625 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2626 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2627 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2628 SmallVector<const SCEV *, 8> LargeMulOps;
2629 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2630 if (const SCEVTruncateExpr *T =
2631 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2632 if (T->getOperand()->getType() != SrcType) {
2633 Ok = false;
2634 break;
2635 }
2636 LargeMulOps.push_back(T->getOperand());
2637 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2638 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2639 } else {
2640 Ok = false;
2641 break;
2642 }
2643 }
2644 if (Ok)
2645 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2646 } else {
2647 Ok = false;
2648 break;
2649 }
2650 }
2651 if (Ok) {
2652 // Evaluate the expression in the larger type.
2653 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2654 // If it folds to something simple, use it. Otherwise, don't.
2655 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2656 return getTruncateExpr(Fold, Ty);
2657 }
2658 }
2659
2660 if (Ops.size() == 2) {
2661 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2662 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2663 // C1).
2664 const SCEV *A = Ops[0];
2665 const SCEV *B = Ops[1];
2666 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2667 auto *C = dyn_cast<SCEVConstant>(A);
2668 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2669 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2670 auto C2 = C->getAPInt();
2671 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2672
2673 APInt ConstAdd = C1 + C2;
2674 auto AddFlags = AddExpr->getNoWrapFlags();
2675 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2677 ConstAdd.ule(C1)) {
2678 PreservedFlags =
2680 }
2681
2682 // Adding a constant with the same sign and small magnitude is NSW, if the
2683 // original AddExpr was NSW.
2685 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2686 ConstAdd.abs().ule(C1.abs())) {
2687 PreservedFlags =
2689 }
2690
2691 if (PreservedFlags != SCEV::FlagAnyWrap) {
2692 SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2693 NewOps[0] = getConstant(ConstAdd);
2694 return getAddExpr(NewOps, PreservedFlags);
2695 }
2696 }
2697 }
2698
2699 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2700 if (Ops.size() == 2) {
2701 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
2702 if (Mul && Mul->getNumOperands() == 2 &&
2703 Mul->getOperand(0)->isAllOnesValue()) {
2704 const SCEV *X;
2705 const SCEV *Y;
2706 if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
2707 return getMulExpr(Y, getUDivExpr(X, Y));
2708 }
2709 }
2710 }
2711
2712 // Skip past any other cast SCEVs.
2713 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2714 ++Idx;
2715
2716 // If there are add operands they would be next.
2717 if (Idx < Ops.size()) {
2718 bool DeletedAdd = false;
2719 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2720 // common NUW flag for expression after inlining. Other flags cannot be
2721 // preserved, because they may depend on the original order of operations.
2722 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2723 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2724 if (Ops.size() > AddOpsInlineThreshold ||
2725 Add->getNumOperands() > AddOpsInlineThreshold)
2726 break;
2727 // If we have an add, expand the add operands onto the end of the operands
2728 // list.
2729 Ops.erase(Ops.begin()+Idx);
2730 append_range(Ops, Add->operands());
2731 DeletedAdd = true;
2732 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2733 }
2734
2735 // If we deleted at least one add, we added operands to the end of the list,
2736 // and they are not necessarily sorted. Recurse to resort and resimplify
2737 // any operands we just acquired.
2738 if (DeletedAdd)
2739 return getAddExpr(Ops, CommonFlags, Depth + 1);
2740 }
2741
2742 // Skip over the add expression until we get to a multiply.
2743 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2744 ++Idx;
2745
2746 // Check to see if there are any folding opportunities present with
2747 // operands multiplied by constant values.
2748 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2752 APInt AccumulatedConstant(BitWidth, 0);
2753 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2754 Ops, APInt(BitWidth, 1), *this)) {
2755 struct APIntCompare {
2756 bool operator()(const APInt &LHS, const APInt &RHS) const {
2757 return LHS.ult(RHS);
2758 }
2759 };
2760
2761 // Some interesting folding opportunity is present, so its worthwhile to
2762 // re-generate the operands list. Group the operands by constant scale,
2763 // to avoid multiplying by the same constant scale multiple times.
2764 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2765 for (const SCEV *NewOp : NewOps)
2766 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2767 // Re-generate the operands list.
2768 Ops.clear();
2769 if (AccumulatedConstant != 0)
2770 Ops.push_back(getConstant(AccumulatedConstant));
2771 for (auto &MulOp : MulOpLists) {
2772 if (MulOp.first == 1) {
2773 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2774 } else if (MulOp.first != 0) {
2776 getConstant(MulOp.first),
2777 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2778 SCEV::FlagAnyWrap, Depth + 1));
2779 }
2780 }
2781 if (Ops.empty())
2782 return getZero(Ty);
2783 if (Ops.size() == 1)
2784 return Ops[0];
2785 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2786 }
2787 }
2788
2789 // If we are adding something to a multiply expression, make sure the
2790 // something is not already an operand of the multiply. If so, merge it into
2791 // the multiply.
2792 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2793 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2794 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2795 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2796 if (isa<SCEVConstant>(MulOpSCEV))
2797 continue;
2798 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2799 if (MulOpSCEV == Ops[AddOp]) {
2800 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2801 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2802 if (Mul->getNumOperands() != 2) {
2803 // If the multiply has more than two operands, we must get the
2804 // Y*Z term.
2806 Mul->operands().take_front(MulOp));
2807 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2808 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2809 }
2810 SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2811 const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2812 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2814 if (Ops.size() == 2) return OuterMul;
2815 if (AddOp < Idx) {
2816 Ops.erase(Ops.begin()+AddOp);
2817 Ops.erase(Ops.begin()+Idx-1);
2818 } else {
2819 Ops.erase(Ops.begin()+Idx);
2820 Ops.erase(Ops.begin()+AddOp-1);
2821 }
2822 Ops.push_back(OuterMul);
2823 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2824 }
2825
2826 // Check this multiply against other multiplies being added together.
2827 for (unsigned OtherMulIdx = Idx+1;
2828 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2829 ++OtherMulIdx) {
2830 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2831 // If MulOp occurs in OtherMul, we can fold the two multiplies
2832 // together.
2833 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2834 OMulOp != e; ++OMulOp)
2835 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2836 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2837 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2838 if (Mul->getNumOperands() != 2) {
2840 Mul->operands().take_front(MulOp));
2841 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2842 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2843 }
2844 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2845 if (OtherMul->getNumOperands() != 2) {
2847 OtherMul->operands().take_front(OMulOp));
2848 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2849 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2850 }
2851 SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2852 const SCEV *InnerMulSum =
2853 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2854 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2856 if (Ops.size() == 2) return OuterMul;
2857 Ops.erase(Ops.begin()+Idx);
2858 Ops.erase(Ops.begin()+OtherMulIdx-1);
2859 Ops.push_back(OuterMul);
2860 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2861 }
2862 }
2863 }
2864 }
2865
2866 // If there are any add recurrences in the operands list, see if any other
2867 // added values are loop invariant. If so, we can fold them into the
2868 // recurrence.
2869 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2870 ++Idx;
2871
2872 // Scan over all recurrences, trying to fold loop invariants into them.
2873 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2874 // Scan all of the other operands to this add and add them to the vector if
2875 // they are loop invariant w.r.t. the recurrence.
2877 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2878 const Loop *AddRecLoop = AddRec->getLoop();
2879 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2880 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2881 LIOps.push_back(Ops[i]);
2882 Ops.erase(Ops.begin()+i);
2883 --i; --e;
2884 }
2885
2886 // If we found some loop invariants, fold them into the recurrence.
2887 if (!LIOps.empty()) {
2888 // Compute nowrap flags for the addition of the loop-invariant ops and
2889 // the addrec. Temporarily push it as an operand for that purpose. These
2890 // flags are valid in the scope of the addrec only.
2891 LIOps.push_back(AddRec);
2892 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2893 LIOps.pop_back();
2894
2895 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2896 LIOps.push_back(AddRec->getStart());
2897
2898 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2899
2900 // It is not in general safe to propagate flags valid on an add within
2901 // the addrec scope to one outside it. We must prove that the inner
2902 // scope is guaranteed to execute if the outer one does to be able to
2903 // safely propagate. We know the program is undefined if poison is
2904 // produced on the inner scoped addrec. We also know that *for this use*
2905 // the outer scoped add can't overflow (because of the flags we just
2906 // computed for the inner scoped add) without the program being undefined.
2907 // Proving that entry to the outer scope neccesitates entry to the inner
2908 // scope, thus proves the program undefined if the flags would be violated
2909 // in the outer scope.
2910 SCEV::NoWrapFlags AddFlags = Flags;
2911 if (AddFlags != SCEV::FlagAnyWrap) {
2912 auto *DefI = getDefiningScopeBound(LIOps);
2913 auto *ReachI = &*AddRecLoop->getHeader()->begin();
2914 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2915 AddFlags = SCEV::FlagAnyWrap;
2916 }
2917 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2918
2919 // Build the new addrec. Propagate the NUW and NSW flags if both the
2920 // outer add and the inner addrec are guaranteed to have no overflow.
2921 // Always propagate NW.
2922 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2923 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2924
2925 // If all of the other operands were loop invariant, we are done.
2926 if (Ops.size() == 1) return NewRec;
2927
2928 // Otherwise, add the folded AddRec by the non-invariant parts.
2929 for (unsigned i = 0;; ++i)
2930 if (Ops[i] == AddRec) {
2931 Ops[i] = NewRec;
2932 break;
2933 }
2934 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2935 }
2936
2937 // Okay, if there weren't any loop invariants to be folded, check to see if
2938 // there are multiple AddRec's with the same loop induction variable being
2939 // added together. If so, we can fold them.
2940 for (unsigned OtherIdx = Idx+1;
2941 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2942 ++OtherIdx) {
2943 // We expect the AddRecExpr's to be sorted in reverse dominance order,
2944 // so that the 1st found AddRecExpr is dominated by all others.
2945 assert(DT.dominates(
2946 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2947 AddRec->getLoop()->getHeader()) &&
2948 "AddRecExprs are not sorted in reverse dominance order?");
2949 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2950 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2951 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2952 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2953 ++OtherIdx) {
2954 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2955 if (OtherAddRec->getLoop() == AddRecLoop) {
2956 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2957 i != e; ++i) {
2958 if (i >= AddRecOps.size()) {
2959 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
2960 break;
2961 }
2963 AddRecOps[i], OtherAddRec->getOperand(i)};
2964 AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2965 }
2966 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2967 }
2968 }
2969 // Step size has changed, so we cannot guarantee no self-wraparound.
2970 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2971 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2972 }
2973 }
2974
2975 // Otherwise couldn't fold anything into this recurrence. Move onto the
2976 // next one.
2977 }
2978
2979 // Okay, it looks like we really DO need an add expr. Check to see if we
2980 // already have one, otherwise create a new one.
2981 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2982}
2983
2984const SCEV *
2985ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2986 SCEV::NoWrapFlags Flags) {
2988 ID.AddInteger(scAddExpr);
2989 for (const SCEV *Op : Ops)
2990 ID.AddPointer(Op);
2991 void *IP = nullptr;
2992 SCEVAddExpr *S =
2993 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2994 if (!S) {
2995 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2996 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2997 S = new (SCEVAllocator)
2998 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
2999 UniqueSCEVs.InsertNode(S, IP);
3000 registerUser(S, Ops);
3001 }
3002 S->setNoWrapFlags(Flags);
3003 return S;
3004}
3005
3006const SCEV *
3007ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
3008 const Loop *L, SCEV::NoWrapFlags Flags) {
3010 ID.AddInteger(scAddRecExpr);
3011 for (const SCEV *Op : Ops)
3012 ID.AddPointer(Op);
3013 ID.AddPointer(L);
3014 void *IP = nullptr;
3015 SCEVAddRecExpr *S =
3016 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3017 if (!S) {
3018 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3019 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3020 S = new (SCEVAllocator)
3021 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3022 UniqueSCEVs.InsertNode(S, IP);
3023 LoopUsers[L].push_back(S);
3024 registerUser(S, Ops);
3025 }
3026 setNoWrapFlags(S, Flags);
3027 return S;
3028}
3029
3030const SCEV *
3031ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
3032 SCEV::NoWrapFlags Flags) {
3034 ID.AddInteger(scMulExpr);
3035 for (const SCEV *Op : Ops)
3036 ID.AddPointer(Op);
3037 void *IP = nullptr;
3038 SCEVMulExpr *S =
3039 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3040 if (!S) {
3041 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3042 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3043 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3044 O, Ops.size());
3045 UniqueSCEVs.InsertNode(S, IP);
3046 registerUser(S, Ops);
3047 }
3048 S->setNoWrapFlags(Flags);
3049 return S;
3050}
3051
3052static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3053 uint64_t k = i*j;
3054 if (j > 1 && k / j != i) Overflow = true;
3055 return k;
3056}
3057
3058/// Compute the result of "n choose k", the binomial coefficient. If an
3059/// intermediate computation overflows, Overflow will be set and the return will
3060/// be garbage. Overflow is not cleared on absence of overflow.
3061static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3062 // We use the multiplicative formula:
3063 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3064 // At each iteration, we take the n-th term of the numeral and divide by the
3065 // (k-n)th term of the denominator. This division will always produce an
3066 // integral result, and helps reduce the chance of overflow in the
3067 // intermediate computations. However, we can still overflow even when the
3068 // final result would fit.
3069
3070 if (n == 0 || n == k) return 1;
3071 if (k > n) return 0;
3072
3073 if (k > n/2)
3074 k = n-k;
3075
3076 uint64_t r = 1;
3077 for (uint64_t i = 1; i <= k; ++i) {
3078 r = umul_ov(r, n-(i-1), Overflow);
3079 r /= i;
3080 }
3081 return r;
3082}
3083
3084/// Determine if any of the operands in this SCEV are a constant or if
3085/// any of the add or multiply expressions in this SCEV contain a constant.
3086static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3087 struct FindConstantInAddMulChain {
3088 bool FoundConstant = false;
3089
3090 bool follow(const SCEV *S) {
3091 FoundConstant |= isa<SCEVConstant>(S);
3092 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3093 }
3094
3095 bool isDone() const {
3096 return FoundConstant;
3097 }
3098 };
3099
3100 FindConstantInAddMulChain F;
3102 ST.visitAll(StartExpr);
3103 return F.FoundConstant;
3104}
3105
3106/// Get a canonical multiply expression, or something simpler if possible.
3108 SCEV::NoWrapFlags OrigFlags,
3109 unsigned Depth) {
3110 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3111 "only nuw or nsw allowed");
3112 assert(!Ops.empty() && "Cannot get empty mul!");
3113 if (Ops.size() == 1) return Ops[0];
3114#ifndef NDEBUG
3115 Type *ETy = Ops[0]->getType();
3116 assert(!ETy->isPointerTy());
3117 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3118 assert(Ops[i]->getType() == ETy &&
3119 "SCEVMulExpr operand types don't match!");
3120#endif
3121
3122 // Sort by complexity, this groups all similar expression types together.
3123 GroupByComplexity(Ops, &LI, DT);
3124
3125 // If there are any constants, fold them together.
3126 unsigned Idx = 0;
3127 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3128 ++Idx;
3129 assert(Idx < Ops.size());
3130 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3131 // We found two constants, fold them together!
3132 Ops[0] = getConstant(LHSC->getAPInt() * RHSC->getAPInt());
3133 if (Ops.size() == 2) return Ops[0];
3134 Ops.erase(Ops.begin()+1); // Erase the folded element
3135 LHSC = cast<SCEVConstant>(Ops[0]);
3136 }
3137
3138 // If we have a multiply of zero, it will always be zero.
3139 if (LHSC->getValue()->isZero())
3140 return LHSC;
3141
3142 // If we are left with a constant one being multiplied, strip it off.
3143 if (LHSC->getValue()->isOne()) {
3144 Ops.erase(Ops.begin());
3145 --Idx;
3146 }
3147
3148 if (Ops.size() == 1)
3149 return Ops[0];
3150 }
3151
3152 // Delay expensive flag strengthening until necessary.
3153 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
3154 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3155 };
3156
3157 // Limit recursion calls depth.
3159 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3160
3161 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3162 // Don't strengthen flags if we have no new information.
3163 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3164 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3165 Mul->setNoWrapFlags(ComputeFlags(Ops));
3166 return S;
3167 }
3168
3169 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3170 if (Ops.size() == 2) {
3171 // C1*(C2+V) -> C1*C2 + C1*V
3172 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
3173 // If any of Add's ops are Adds or Muls with a constant, apply this
3174 // transformation as well.
3175 //
3176 // TODO: There are some cases where this transformation is not
3177 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3178 // this transformation should be narrowed down.
3179 if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) {
3180 const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0),
3182 const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1),
3184 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3185 }
3186
3187 if (Ops[0]->isAllOnesValue()) {
3188 // If we have a mul by -1 of an add, try distributing the -1 among the
3189 // add operands.
3190 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3192 bool AnyFolded = false;
3193 for (const SCEV *AddOp : Add->operands()) {
3194 const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3195 Depth + 1);
3196 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3197 NewOps.push_back(Mul);
3198 }
3199 if (AnyFolded)
3200 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3201 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3202 // Negation preserves a recurrence's no self-wrap property.
3204 for (const SCEV *AddRecOp : AddRec->operands())
3205 Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3206 Depth + 1));
3207 // Let M be the minimum representable signed value. AddRec with nsw
3208 // multiplied by -1 can have signed overflow if and only if it takes a
3209 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3210 // maximum signed value. In all other cases signed overflow is
3211 // impossible.
3212 auto FlagsMask = SCEV::FlagNW;
3213 if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) {
3214 auto MinInt =
3215 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3216 if (getSignedRangeMin(AddRec) != MinInt)
3217 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3218 }
3219 return getAddRecExpr(Operands, AddRec->getLoop(),
3220 AddRec->getNoWrapFlags(FlagsMask));
3221 }
3222 }
3223 }
3224 }
3225
3226 // Skip over the add expression until we get to a multiply.
3227 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3228 ++Idx;
3229
3230 // If there are mul operands inline them all into this expression.
3231 if (Idx < Ops.size()) {
3232 bool DeletedMul = false;
3233 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3234 if (Ops.size() > MulOpsInlineThreshold)
3235 break;
3236 // If we have an mul, expand the mul operands onto the end of the
3237 // operands list.
3238 Ops.erase(Ops.begin()+Idx);
3239 append_range(Ops, Mul->operands());
3240 DeletedMul = true;
3241 }
3242
3243 // If we deleted at least one mul, we added operands to the end of the
3244 // list, and they are not necessarily sorted. Recurse to resort and
3245 // resimplify any operands we just acquired.
3246 if (DeletedMul)
3247 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3248 }
3249
3250 // If there are any add recurrences in the operands list, see if any other
3251 // added values are loop invariant. If so, we can fold them into the
3252 // recurrence.
3253 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3254 ++Idx;
3255
3256 // Scan over all recurrences, trying to fold loop invariants into them.
3257 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3258 // Scan all of the other operands to this mul and add them to the vector
3259 // if they are loop invariant w.r.t. the recurrence.
3261 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3262 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3263 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3264 LIOps.push_back(Ops[i]);
3265 Ops.erase(Ops.begin()+i);
3266 --i; --e;
3267 }
3268
3269 // If we found some loop invariants, fold them into the recurrence.
3270 if (!LIOps.empty()) {
3271 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3273 NewOps.reserve(AddRec->getNumOperands());
3274 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3275
3276 // If both the mul and addrec are nuw, we can preserve nuw.
3277 // If both the mul and addrec are nsw, we can only preserve nsw if either
3278 // a) they are also nuw, or
3279 // b) all multiplications of addrec operands with scale are nsw.
3280 SCEV::NoWrapFlags Flags =
3281 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3282
3283 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3284 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3285 SCEV::FlagAnyWrap, Depth + 1));
3286
3287 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3289 Instruction::Mul, getSignedRange(Scale),
3291 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3292 Flags = clearFlags(Flags, SCEV::FlagNSW);
3293 }
3294 }
3295
3296 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3297
3298 // If all of the other operands were loop invariant, we are done.
3299 if (Ops.size() == 1) return NewRec;
3300
3301 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3302 for (unsigned i = 0;; ++i)
3303 if (Ops[i] == AddRec) {
3304 Ops[i] = NewRec;
3305 break;
3306 }
3307 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3308 }
3309
3310 // Okay, if there weren't any loop invariants to be folded, check to see
3311 // if there are multiple AddRec's with the same loop induction variable
3312 // being multiplied together. If so, we can fold them.
3313
3314 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3315 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3316 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3317 // ]]],+,...up to x=2n}.
3318 // Note that the arguments to choose() are always integers with values
3319 // known at compile time, never SCEV objects.
3320 //
3321 // The implementation avoids pointless extra computations when the two
3322 // addrec's are of different length (mathematically, it's equivalent to
3323 // an infinite stream of zeros on the right).
3324 bool OpsModified = false;
3325 for (unsigned OtherIdx = Idx+1;
3326 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3327 ++OtherIdx) {
3328 const SCEVAddRecExpr *OtherAddRec =
3329 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3330 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3331 continue;
3332
3333 // Limit max number of arguments to avoid creation of unreasonably big
3334 // SCEVAddRecs with very complex operands.
3335 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3336 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3337 continue;
3338
3339 bool Overflow = false;
3340 Type *Ty = AddRec->getType();
3341 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3343 for (int x = 0, xe = AddRec->getNumOperands() +
3344 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3345 SmallVector <const SCEV *, 7> SumOps;
3346 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3347 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3348 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3349 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3350 z < ze && !Overflow; ++z) {
3351 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3352 uint64_t Coeff;
3353 if (LargerThan64Bits)
3354 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3355 else
3356 Coeff = Coeff1*Coeff2;
3357 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3358 const SCEV *Term1 = AddRec->getOperand(y-z);
3359 const SCEV *Term2 = OtherAddRec->getOperand(z);
3360 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3361 SCEV::FlagAnyWrap, Depth + 1));
3362 }
3363 }
3364 if (SumOps.empty())
3365 SumOps.push_back(getZero(Ty));
3366 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3367 }
3368 if (!Overflow) {
3369 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3371 if (Ops.size() == 2) return NewAddRec;
3372 Ops[Idx] = NewAddRec;
3373 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3374 OpsModified = true;
3375 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3376 if (!AddRec)
3377 break;
3378 }
3379 }
3380 if (OpsModified)
3381 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3382
3383 // Otherwise couldn't fold anything into this recurrence. Move onto the
3384 // next one.
3385 }
3386
3387 // Okay, it looks like we really DO need an mul expr. Check to see if we
3388 // already have one, otherwise create a new one.
3389 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3390}
3391
3392/// Represents an unsigned remainder expression based on unsigned division.
3394 const SCEV *RHS) {
3397 "SCEVURemExpr operand types don't match!");
3398
3399 // Short-circuit easy cases
3400 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3401 // If constant is one, the result is trivial
3402 if (RHSC->getValue()->isOne())
3403 return getZero(LHS->getType()); // X urem 1 --> 0
3404
3405 // If constant is a power of two, fold into a zext(trunc(LHS)).
3406 if (RHSC->getAPInt().isPowerOf2()) {
3407 Type *FullTy = LHS->getType();
3408 Type *TruncTy =
3409 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3410 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3411 }
3412 }
3413
3414 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3415 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3416 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3417 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3418}
3419
3420/// Get a canonical unsigned division expression, or something simpler if
3421/// possible.
3423 const SCEV *RHS) {
3424 assert(!LHS->getType()->isPointerTy() &&
3425 "SCEVUDivExpr operand can't be pointer!");
3426 assert(LHS->getType() == RHS->getType() &&
3427 "SCEVUDivExpr operand types don't match!");
3428
3430 ID.AddInteger(scUDivExpr);
3431 ID.AddPointer(LHS);
3432 ID.AddPointer(RHS);
3433 void *IP = nullptr;
3434 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3435 return S;
3436
3437 // 0 udiv Y == 0
3438 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3439 if (LHSC->getValue()->isZero())
3440 return LHS;
3441
3442 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3443 if (RHSC->getValue()->isOne())
3444 return LHS; // X udiv 1 --> x
3445 // If the denominator is zero, the result of the udiv is undefined. Don't
3446 // try to analyze it, because the resolution chosen here may differ from
3447 // the resolution chosen in other parts of the compiler.
3448 if (!RHSC->getValue()->isZero()) {
3449 // Determine if the division can be folded into the operands of
3450 // its operands.
3451 // TODO: Generalize this to non-constants by using known-bits information.
3452 Type *Ty = LHS->getType();
3453 unsigned LZ = RHSC->getAPInt().countl_zero();
3454 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3455 // For non-power-of-two values, effectively round the value up to the
3456 // nearest power of two.
3457 if (!RHSC->getAPInt().isPowerOf2())
3458 ++MaxShiftAmt;
3459 IntegerType *ExtTy =
3460 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3461 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3462 if (const SCEVConstant *Step =
3463 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3464 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3465 const APInt &StepInt = Step->getAPInt();
3466 const APInt &DivInt = RHSC->getAPInt();
3467 if (!StepInt.urem(DivInt) &&
3468 getZeroExtendExpr(AR, ExtTy) ==
3469 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3470 getZeroExtendExpr(Step, ExtTy),
3471 AR->getLoop(), SCEV::FlagAnyWrap)) {
3473 for (const SCEV *Op : AR->operands())
3474 Operands.push_back(getUDivExpr(Op, RHS));
3475 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3476 }
3477 /// Get a canonical UDivExpr for a recurrence.
3478 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3479 // We can currently only fold X%N if X is constant.
3480 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
3481 if (StartC && !DivInt.urem(StepInt) &&
3482 getZeroExtendExpr(AR, ExtTy) ==
3483 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3484 getZeroExtendExpr(Step, ExtTy),
3485 AR->getLoop(), SCEV::FlagAnyWrap)) {
3486 const APInt &StartInt = StartC->getAPInt();
3487 const APInt &StartRem = StartInt.urem(StepInt);
3488 if (StartRem != 0) {
3489 const SCEV *NewLHS =
3490 getAddRecExpr(getConstant(StartInt - StartRem), Step,
3491 AR->getLoop(), SCEV::FlagNW);
3492 if (LHS != NewLHS) {
3493 LHS = NewLHS;
3494
3495 // Reset the ID to include the new LHS, and check if it is
3496 // already cached.
3497 ID.clear();
3498 ID.AddInteger(scUDivExpr);
3499 ID.AddPointer(LHS);
3500 ID.AddPointer(RHS);
3501 IP = nullptr;
3502 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3503 return S;
3504 }
3505 }
3506 }
3507 }
3508 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3509 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3511 for (const SCEV *Op : M->operands())
3512 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3513 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3514 // Find an operand that's safely divisible.
3515 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3516 const SCEV *Op = M->getOperand(i);
3517 const SCEV *Div = getUDivExpr(Op, RHSC);
3518 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3519 Operands = SmallVector<const SCEV *, 4>(M->operands());
3520 Operands[i] = Div;
3521 return getMulExpr(Operands);
3522 }
3523 }
3524 }
3525
3526 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3527 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3528 if (auto *DivisorConstant =
3529 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3530 bool Overflow = false;
3531 APInt NewRHS =
3532 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3533 if (Overflow) {
3534 return getConstant(RHSC->getType(), 0, false);
3535 }
3536 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3537 }
3538 }
3539
3540 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3541 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3543 for (const SCEV *Op : A->operands())
3544 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3545 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3546 Operands.clear();
3547 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3548 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3549 if (isa<SCEVUDivExpr>(Op) ||
3550 getMulExpr(Op, RHS) != A->getOperand(i))
3551 break;
3552 Operands.push_back(Op);
3553 }
3554 if (Operands.size() == A->getNumOperands())
3555 return getAddExpr(Operands);
3556 }
3557 }
3558
3559 // Fold if both operands are constant.
3560 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3561 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3562 }
3563 }
3564
3565 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3566 // changes). Make sure we get a new one.
3567 IP = nullptr;
3568 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3569 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3570 LHS, RHS);
3571 UniqueSCEVs.InsertNode(S, IP);
3572 registerUser(S, {LHS, RHS});
3573 return S;
3574}
3575
3576APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3577 APInt A = C1->getAPInt().abs();
3578 APInt B = C2->getAPInt().abs();
3579 uint32_t ABW = A.getBitWidth();
3580 uint32_t BBW = B.getBitWidth();
3581
3582 if (ABW > BBW)
3583 B = B.zext(ABW);
3584 else if (ABW < BBW)
3585 A = A.zext(BBW);
3586
3587 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3588}
3589
3590/// Get a canonical unsigned division expression, or something simpler if
3591/// possible. There is no representation for an exact udiv in SCEV IR, but we
3592/// can attempt to remove factors from the LHS and RHS. We can't do this when
3593/// it's not exact because the udiv may be clearing bits.
3595 const SCEV *RHS) {
3596 // TODO: we could try to find factors in all sorts of things, but for now we
3597 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3598 // end of this file for inspiration.
3599
3600 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
3601 if (!Mul || !Mul->hasNoUnsignedWrap())
3602 return getUDivExpr(LHS, RHS);
3603
3604 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3605 // If the mulexpr multiplies by a constant, then that constant must be the
3606 // first element of the mulexpr.
3607 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3608 if (LHSCst == RHSCst) {
3610 return getMulExpr(Operands);
3611 }
3612
3613 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3614 // that there's a factor provided by one of the other terms. We need to
3615 // check.
3616 APInt Factor = gcd(LHSCst, RHSCst);
3617 if (!Factor.isIntN(1)) {
3618 LHSCst =
3619 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3620 RHSCst =
3621 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3623 Operands.push_back(LHSCst);
3624 append_range(Operands, Mul->operands().drop_front());
3626 RHS = RHSCst;
3627 Mul = dyn_cast<SCEVMulExpr>(LHS);
3628 if (!Mul)
3629 return getUDivExactExpr(LHS, RHS);
3630 }
3631 }
3632 }
3633
3634 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3635 if (Mul->getOperand(i) == RHS) {
3637 append_range(Operands, Mul->operands().take_front(i));
3638 append_range(Operands, Mul->operands().drop_front(i + 1));
3639 return getMulExpr(Operands);
3640 }
3641 }
3642
3643 return getUDivExpr(LHS, RHS);
3644}
3645
3646/// Get an add recurrence expression for the specified loop. Simplify the
3647/// expression as much as possible.
3648const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3649 const Loop *L,
3650 SCEV::NoWrapFlags Flags) {
3652 Operands.push_back(Start);
3653 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3654 if (StepChrec->getLoop() == L) {
3655 append_range(Operands, StepChrec->operands());
3656 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3657 }
3658
3659 Operands.push_back(Step);
3660 return getAddRecExpr(Operands, L, Flags);
3661}
3662
3663/// Get an add recurrence expression for the specified loop. Simplify the
3664/// expression as much as possible.
3665const SCEV *
3667 const Loop *L, SCEV::NoWrapFlags Flags) {
3668 if (Operands.size() == 1) return Operands[0];
3669#ifndef NDEBUG
3671 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3672 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3673 "SCEVAddRecExpr operand types don't match!");
3674 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3675 }
3676 for (const SCEV *Op : Operands)
3678 "SCEVAddRecExpr operand is not available at loop entry!");
3679#endif
3680
3681 if (Operands.back()->isZero()) {
3682 Operands.pop_back();
3683 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3684 }
3685
3686 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3687 // use that information to infer NUW and NSW flags. However, computing a
3688 // BE count requires calling getAddRecExpr, so we may not yet have a
3689 // meaningful BE count at this point (and if we don't, we'd be stuck
3690 // with a SCEVCouldNotCompute as the cached BE count).
3691
3692 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3693
3694 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3695 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3696 const Loop *NestedLoop = NestedAR->getLoop();
3697 if (L->contains(NestedLoop)
3698 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3699 : (!NestedLoop->contains(L) &&
3700 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3701 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3702 Operands[0] = NestedAR->getStart();
3703 // AddRecs require their operands be loop-invariant with respect to their
3704 // loops. Don't perform this transformation if it would break this
3705 // requirement.
3706 bool AllInvariant = all_of(
3707 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3708
3709 if (AllInvariant) {
3710 // Create a recurrence for the outer loop with the same step size.
3711 //
3712 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3713 // inner recurrence has the same property.
3714 SCEV::NoWrapFlags OuterFlags =
3715 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3716
3717 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3718 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3719 return isLoopInvariant(Op, NestedLoop);
3720 });
3721
3722 if (AllInvariant) {
3723 // Ok, both add recurrences are valid after the transformation.
3724 //
3725 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3726 // the outer recurrence has the same property.
3727 SCEV::NoWrapFlags InnerFlags =
3728 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3729 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3730 }
3731 }
3732 // Reset Operands to its original state.
3733 Operands[0] = NestedAR;
3734 }
3735 }
3736
3737 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3738 // already have one, otherwise create a new one.
3739 return getOrCreateAddRecExpr(Operands, L, Flags);
3740}
3741
3742const SCEV *
3744 const SmallVectorImpl<const SCEV *> &IndexExprs) {
3745 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3746 // getSCEV(Base)->getType() has the same address space as Base->getType()
3747 // because SCEV::getType() preserves the address space.
3748 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3749 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3750 if (NW != GEPNoWrapFlags::none()) {
3751 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3752 // but to do that, we have to ensure that said flag is valid in the entire
3753 // defined scope of the SCEV.
3754 // TODO: non-instructions have global scope. We might be able to prove
3755 // some global scope cases
3756 auto *GEPI = dyn_cast<Instruction>(GEP);
3757 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3758 NW = GEPNoWrapFlags::none();
3759 }
3760
3762 if (NW.hasNoUnsignedSignedWrap())
3763 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3764 if (NW.hasNoUnsignedWrap())
3765 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3766
3767 Type *CurTy = GEP->getType();
3768 bool FirstIter = true;
3770 for (const SCEV *IndexExpr : IndexExprs) {
3771 // Compute the (potentially symbolic) offset in bytes for this index.
3772 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3773 // For a struct, add the member offset.
3774 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3775 unsigned FieldNo = Index->getZExtValue();
3776 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3777 Offsets.push_back(FieldOffset);
3778
3779 // Update CurTy to the type of the field at Index.
3780 CurTy = STy->getTypeAtIndex(Index);
3781 } else {
3782 // Update CurTy to its element type.
3783 if (FirstIter) {
3784 assert(isa<PointerType>(CurTy) &&
3785 "The first index of a GEP indexes a pointer");
3786 CurTy = GEP->getSourceElementType();
3787 FirstIter = false;
3788 } else {
3790 }
3791 // For an array, add the element offset, explicitly scaled.
3792 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3793 // Getelementptr indices are signed.
3794 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3795
3796 // Multiply the index by the element size to compute the element offset.
3797 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3798 Offsets.push_back(LocalOffset);
3799 }
3800 }
3801
3802 // Handle degenerate case of GEP without offsets.
3803 if (Offsets.empty())
3804 return BaseExpr;
3805
3806 // Add the offsets together, assuming nsw if inbounds.
3807 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3808 // Add the base address and the offset. We cannot use the nsw flag, as the
3809 // base address is unsigned. However, if we know that the offset is
3810 // non-negative, we can use nuw.
3811 bool NUW = NW.hasNoUnsignedWrap() ||
3814 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3815 assert(BaseExpr->getType() == GEPExpr->getType() &&
3816 "GEP should not change type mid-flight.");
3817 return GEPExpr;
3818}
3819
3820SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3823 ID.AddInteger(SCEVType);
3824 for (const SCEV *Op : Ops)
3825 ID.AddPointer(Op);
3826 void *IP = nullptr;
3827 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3828}
3829
3830const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3832 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3833}
3834
3837 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3838 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3839 if (Ops.size() == 1) return Ops[0];
3840#ifndef NDEBUG
3841 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3842 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3843 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3844 "Operand types don't match!");
3845 assert(Ops[0]->getType()->isPointerTy() ==
3846 Ops[i]->getType()->isPointerTy() &&
3847 "min/max should be consistently pointerish");
3848 }
3849#endif
3850
3851 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3852 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3853
3854 // Sort by complexity, this groups all similar expression types together.
3855 GroupByComplexity(Ops, &LI, DT);
3856
3857 // Check if we have created the same expression before.
3858 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3859 return S;
3860 }
3861
3862 // If there are any constants, fold them together.
3863 unsigned Idx = 0;
3864 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3865 ++Idx;
3866 assert(Idx < Ops.size());
3867 auto FoldOp = [&](const APInt &LHS, const APInt &RHS) {
3868 switch (Kind) {
3869 case scSMaxExpr:
3870 return APIntOps::smax(LHS, RHS);
3871 case scSMinExpr:
3872 return APIntOps::smin(LHS, RHS);
3873 case scUMaxExpr:
3874 return APIntOps::umax(LHS, RHS);
3875 case scUMinExpr:
3876 return APIntOps::umin(LHS, RHS);
3877 default:
3878 llvm_unreachable("Unknown SCEV min/max opcode");
3879 }
3880 };
3881
3882 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3883 // We found two constants, fold them together!
3884 ConstantInt *Fold = ConstantInt::get(
3885 getContext(), FoldOp(LHSC->getAPInt(), RHSC->getAPInt()));
3886 Ops[0] = getConstant(Fold);
3887 Ops.erase(Ops.begin()+1); // Erase the folded element
3888 if (Ops.size() == 1) return Ops[0];
3889 LHSC = cast<SCEVConstant>(Ops[0]);
3890 }
3891
3892 bool IsMinV = LHSC->getValue()->isMinValue(IsSigned);
3893 bool IsMaxV = LHSC->getValue()->isMaxValue(IsSigned);
3894
3895 if (IsMax ? IsMinV : IsMaxV) {
3896 // If we are left with a constant minimum(/maximum)-int, strip it off.
3897 Ops.erase(Ops.begin());
3898 --Idx;
3899 } else if (IsMax ? IsMaxV : IsMinV) {
3900 // If we have a max(/min) with a constant maximum(/minimum)-int,
3901 // it will always be the extremum.
3902 return LHSC;
3903 }
3904
3905 if (Ops.size() == 1) return Ops[0];
3906 }
3907
3908 // Find the first operation of the same kind
3909 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3910 ++Idx;
3911
3912 // Check to see if one of the operands is of the same kind. If so, expand its
3913 // operands onto our operand list, and recurse to simplify.
3914 if (Idx < Ops.size()) {
3915 bool DeletedAny = false;
3916 while (Ops[Idx]->getSCEVType() == Kind) {
3917 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3918 Ops.erase(Ops.begin()+Idx);
3919 append_range(Ops, SMME->operands());
3920 DeletedAny = true;
3921 }
3922
3923 if (DeletedAny)
3924 return getMinMaxExpr(Kind, Ops);
3925 }
3926
3927 // Okay, check to see if the same value occurs in the operand list twice. If
3928 // so, delete one. Since we sorted the list, these values are required to
3929 // be adjacent.
3934 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3935 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3936 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3937 if (Ops[i] == Ops[i + 1] ||
3938 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3939 // X op Y op Y --> X op Y
3940 // X op Y --> X, if we know X, Y are ordered appropriately
3941 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3942 --i;
3943 --e;
3944 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3945 Ops[i + 1])) {
3946 // X op Y --> Y, if we know X, Y are ordered appropriately
3947 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3948 --i;
3949 --e;
3950 }
3951 }
3952
3953 if (Ops.size() == 1) return Ops[0];
3954
3955 assert(!Ops.empty() && "Reduced smax down to nothing!");
3956
3957 // Okay, it looks like we really DO need an expr. Check to see if we
3958 // already have one, otherwise create a new one.
3960 ID.AddInteger(Kind);
3961 for (const SCEV *Op : Ops)
3962 ID.AddPointer(Op);
3963 void *IP = nullptr;
3964 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3965 if (ExistingSCEV)
3966 return ExistingSCEV;
3967 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3968 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3969 SCEV *S = new (SCEVAllocator)
3970 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3971
3972 UniqueSCEVs.InsertNode(S, IP);
3973 registerUser(S, Ops);
3974 return S;
3975}
3976
3977namespace {
3978
3979class SCEVSequentialMinMaxDeduplicatingVisitor final
3980 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
3981 std::optional<const SCEV *>> {
3982 using RetVal = std::optional<const SCEV *>;
3984
3985 ScalarEvolution &SE;
3986 const SCEVTypes RootKind; // Must be a sequential min/max expression.
3987 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
3989
3990 bool canRecurseInto(SCEVTypes Kind) const {
3991 // We can only recurse into the SCEV expression of the same effective type
3992 // as the type of our root SCEV expression.
3993 return RootKind == Kind || NonSequentialRootKind == Kind;
3994 };
3995
3996 RetVal visitAnyMinMaxExpr(const SCEV *S) {
3997 assert((isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) &&
3998 "Only for min/max expressions.");
3999 SCEVTypes Kind = S->getSCEVType();
4000
4001 if (!canRecurseInto(Kind))
4002 return S;
4003
4004 auto *NAry = cast<SCEVNAryExpr>(S);
4006 bool Changed = visit(Kind, NAry->operands(), NewOps);
4007
4008 if (!Changed)
4009 return S;
4010 if (NewOps.empty())
4011 return std::nullopt;
4012
4013 return isa<SCEVSequentialMinMaxExpr>(S)
4014 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
4015 : SE.getMinMaxExpr(Kind, NewOps);
4016 }
4017
4018 RetVal visit(const SCEV *S) {
4019 // Has the whole operand been seen already?
4020 if (!SeenOps.insert(S).second)
4021 return std::nullopt;
4022 return Base::visit(S);
4023 }
4024
4025public:
4026 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4027 SCEVTypes RootKind)
4028 : SE(SE), RootKind(RootKind),
4029 NonSequentialRootKind(
4030 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4031 RootKind)) {}
4032
4033 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
4035 bool Changed = false;
4037 Ops.reserve(OrigOps.size());
4038
4039 for (const SCEV *Op : OrigOps) {
4040 RetVal NewOp = visit(Op);
4041 if (NewOp != Op)
4042 Changed = true;
4043 if (NewOp)
4044 Ops.emplace_back(*NewOp);
4045 }
4046
4047 if (Changed)
4048 NewOps = std::move(Ops);
4049 return Changed;
4050 }
4051
4052 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4053
4054 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4055
4056 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4057
4058 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4059
4060 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4061
4062 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4063
4064 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4065
4066 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4067
4068 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4069
4070 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4071
4072 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4073 return visitAnyMinMaxExpr(Expr);
4074 }
4075
4076 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4077 return visitAnyMinMaxExpr(Expr);
4078 }
4079
4080 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4081 return visitAnyMinMaxExpr(Expr);
4082 }
4083
4084 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4085 return visitAnyMinMaxExpr(Expr);
4086 }
4087
4088 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4089 return visitAnyMinMaxExpr(Expr);
4090 }
4091
4092 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4093
4094 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4095};
4096
4097} // namespace
4098
4100 switch (Kind) {
4101 case scConstant:
4102 case scVScale:
4103 case scTruncate:
4104 case scZeroExtend:
4105 case scSignExtend:
4106 case scPtrToInt:
4107 case scAddExpr:
4108 case scMulExpr:
4109 case scUDivExpr:
4110 case scAddRecExpr:
4111 case scUMaxExpr:
4112 case scSMaxExpr:
4113 case scUMinExpr:
4114 case scSMinExpr:
4115 case scUnknown:
4116 // If any operand is poison, the whole expression is poison.
4117 return true;
4119 // FIXME: if the *first* operand is poison, the whole expression is poison.
4120 return false; // Pessimistically, say that it does not propagate poison.
4121 case scCouldNotCompute:
4122 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4123 }
4124 llvm_unreachable("Unknown SCEV kind!");
4125}
4126
4127namespace {
4128// The only way poison may be introduced in a SCEV expression is from a
4129// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4130// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4131// introduce poison -- they encode guaranteed, non-speculated knowledge.
4132//
4133// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4134// with the notable exception of umin_seq, where only poison from the first
4135// operand is (unconditionally) propagated.
4136struct SCEVPoisonCollector {
4137 bool LookThroughMaybePoisonBlocking;
4139 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4140 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4141
4142 bool follow(const SCEV *S) {
4143 if (!LookThroughMaybePoisonBlocking &&
4145 return false;
4146
4147 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4148 if (!isGuaranteedNotToBePoison(SU->getValue()))
4149 MaybePoison.insert(SU);
4150 }
4151 return true;
4152 }
4153 bool isDone() const { return false; }
4154};
4155} // namespace
4156
4157/// Return true if V is poison given that AssumedPoison is already poison.
4158static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4159 // First collect all SCEVs that might result in AssumedPoison to be poison.
4160 // We need to look through potentially poison-blocking operations here,
4161 // because we want to find all SCEVs that *might* result in poison, not only
4162 // those that are *required* to.
4163 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4164 visitAll(AssumedPoison, PC1);
4165
4166 // AssumedPoison is never poison. As the assumption is false, the implication
4167 // is true. Don't bother walking the other SCEV in this case.
4168 if (PC1.MaybePoison.empty())
4169 return true;
4170
4171 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4172 // as well. We cannot look through potentially poison-blocking operations
4173 // here, as their arguments only *may* make the result poison.
4174 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4175 visitAll(S, PC2);
4176
4177 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4178 // it will also make S poison by being part of PC2.MaybePoison.
4179 return all_of(PC1.MaybePoison, [&](const SCEVUnknown *S) {
4180 return PC2.MaybePoison.contains(S);
4181 });
4182}
4183
4185 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4186 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4187 visitAll(S, PC);
4188 for (const SCEVUnknown *SU : PC.MaybePoison)
4189 Result.insert(SU->getValue());
4190}
4191
4193 const SCEV *S, Instruction *I,
4194 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4195 // If the instruction cannot be poison, it's always safe to reuse.
4197 return true;
4198
4199 // Otherwise, it is possible that I is more poisonous that S. Collect the
4200 // poison-contributors of S, and then check whether I has any additional
4201 // poison-contributors. Poison that is contributed through poison-generating
4202 // flags is handled by dropping those flags instead.
4204 getPoisonGeneratingValues(PoisonVals, S);
4205
4206 SmallVector<Value *> Worklist;
4208 Worklist.push_back(I);
4209 while (!Worklist.empty()) {
4210 Value *V = Worklist.pop_back_val();
4211 if (!Visited.insert(V).second)
4212 continue;
4213
4214 // Avoid walking large instruction graphs.
4215 if (Visited.size() > 16)
4216 return false;
4217
4218 // Either the value can't be poison, or the S would also be poison if it
4219 // is.
4220 if (PoisonVals.contains(V) || isGuaranteedNotToBePoison(V))
4221 continue;
4222
4223 auto *I = dyn_cast<Instruction>(V);
4224 if (!I)
4225 return false;
4226
4227 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4228 // can't replace an arbitrary add with disjoint or, even if we drop the
4229 // flag. We would need to convert the or into an add.
4230 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4231 if (PDI->isDisjoint())
4232 return false;
4233
4234 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4235 // because SCEV currently assumes it can't be poison. Remove this special
4236 // case once we proper model when vscale can be poison.
4237 if (auto *II = dyn_cast<IntrinsicInst>(I);
4238 II && II->getIntrinsicID() == Intrinsic::vscale)
4239 continue;
4240
4241 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4242 return false;
4243
4244 // If the instruction can't create poison, we can recurse to its operands.
4245 if (I->hasPoisonGeneratingAnnotations())
4246 DropPoisonGeneratingInsts.push_back(I);
4247
4248 for (Value *Op : I->operands())
4249 Worklist.push_back(Op);
4250 }
4251 return true;
4252}
4253
4254const SCEV *
4257 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4258 "Not a SCEVSequentialMinMaxExpr!");
4259 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4260 if (Ops.size() == 1)
4261 return Ops[0];
4262#ifndef NDEBUG
4263 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4264 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4265 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4266 "Operand types don't match!");
4267 assert(Ops[0]->getType()->isPointerTy() ==
4268 Ops[i]->getType()->isPointerTy() &&
4269 "min/max should be consistently pointerish");
4270 }
4271#endif
4272
4273 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4274 // so we can *NOT* do any kind of sorting of the expressions!
4275
4276 // Check if we have created the same expression before.
4277 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4278 return S;
4279
4280 // FIXME: there are *some* simplifications that we can do here.
4281
4282 // Keep only the first instance of an operand.
4283 {
4284 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4285 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4286 if (Changed)
4287 return getSequentialMinMaxExpr(Kind, Ops);
4288 }
4289
4290 // Check to see if one of the operands is of the same kind. If so, expand its
4291 // operands onto our operand list, and recurse to simplify.
4292 {
4293 unsigned Idx = 0;
4294 bool DeletedAny = false;
4295 while (Idx < Ops.size()) {
4296 if (Ops[Idx]->getSCEVType() != Kind) {
4297 ++Idx;
4298 continue;
4299 }
4300 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4301 Ops.erase(Ops.begin() + Idx);
4302 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4303 SMME->operands().end());
4304 DeletedAny = true;
4305 }
4306
4307 if (DeletedAny)
4308 return getSequentialMinMaxExpr(Kind, Ops);
4309 }
4310
4311 const SCEV *SaturationPoint;
4313 switch (Kind) {
4315 SaturationPoint = getZero(Ops[0]->getType());
4316 Pred = ICmpInst::ICMP_ULE;
4317 break;
4318 default:
4319 llvm_unreachable("Not a sequential min/max type.");
4320 }
4321
4322 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4323 // We can replace %x umin_seq %y with %x umin %y if either:
4324 // * %y being poison implies %x is also poison.
4325 // * %x cannot be the saturating value (e.g. zero for umin).
4326 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4327 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4328 SaturationPoint)) {
4329 SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
4330 Ops[i - 1] = getMinMaxExpr(
4332 SeqOps);
4333 Ops.erase(Ops.begin() + i);
4334 return getSequentialMinMaxExpr(Kind, Ops);
4335 }
4336 // Fold %x umin_seq %y to %x if %x ule %y.
4337 // TODO: We might be able to prove the predicate for a later operand.
4338 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4339 Ops.erase(Ops.begin() + i);
4340 return getSequentialMinMaxExpr(Kind, Ops);
4341 }
4342 }
4343
4344 // Okay, it looks like we really DO need an expr. Check to see if we
4345 // already have one, otherwise create a new one.
4347 ID.AddInteger(Kind);
4348 for (const SCEV *Op : Ops)
4349 ID.AddPointer(Op);
4350 void *IP = nullptr;
4351 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4352 if (ExistingSCEV)
4353 return ExistingSCEV;
4354
4355 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4356 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
4357 SCEV *S = new (SCEVAllocator)
4358 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4359
4360 UniqueSCEVs.InsertNode(S, IP);
4361 registerUser(S, Ops);
4362 return S;
4363}
4364
4365const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4367 return getSMaxExpr(Ops);
4368}
4369
4371 return getMinMaxExpr(scSMaxExpr, Ops);
4372}
4373
4374const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4376 return getUMaxExpr(Ops);
4377}
4378
4380 return getMinMaxExpr(scUMaxExpr, Ops);
4381}
4382
4384 const SCEV *RHS) {
4386 return getSMinExpr(Ops);
4387}
4388
4390 return getMinMaxExpr(scSMinExpr, Ops);
4391}
4392
4393const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS,
4394 bool Sequential) {
4396 return getUMinExpr(Ops, Sequential);
4397}
4398
4400 bool Sequential) {
4401 return Sequential ? getSequentialMinMaxExpr(scSequentialUMinExpr, Ops)
4402 : getMinMaxExpr(scUMinExpr, Ops);
4403}
4404
4405const SCEV *
4407 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4408 if (Size.isScalable())
4409 Res = getMulExpr(Res, getVScale(IntTy));
4410 return Res;
4411}
4412
4414 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4415}
4416
4418 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4419}
4420
4422 StructType *STy,
4423 unsigned FieldNo) {
4424 // We can bypass creating a target-independent constant expression and then
4425 // folding it back into a ConstantInt. This is just a compile-time
4426 // optimization.
4427 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4428 assert(!SL->getSizeInBits().isScalable() &&
4429 "Cannot get offset for structure containing scalable vector types");
4430 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4431}
4432
4434 // Don't attempt to do anything other than create a SCEVUnknown object
4435 // here. createSCEV only calls getUnknown after checking for all other
4436 // interesting possibilities, and any other code that calls getUnknown
4437 // is doing so in order to hide a value from SCEV canonicalization.
4438
4440 ID.AddInteger(scUnknown);
4441 ID.AddPointer(V);
4442 void *IP = nullptr;
4443 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4444 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4445 "Stale SCEVUnknown in uniquing map!");
4446 return S;
4447 }
4448 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4449 FirstUnknown);
4450 FirstUnknown = cast<SCEVUnknown>(S);
4451 UniqueSCEVs.InsertNode(S, IP);
4452 return S;
4453}
4454
4455//===----------------------------------------------------------------------===//
4456// Basic SCEV Analysis and PHI Idiom Recognition Code
4457//
4458
4459/// Test if values of the given type are analyzable within the SCEV
4460/// framework. This primarily includes integer types, and it can optionally
4461/// include pointer types if the ScalarEvolution class has access to
4462/// target-specific information.
4464 // Integers and pointers are always SCEVable.
4465 return Ty->isIntOrPtrTy();
4466}
4467
4468/// Return the size in bits of the specified type, for which isSCEVable must
4469/// return true.
4471 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4472 if (Ty->isPointerTy())
4474 return getDataLayout().getTypeSizeInBits(Ty);
4475}
4476
4477/// Return a type with the same bitwidth as the given type and which represents
4478/// how SCEV will treat the given type, for which isSCEVable must return
4479/// true. For pointer types, this is the pointer index sized integer type.
4481 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4482
4483 if (Ty->isIntegerTy())
4484 return Ty;
4485
4486 // The only other support type is pointer.
4487 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4488 return getDataLayout().getIndexType(Ty);
4489}
4490
4492 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4493}
4494
4496 const SCEV *B) {
4497 /// For a valid use point to exist, the defining scope of one operand
4498 /// must dominate the other.
4499 bool PreciseA, PreciseB;
4500 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4501 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4502 if (!PreciseA || !PreciseB)
4503 // Can't tell.
4504 return false;
4505 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4506 DT.dominates(ScopeB, ScopeA);
4507}
4508
4510 return CouldNotCompute.get();
4511}
4512
4513bool ScalarEvolution::checkValidity(const SCEV *S) const {
4514 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4515 auto *SU = dyn_cast<SCEVUnknown>(S);
4516 return SU && SU->getValue() == nullptr;
4517 });
4518
4519 return !ContainsNulls;
4520}
4521
4523 HasRecMapType::iterator I = HasRecMap.find(S);
4524 if (I != HasRecMap.end())
4525 return I->second;
4526
4527 bool FoundAddRec =
4528 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4529 HasRecMap.insert({S, FoundAddRec});
4530 return FoundAddRec;
4531}
4532
4533/// Return the ValueOffsetPair set for \p S. \p S can be represented
4534/// by the value and offset from any ValueOffsetPair in the set.
4535ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4536 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4537 if (SI == ExprValueMap.end())
4538 return std::nullopt;
4539 return SI->second.getArrayRef();
4540}
4541
4542/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4543/// cannot be used separately. eraseValueFromMap should be used to remove
4544/// V from ValueExprMap and ExprValueMap at the same time.
4545void ScalarEvolution::eraseValueFromMap(Value *V) {
4546 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4547 if (I != ValueExprMap.end()) {
4548 auto EVIt = ExprValueMap.find(I->second);
4549 bool Removed = EVIt->second.remove(V);
4550 (void) Removed;
4551 assert(Removed && "Value not in ExprValueMap?");
4552 ValueExprMap.erase(I);
4553 }
4554}
4555
4556void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4557 // A recursive query may have already computed the SCEV. It should be
4558 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4559 // inferred nowrap flags.
4560 auto It = ValueExprMap.find_as(V);
4561 if (It == ValueExprMap.end()) {
4562 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4563 ExprValueMap[S].insert(V);
4564 }
4565}
4566
4567/// Return an existing SCEV if it exists, otherwise analyze the expression and
4568/// create a new one.
4570 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4571
4572 if (const SCEV *S = getExistingSCEV(V))
4573 return S;
4574 return createSCEVIter(V);
4575}
4576
4578 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4579
4580 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4581 if (I != ValueExprMap.end()) {
4582 const SCEV *S = I->second;
4583 assert(checkValidity(S) &&
4584 "existing SCEV has not been properly invalidated");
4585 return S;
4586 }
4587 return nullptr;
4588}
4589
4590/// Return a SCEV corresponding to -V = -1*V
4592 SCEV::NoWrapFlags Flags) {
4593 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4594 return getConstant(
4595 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4596
4597 Type *Ty = V->getType();
4598 Ty = getEffectiveSCEVType(Ty);
4599 return getMulExpr(V, getMinusOne(Ty), Flags);
4600}
4601
4602/// If Expr computes ~A, return A else return nullptr
4603static const SCEV *MatchNotExpr(const SCEV *Expr) {
4604 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
4605 if (!Add || Add->getNumOperands() != 2 ||
4606 !Add->getOperand(0)->isAllOnesValue())
4607 return nullptr;
4608
4609 const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
4610 if (!AddRHS || AddRHS->getNumOperands() != 2 ||
4611 !AddRHS->getOperand(0)->isAllOnesValue())
4612 return nullptr;
4613
4614 return AddRHS->getOperand(1);
4615}
4616
4617/// Return a SCEV corresponding to ~V = -1-V
4619 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4620
4621 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4622 return getConstant(
4623 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4624
4625 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4626 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4627 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4628 SmallVector<const SCEV *, 2> MatchedOperands;
4629 for (const SCEV *Operand : MME->operands()) {
4630 const SCEV *Matched = MatchNotExpr(Operand);
4631 if (!Matched)
4632 return (const SCEV *)nullptr;
4633 MatchedOperands.push_back(Matched);
4634 }
4635 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4636 MatchedOperands);
4637 };
4638 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4639 return Replaced;
4640 }
4641
4642 Type *Ty = V->getType();
4643 Ty = getEffectiveSCEVType(Ty);
4644 return getMinusSCEV(getMinusOne(Ty), V);
4645}
4646
4648 assert(P->getType()->isPointerTy());
4649
4650 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4651 // The base of an AddRec is the first operand.
4652 SmallVector<const SCEV *> Ops{AddRec->operands()};
4653 Ops[0] = removePointerBase(Ops[0]);
4654 // Don't try to transfer nowrap flags for now. We could in some cases
4655 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4656 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4657 }
4658 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4659 // The base of an Add is the pointer operand.
4660 SmallVector<const SCEV *> Ops{Add->operands()};
4661 const SCEV **PtrOp = nullptr;
4662 for (const SCEV *&AddOp : Ops) {
4663 if (AddOp->getType()->isPointerTy()) {
4664 assert(!PtrOp && "Cannot have multiple pointer ops");
4665 PtrOp = &AddOp;
4666 }
4667 }
4668 *PtrOp = removePointerBase(*PtrOp);
4669 // Don't try to transfer nowrap flags for now. We could in some cases
4670 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4671 return getAddExpr(Ops);
4672 }
4673 // Any other expression must be a pointer base.
4674 return getZero(P->getType());
4675}
4676
4677const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4678 SCEV::NoWrapFlags Flags,
4679 unsigned Depth) {
4680 // Fast path: X - X --> 0.
4681 if (LHS == RHS)
4682 return getZero(LHS->getType());
4683
4684 // If we subtract two pointers with different pointer bases, bail.
4685 // Eventually, we're going to add an assertion to getMulExpr that we
4686 // can't multiply by a pointer.
4687 if (RHS->getType()->isPointerTy()) {
4688 if (!LHS->getType()->isPointerTy() ||
4690 return getCouldNotCompute();
4693 }
4694
4695 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4696 // makes it so that we cannot make much use of NUW.
4697 auto AddFlags = SCEV::FlagAnyWrap;
4698 const bool RHSIsNotMinSigned =
4700 if (hasFlags(Flags, SCEV::FlagNSW)) {
4701 // Let M be the minimum representable signed value. Then (-1)*RHS
4702 // signed-wraps if and only if RHS is M. That can happen even for
4703 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4704 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4705 // (-1)*RHS, we need to prove that RHS != M.
4706 //
4707 // If LHS is non-negative and we know that LHS - RHS does not
4708 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4709 // either by proving that RHS > M or that LHS >= 0.
4710 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4711 AddFlags = SCEV::FlagNSW;
4712 }
4713 }
4714
4715 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4716 // RHS is NSW and LHS >= 0.
4717 //
4718 // The difficulty here is that the NSW flag may have been proven
4719 // relative to a loop that is to be found in a recurrence in LHS and
4720 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4721 // larger scope than intended.
4722 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4723
4724 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4725}
4726
4728 unsigned Depth) {
4729 Type *SrcTy = V->getType();
4730 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4731 "Cannot truncate or zero extend with non-integer arguments!");
4732 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4733 return V; // No conversion
4734 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4735 return getTruncateExpr(V, Ty, Depth);
4736 return getZeroExtendExpr(V, Ty, Depth);
4737}
4738
4740 unsigned Depth) {
4741 Type *SrcTy = V->getType();
4742 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4743 "Cannot truncate or zero extend with non-integer arguments!");
4744 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4745 return V; // No conversion
4746 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4747 return getTruncateExpr(V, Ty, Depth);
4748 return getSignExtendExpr(V, Ty, Depth);
4749}
4750
4751const SCEV *
4753 Type *SrcTy = V->getType();
4754 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4755 "Cannot noop or zero extend with non-integer arguments!");
4757 "getNoopOrZeroExtend cannot truncate!");
4758 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4759 return V; // No conversion
4760 return getZeroExtendExpr(V, Ty);
4761}
4762
4763const SCEV *
4765 Type *SrcTy = V->getType();
4766 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4767 "Cannot noop or sign extend with non-integer arguments!");
4769 "getNoopOrSignExtend cannot truncate!");
4770 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4771 return V; // No conversion
4772 return getSignExtendExpr(V, Ty);
4773}
4774
4775const SCEV *
4777 Type *SrcTy = V->getType();
4778 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4779 "Cannot noop or any extend with non-integer arguments!");
4781 "getNoopOrAnyExtend cannot truncate!");
4782 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4783 return V; // No conversion
4784 return getAnyExtendExpr(V, Ty);
4785}
4786
4787const SCEV *
4789 Type *SrcTy = V->getType();
4790 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4791 "Cannot truncate or noop with non-integer arguments!");
4793 "getTruncateOrNoop cannot extend!");
4794 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4795 return V; // No conversion
4796 return getTruncateExpr(V, Ty);
4797}
4798
4800 const SCEV *RHS) {
4801 const SCEV *PromotedLHS = LHS;
4802 const SCEV *PromotedRHS = RHS;
4803
4805 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4806 else
4807 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4808
4809 return getUMaxExpr(PromotedLHS, PromotedRHS);
4810}
4811
4813 const SCEV *RHS,
4814 bool Sequential) {
4816 return getUMinFromMismatchedTypes(Ops, Sequential);
4817}
4818
4819const SCEV *
4821 bool Sequential) {
4822 assert(!Ops.empty() && "At least one operand must be!");
4823 // Trivial case.
4824 if (Ops.size() == 1)
4825 return Ops[0];
4826
4827 // Find the max type first.
4828 Type *MaxType = nullptr;
4829 for (const auto *S : Ops)
4830 if (MaxType)
4831 MaxType = getWiderType(MaxType, S->getType());
4832 else
4833 MaxType = S->getType();
4834 assert(MaxType && "Failed to find maximum type!");
4835
4836 // Extend all ops to max type.
4837 SmallVector<const SCEV *, 2> PromotedOps;
4838 for (const auto *S : Ops)
4839 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4840
4841 // Generate umin.
4842 return getUMinExpr(PromotedOps, Sequential);
4843}
4844
4846 // A pointer operand may evaluate to a nonpointer expression, such as null.
4847 if (!V->getType()->isPointerTy())
4848 return V;
4849
4850 while (true) {
4851 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4852 V = AddRec->getStart();
4853 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4854 const SCEV *PtrOp = nullptr;
4855 for (const SCEV *AddOp : Add->operands()) {
4856 if (AddOp->getType()->isPointerTy()) {
4857 assert(!PtrOp && "Cannot have multiple pointer ops");
4858 PtrOp = AddOp;
4859 }
4860 }
4861 assert(PtrOp && "Must have pointer op");
4862 V = PtrOp;
4863 } else // Not something we can look further into.
4864 return V;
4865 }
4866}
4867
4868/// Push users of the given Instruction onto the given Worklist.
4872 // Push the def-use children onto the Worklist stack.
4873 for (User *U : I->users()) {
4874 auto *UserInsn = cast<Instruction>(U);
4875 if (Visited.insert(UserInsn).second)
4876 Worklist.push_back(UserInsn);
4877 }
4878}
4879
4880namespace {
4881
4882/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4883/// expression in case its Loop is L. If it is not L then
4884/// if IgnoreOtherLoops is true then use AddRec itself
4885/// otherwise rewrite cannot be done.
4886/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4887class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4888public:
4889 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4890 bool IgnoreOtherLoops = true) {
4891 SCEVInitRewriter Rewriter(L, SE);
4892 const SCEV *Result = Rewriter.visit(S);
4893 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4894 return SE.getCouldNotCompute();
4895 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4896 ? SE.getCouldNotCompute()
4897 : Result;
4898 }
4899
4900 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4901 if (!SE.isLoopInvariant(Expr, L))
4902 SeenLoopVariantSCEVUnknown = true;
4903 return Expr;
4904 }
4905
4906 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4907 // Only re-write AddRecExprs for this loop.
4908 if (Expr->getLoop() == L)
4909 return Expr->getStart();
4910 SeenOtherLoops = true;
4911 return Expr;
4912 }
4913
4914 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4915
4916 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4917
4918private:
4919 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4920 : SCEVRewriteVisitor(SE), L(L) {}
4921
4922 const Loop *L;
4923 bool SeenLoopVariantSCEVUnknown = false;
4924 bool SeenOtherLoops = false;
4925};
4926
4927/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4928/// increment expression in case its Loop is L. If it is not L then
4929/// use AddRec itself.
4930/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4931class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4932public:
4933 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4934 SCEVPostIncRewriter Rewriter(L, SE);
4935 const SCEV *Result = Rewriter.visit(S);
4936 return Rewriter.hasSeenLoopVariantSCEVUnknown()
4937 ? SE.getCouldNotCompute()
4938 : Result;
4939 }
4940
4941 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4942 if (!SE.isLoopInvariant(Expr, L))
4943 SeenLoopVariantSCEVUnknown = true;
4944 return Expr;
4945 }
4946
4947 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4948 // Only re-write AddRecExprs for this loop.
4949 if (Expr->getLoop() == L)
4950 return Expr->getPostIncExpr(SE);
4951 SeenOtherLoops = true;
4952 return Expr;
4953 }
4954
4955 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4956
4957 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4958
4959private:
4960 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
4961 : SCEVRewriteVisitor(SE), L(L) {}
4962
4963 const Loop *L;
4964 bool SeenLoopVariantSCEVUnknown = false;
4965 bool SeenOtherLoops = false;
4966};
4967
4968/// This class evaluates the compare condition by matching it against the
4969/// condition of loop latch. If there is a match we assume a true value
4970/// for the condition while building SCEV nodes.
4971class SCEVBackedgeConditionFolder
4972 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
4973public:
4974 static const SCEV *rewrite(const SCEV *S, const Loop *L,
4975 ScalarEvolution &SE) {
4976 bool IsPosBECond = false;
4977 Value *BECond = nullptr;
4978 if (BasicBlock *Latch = L->getLoopLatch()) {
4979 BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
4980 if (BI && BI->isConditional()) {
4981 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
4982 "Both outgoing branches should not target same header!");
4983 BECond = BI->getCondition();
4984 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
4985 } else {
4986 return S;
4987 }
4988 }
4989 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
4990 return Rewriter.visit(S);
4991 }
4992
4993 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4994 const SCEV *Result = Expr;
4995 bool InvariantF = SE.isLoopInvariant(Expr, L);
4996
4997 if (!InvariantF) {
4998 Instruction *I = cast<Instruction>(Expr->getValue());
4999 switch (I->getOpcode()) {
5000 case Instruction::Select: {
5001 SelectInst *SI = cast<SelectInst>(I);
5002 std::optional<const SCEV *> Res =
5003 compareWithBackedgeCondition(SI->getCondition());
5004 if (Res) {
5005 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
5006 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
5007 }
5008 break;
5009 }
5010 default: {
5011 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
5012 if (Res)
5013 Result = *Res;
5014 break;
5015 }
5016 }
5017 }
5018 return Result;
5019 }
5020
5021private:
5022 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5023 bool IsPosBECond, ScalarEvolution &SE)
5024 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5025 IsPositiveBECond(IsPosBECond) {}
5026
5027 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5028
5029 const Loop *L;
5030 /// Loop back condition.
5031 Value *BackedgeCond = nullptr;
5032 /// Set to true if loop back is on positive branch condition.
5033 bool IsPositiveBECond;
5034};
5035
5036std::optional<const SCEV *>
5037SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5038
5039 // If value matches the backedge condition for loop latch,
5040 // then return a constant evolution node based on loopback
5041 // branch taken.
5042 if (BackedgeCond == IC)
5043 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5045 return std::nullopt;
5046}
5047
5048class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5049public:
5050 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5051 ScalarEvolution &SE) {
5052 SCEVShiftRewriter Rewriter(L, SE);
5053 const SCEV *Result = Rewriter.visit(S);
5054 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5055 }
5056
5057 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5058 // Only allow AddRecExprs for this loop.
5059 if (!SE.isLoopInvariant(Expr, L))
5060 Valid = false;
5061 return Expr;
5062 }
5063
5064 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5065 if (Expr->getLoop() == L && Expr->isAffine())
5066 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5067 Valid = false;
5068 return Expr;
5069 }
5070
5071 bool isValid() { return Valid; }
5072
5073private:
5074 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5075 : SCEVRewriteVisitor(SE), L(L) {}
5076
5077 const Loop *L;
5078 bool Valid = true;
5079};
5080
5081} // end anonymous namespace
5082
5084ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5085 if (!AR->isAffine())
5086 return SCEV::FlagAnyWrap;
5087
5088 using OBO = OverflowingBinaryOperator;
5089
5091
5092 if (!AR->hasNoSelfWrap()) {
5093 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5094 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5095 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5096 const APInt &BECountAP = BECountMax->getAPInt();
5097 unsigned NoOverflowBitWidth =
5098 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5099 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5101 }
5102 }
5103
5104 if (!AR->hasNoSignedWrap()) {
5105 ConstantRange AddRecRange = getSignedRange(AR);
5106 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5107
5109 Instruction::Add, IncRange, OBO::NoSignedWrap);
5110 if (NSWRegion.contains(AddRecRange))
5112 }
5113
5114 if (!AR->hasNoUnsignedWrap()) {
5115 ConstantRange AddRecRange = getUnsignedRange(AR);
5116 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5117
5119 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5120 if (NUWRegion.contains(AddRecRange))
5122 }
5123
5124 return Result;
5125}
5126
5128ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5130
5131 if (AR->hasNoSignedWrap())
5132 return Result;
5133
5134 if (!AR->isAffine())
5135 return Result;
5136
5137 // This function can be expensive, only try to prove NSW once per AddRec.
5138 if (!SignedWrapViaInductionTried.insert(AR).second)
5139 return Result;
5140
5141 const SCEV *Step = AR->getStepRecurrence(*this);
5142 const Loop *L = AR->getLoop();
5143
5144 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5145 // Note that this serves two purposes: It filters out loops that are
5146 // simply not analyzable, and it covers the case where this code is
5147 // being called from within backedge-taken count analysis, such that
5148 // attempting to ask for the backedge-taken count would likely result
5149 // in infinite recursion. In the later case, the analysis code will
5150 // cope with a conservative value, and it will take care to purge
5151 // that value once it has finished.
5152 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5153
5154 // Normally, in the cases we can prove no-overflow via a
5155 // backedge guarding condition, we can also compute a backedge
5156 // taken count for the loop. The exceptions are assumptions and
5157 // guards present in the loop -- SCEV is not great at exploiting
5158 // these to compute max backedge taken counts, but can still use
5159 // these to prove lack of overflow. Use this fact to avoid
5160 // doing extra work that may not pay off.
5161
5162 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5163 AC.assumptions().empty())
5164 return Result;
5165
5166 // If the backedge is guarded by a comparison with the pre-inc value the
5167 // addrec is safe. Also, if the entry is guarded by a comparison with the
5168 // start value and the backedge is guarded by a comparison with the post-inc
5169 // value, the addrec is safe.
5171 const SCEV *OverflowLimit =
5172 getSignedOverflowLimitForStep(Step, &Pred, this);
5173 if (OverflowLimit &&
5174 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5175 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5176 Result = setFlags(Result, SCEV::FlagNSW);
5177 }
5178 return Result;
5179}
5181ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5183
5184 if (AR->hasNoUnsignedWrap())
5185 return Result;
5186
5187 if (!AR->isAffine())
5188 return Result;
5189
5190 // This function can be expensive, only try to prove NUW once per AddRec.
5191 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5192 return Result;
5193
5194 const SCEV *Step = AR->getStepRecurrence(*this);
5195 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5196 const Loop *L = AR->getLoop();
5197
5198 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5199 // Note that this serves two purposes: It filters out loops that are
5200 // simply not analyzable, and it covers the case where this code is
5201 // being called from within backedge-taken count analysis, such that
5202 // attempting to ask for the backedge-taken count would likely result
5203 // in infinite recursion. In the later case, the analysis code will
5204 // cope with a conservative value, and it will take care to purge
5205 // that value once it has finished.
5206 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5207
5208 // Normally, in the cases we can prove no-overflow via a
5209 // backedge guarding condition, we can also compute a backedge
5210 // taken count for the loop. The exceptions are assumptions and
5211 // guards present in the loop -- SCEV is not great at exploiting
5212 // these to compute max backedge taken counts, but can still use
5213 // these to prove lack of overflow. Use this fact to avoid
5214 // doing extra work that may not pay off.
5215
5216 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5217 AC.assumptions().empty())
5218 return Result;
5219
5220 // If the backedge is guarded by a comparison with the pre-inc value the
5221 // addrec is safe. Also, if the entry is guarded by a comparison with the
5222 // start value and the backedge is guarded by a comparison with the post-inc
5223 // value, the addrec is safe.
5224 if (isKnownPositive(Step)) {
5226 getUnsignedRangeMax(Step));
5229 Result = setFlags(Result, SCEV::FlagNUW);
5230 }
5231 }
5232
5233 return Result;
5234}
5235
5236namespace {
5237
5238/// Represents an abstract binary operation. This may exist as a
5239/// normal instruction or constant expression, or may have been
5240/// derived from an expression tree.
5241struct BinaryOp {
5242 unsigned Opcode;
5243 Value *LHS;
5244 Value *RHS;
5245 bool IsNSW = false;
5246 bool IsNUW = false;
5247
5248 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5249 /// constant expression.
5250 Operator *Op = nullptr;
5251
5252 explicit BinaryOp(Operator *Op)
5253 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5254 Op(Op) {
5255 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5256 IsNSW = OBO->hasNoSignedWrap();
5257 IsNUW = OBO->hasNoUnsignedWrap();
5258 }
5259 }
5260
5261 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5262 bool IsNUW = false)
5263 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5264};
5265
5266} // end anonymous namespace
5267
5268/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5269static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5270 AssumptionCache &AC,
5271 const DominatorTree &DT,
5272 const Instruction *CxtI) {
5273 auto *Op = dyn_cast<Operator>(V);
5274 if (!Op)
5275 return std::nullopt;
5276
5277 // Implementation detail: all the cleverness here should happen without
5278 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5279 // SCEV expressions when possible, and we should not break that.
5280
5281 switch (Op->getOpcode()) {
5282 case Instruction::Add:
5283 case Instruction::Sub:
5284 case Instruction::Mul:
5285 case Instruction::UDiv:
5286 case Instruction::URem:
5287 case Instruction::And:
5288 case Instruction::AShr:
5289 case Instruction::Shl:
5290 return BinaryOp(Op);
5291
5292 case Instruction::Or: {
5293 // Convert or disjoint into add nuw nsw.
5294 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5295 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5296 /*IsNSW=*/true, /*IsNUW=*/true);
5297 return BinaryOp(Op);
5298 }
5299
5300 case Instruction::Xor:
5301 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5302 // If the RHS of the xor is a signmask, then this is just an add.
5303 // Instcombine turns add of signmask into xor as a strength reduction step.
5304 if (RHSC->getValue().isSignMask())
5305 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5306 // Binary `xor` is a bit-wise `add`.
5307 if (V->getType()->isIntegerTy(1))
5308 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5309 return BinaryOp(Op);
5310
5311 case Instruction::LShr:
5312 // Turn logical shift right of a constant into a unsigned divide.
5313 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5314 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5315
5316 // If the shift count is not less than the bitwidth, the result of
5317 // the shift is undefined. Don't try to analyze it, because the
5318 // resolution chosen here may differ from the resolution chosen in
5319 // other parts of the compiler.
5320 if (SA->getValue().ult(BitWidth)) {
5321 Constant *X =
5322 ConstantInt::get(SA->getContext(),
5323 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5324 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5325 }
5326 }
5327 return BinaryOp(Op);
5328
5329 case Instruction::ExtractValue: {
5330 auto *EVI = cast<ExtractValueInst>(Op);
5331 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5332 break;
5333
5334 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5335 if (!WO)
5336 break;
5337
5338 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5339 bool Signed = WO->isSigned();
5340 // TODO: Should add nuw/nsw flags for mul as well.
5341 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5342 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5343
5344 // Now that we know that all uses of the arithmetic-result component of
5345 // CI are guarded by the overflow check, we can go ahead and pretend
5346 // that the arithmetic is non-overflowing.
5347 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5348 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5349 }
5350
5351 default:
5352 break;
5353 }
5354
5355 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5356 // semantics as a Sub, return a binary sub expression.
5357 if (auto *II = dyn_cast<IntrinsicInst>(V))
5358 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5359 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5360
5361 return std::nullopt;
5362}
5363
5364/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5365/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5366/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5367/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5368/// follows one of the following patterns:
5369/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5370/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5371/// If the SCEV expression of \p Op conforms with one of the expected patterns
5372/// we return the type of the truncation operation, and indicate whether the
5373/// truncated type should be treated as signed/unsigned by setting
5374/// \p Signed to true/false, respectively.
5375static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5376 bool &Signed, ScalarEvolution &SE) {
5377 // The case where Op == SymbolicPHI (that is, with no type conversions on
5378 // the way) is handled by the regular add recurrence creating logic and
5379 // would have already been triggered in createAddRecForPHI. Reaching it here
5380 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5381 // because one of the other operands of the SCEVAddExpr updating this PHI is
5382 // not invariant).
5383 //
5384 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5385 // this case predicates that allow us to prove that Op == SymbolicPHI will
5386 // be added.
5387 if (Op == SymbolicPHI)
5388 return nullptr;
5389
5390 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5391 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5392 if (SourceBits != NewBits)
5393 return nullptr;
5394
5395 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op);
5396 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op);
5397 if (!SExt && !ZExt)
5398 return nullptr;
5399 const SCEVTruncateExpr *Trunc =
5400 SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand())
5401 : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand());
5402 if (!Trunc)
5403 return nullptr;
5404 const SCEV *X = Trunc->getOperand();
5405 if (X != SymbolicPHI)
5406 return nullptr;
5407 Signed = SExt != nullptr;
5408 return Trunc->getType();
5409}
5410
5411static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5412 if (!PN->getType()->isIntegerTy())
5413 return nullptr;
5414 const Loop *L = LI.getLoopFor(PN->getParent());
5415 if (!L || L->getHeader() != PN->getParent())
5416 return nullptr;
5417 return L;
5418}
5419
5420// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5421// computation that updates the phi follows the following pattern:
5422// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5423// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5424// If so, try to see if it can be rewritten as an AddRecExpr under some
5425// Predicates. If successful, return them as a pair. Also cache the results
5426// of the analysis.
5427//
5428// Example usage scenario:
5429// Say the Rewriter is called for the following SCEV:
5430// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5431// where:
5432// %X = phi i64 (%Start, %BEValue)
5433// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5434// and call this function with %SymbolicPHI = %X.
5435//
5436// The analysis will find that the value coming around the backedge has
5437// the following SCEV:
5438// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5439// Upon concluding that this matches the desired pattern, the function
5440// will return the pair {NewAddRec, SmallPredsVec} where:
5441// NewAddRec = {%Start,+,%Step}
5442// SmallPredsVec = {P1, P2, P3} as follows:
5443// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5444// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5445// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5446// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5447// under the predicates {P1,P2,P3}.
5448// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5449// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5450//
5451// TODO's:
5452//
5453// 1) Extend the Induction descriptor to also support inductions that involve
5454// casts: When needed (namely, when we are called in the context of the
5455// vectorizer induction analysis), a Set of cast instructions will be
5456// populated by this method, and provided back to isInductionPHI. This is
5457// needed to allow the vectorizer to properly record them to be ignored by
5458// the cost model and to avoid vectorizing them (otherwise these casts,
5459// which are redundant under the runtime overflow checks, will be
5460// vectorized, which can be costly).
5461//
5462// 2) Support additional induction/PHISCEV patterns: We also want to support
5463// inductions where the sext-trunc / zext-trunc operations (partly) occur
5464// after the induction update operation (the induction increment):
5465//
5466// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5467// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5468//
5469// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5470// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5471//
5472// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5473std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5474ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5476
5477 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5478 // return an AddRec expression under some predicate.
5479
5480 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5481 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5482 assert(L && "Expecting an integer loop header phi");
5483
5484 // The loop may have multiple entrances or multiple exits; we can analyze
5485 // this phi as an addrec if it has a unique entry value and a unique
5486 // backedge value.
5487 Value *BEValueV = nullptr, *StartValueV = nullptr;
5488 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5489 Value *V = PN->getIncomingValue(i);
5490 if (L->contains(PN->getIncomingBlock(i))) {
5491 if (!BEValueV) {
5492 BEValueV = V;
5493 } else if (BEValueV != V) {
5494 BEValueV = nullptr;
5495 break;
5496 }
5497 } else if (!StartValueV) {
5498 StartValueV = V;
5499 } else if (StartValueV != V) {
5500 StartValueV = nullptr;
5501 break;
5502 }
5503 }
5504 if (!BEValueV || !StartValueV)
5505 return std::nullopt;
5506
5507 const SCEV *BEValue = getSCEV(BEValueV);
5508
5509 // If the value coming around the backedge is an add with the symbolic
5510 // value we just inserted, possibly with casts that we can ignore under
5511 // an appropriate runtime guard, then we found a simple induction variable!
5512 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5513 if (!Add)
5514 return std::nullopt;
5515
5516 // If there is a single occurrence of the symbolic value, possibly
5517 // casted, replace it with a recurrence.
5518 unsigned FoundIndex = Add->getNumOperands();
5519 Type *TruncTy = nullptr;
5520 bool Signed;
5521 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5522 if ((TruncTy =
5523 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5524 if (FoundIndex == e) {
5525 FoundIndex = i;
5526 break;
5527 }
5528
5529 if (FoundIndex == Add->getNumOperands())
5530 return std::nullopt;
5531
5532 // Create an add with everything but the specified operand.
5534 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5535 if (i != FoundIndex)
5536 Ops.push_back(Add->getOperand(i));
5537 const SCEV *Accum = getAddExpr(Ops);
5538
5539 // The runtime checks will not be valid if the step amount is
5540 // varying inside the loop.
5541 if (!isLoopInvariant(Accum, L))
5542 return std::nullopt;
5543
5544 // *** Part2: Create the predicates
5545
5546 // Analysis was successful: we have a phi-with-cast pattern for which we
5547 // can return an AddRec expression under the following predicates:
5548 //
5549 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5550 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5551 // P2: An Equal predicate that guarantees that
5552 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5553 // P3: An Equal predicate that guarantees that
5554 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5555 //
5556 // As we next prove, the above predicates guarantee that:
5557 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5558 //
5559 //
5560 // More formally, we want to prove that:
5561 // Expr(i+1) = Start + (i+1) * Accum
5562 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5563 //
5564 // Given that:
5565 // 1) Expr(0) = Start
5566 // 2) Expr(1) = Start + Accum
5567 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5568 // 3) Induction hypothesis (step i):
5569 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5570 //
5571 // Proof:
5572 // Expr(i+1) =
5573 // = Start + (i+1)*Accum
5574 // = (Start + i*Accum) + Accum
5575 // = Expr(i) + Accum
5576 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5577 // :: from step i
5578 //
5579 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5580 //
5581 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5582 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5583 // + Accum :: from P3
5584 //
5585 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5586 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5587 //
5588 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5589 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5590 //
5591 // By induction, the same applies to all iterations 1<=i<n:
5592 //
5593
5594 // Create a truncated addrec for which we will add a no overflow check (P1).
5595 const SCEV *StartVal = getSCEV(StartValueV);
5596 const SCEV *PHISCEV =
5597 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5598 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5599
5600 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5601 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5602 // will be constant.
5603 //
5604 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5605 // add P1.
5606 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5610 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5611 Predicates.push_back(AddRecPred);
5612 }
5613
5614 // Create the Equal Predicates P2,P3:
5615
5616 // It is possible that the predicates P2 and/or P3 are computable at
5617 // compile time due to StartVal and/or Accum being constants.
5618 // If either one is, then we can check that now and escape if either P2
5619 // or P3 is false.
5620
5621 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5622 // for each of StartVal and Accum
5623 auto getExtendedExpr = [&](const SCEV *Expr,
5624 bool CreateSignExtend) -> const SCEV * {
5625 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5626 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5627 const SCEV *ExtendedExpr =
5628 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5629 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5630 return ExtendedExpr;
5631 };
5632
5633 // Given:
5634 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5635 // = getExtendedExpr(Expr)
5636 // Determine whether the predicate P: Expr == ExtendedExpr
5637 // is known to be false at compile time
5638 auto PredIsKnownFalse = [&](const SCEV *Expr,
5639 const SCEV *ExtendedExpr) -> bool {
5640 return Expr != ExtendedExpr &&
5641 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5642 };
5643
5644 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5645 if (PredIsKnownFalse(StartVal, StartExtended)) {
5646 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5647 return std::nullopt;
5648 }
5649
5650 // The Step is always Signed (because the overflow checks are either
5651 // NSSW or NUSW)
5652 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5653 if (PredIsKnownFalse(Accum, AccumExtended)) {
5654 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5655 return std::nullopt;
5656 }
5657
5658 auto AppendPredicate = [&](const SCEV *Expr,
5659 const SCEV *ExtendedExpr) -> void {
5660 if (Expr != ExtendedExpr &&
5661 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5662 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5663 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5664 Predicates.push_back(Pred);
5665 }
5666 };
5667
5668 AppendPredicate(StartVal, StartExtended);
5669 AppendPredicate(Accum, AccumExtended);
5670
5671 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5672 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5673 // into NewAR if it will also add the runtime overflow checks specified in
5674 // Predicates.
5675 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5676
5677 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5678 std::make_pair(NewAR, Predicates);
5679 // Remember the result of the analysis for this SCEV at this locayyytion.
5680 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5681 return PredRewrite;
5682}
5683
5684std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5686 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5687 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5688 if (!L)
5689 return std::nullopt;
5690
5691 // Check to see if we already analyzed this PHI.
5692 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5693 if (I != PredicatedSCEVRewrites.end()) {
5694 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5695 I->second;
5696 // Analysis was done before and failed to create an AddRec:
5697 if (Rewrite.first == SymbolicPHI)
5698 return std::nullopt;
5699 // Analysis was done before and succeeded to create an AddRec under
5700 // a predicate:
5701 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5702 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5703 return Rewrite;
5704 }
5705
5706 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5707 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5708
5709 // Record in the cache that the analysis failed
5710 if (!Rewrite) {
5712 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5713 return std::nullopt;
5714 }
5715
5716 return Rewrite;
5717}
5718
5719// FIXME: This utility is currently required because the Rewriter currently
5720// does not rewrite this expression:
5721// {0, +, (sext ix (trunc iy to ix) to iy)}
5722// into {0, +, %step},
5723// even when the following Equal predicate exists:
5724// "%step == (sext ix (trunc iy to ix) to iy)".
5726 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5727 if (AR1 == AR2)
5728 return true;
5729
5730 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5731 if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) &&
5732 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1)))
5733 return false;
5734 return true;
5735 };
5736
5737 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5738 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5739 return false;
5740 return true;
5741}
5742
5743/// A helper function for createAddRecFromPHI to handle simple cases.
5744///
5745/// This function tries to find an AddRec expression for the simplest (yet most
5746/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5747/// If it fails, createAddRecFromPHI will use a more general, but slow,
5748/// technique for finding the AddRec expression.
5749const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5750 Value *BEValueV,
5751 Value *StartValueV) {
5752 const Loop *L = LI.getLoopFor(PN->getParent());
5753 assert(L && L->getHeader() == PN->getParent());
5754 assert(BEValueV && StartValueV);
5755
5756 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5757 if (!BO)
5758 return nullptr;
5759
5760 if (BO->Opcode != Instruction::Add)
5761 return nullptr;
5762
5763 const SCEV *Accum = nullptr;
5764 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5765 Accum = getSCEV(BO->RHS);
5766 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5767 Accum = getSCEV(BO->LHS);
5768
5769 if (!Accum)
5770 return nullptr;
5771
5773 if (BO->IsNUW)
5774 Flags = setFlags(Flags, SCEV::FlagNUW);
5775 if (BO->IsNSW)
5776 Flags = setFlags(Flags, SCEV::FlagNSW);
5777
5778 const SCEV *StartVal = getSCEV(StartValueV);
5779 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5780 insertValueToMap(PN, PHISCEV);
5781
5782 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5783 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5785 proveNoWrapViaConstantRanges(AR)));
5786 }
5787
5788 // We can add Flags to the post-inc expression only if we
5789 // know that it is *undefined behavior* for BEValueV to
5790 // overflow.
5791 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5792 assert(isLoopInvariant(Accum, L) &&
5793 "Accum is defined outside L, but is not invariant?");
5794 if (isAddRecNeverPoison(BEInst, L))
5795 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5796 }
5797
5798 return PHISCEV;
5799}
5800
5801const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5802 const Loop *L = LI.getLoopFor(PN->getParent());
5803 if (!L || L->getHeader() != PN->getParent())
5804 return nullptr;
5805
5806 // The loop may have multiple entrances or multiple exits; we can analyze
5807 // this phi as an addrec if it has a unique entry value and a unique
5808 // backedge value.
5809 Value *BEValueV = nullptr, *StartValueV = nullptr;
5810 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5811 Value *V = PN->getIncomingValue(i);
5812 if (L->contains(PN->getIncomingBlock(i))) {
5813 if (!BEValueV) {
5814 BEValueV = V;
5815 } else if (BEValueV != V) {
5816 BEValueV = nullptr;
5817 break;
5818 }
5819 } else if (!StartValueV) {
5820 StartValueV = V;
5821 } else if (StartValueV != V) {
5822 StartValueV = nullptr;
5823 break;
5824 }
5825 }
5826 if (!BEValueV || !StartValueV)
5827 return nullptr;
5828
5829 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5830 "PHI node already processed?");
5831
5832 // First, try to find AddRec expression without creating a fictituos symbolic
5833 // value for PN.
5834 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5835 return S;
5836
5837 // Handle PHI node value symbolically.
5838 const SCEV *SymbolicName = getUnknown(PN);
5839 insertValueToMap(PN, SymbolicName);
5840
5841 // Using this symbolic name for the PHI, analyze the value coming around
5842 // the back-edge.
5843 const SCEV *BEValue = getSCEV(BEValueV);
5844
5845 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5846 // has a special value for the first iteration of the loop.
5847
5848 // If the value coming around the backedge is an add with the symbolic
5849 // value we just inserted, then we found a simple induction variable!
5850 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5851 // If there is a single occurrence of the symbolic value, replace it
5852 // with a recurrence.
5853 unsigned FoundIndex = Add->getNumOperands();
5854 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5855 if (Add->getOperand(i) == SymbolicName)
5856 if (FoundIndex == e) {
5857 FoundIndex = i;
5858 break;
5859 }
5860
5861 if (FoundIndex != Add->getNumOperands()) {
5862 // Create an add with everything but the specified operand.
5864 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5865 if (i != FoundIndex)
5866 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5867 L, *this));
5868 const SCEV *Accum = getAddExpr(Ops);
5869
5870 // This is not a valid addrec if the step amount is varying each
5871 // loop iteration, but is not itself an addrec in this loop.
5872 if (isLoopInvariant(Accum, L) ||
5873 (isa<SCEVAddRecExpr>(Accum) &&
5874 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5876
5877 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
5878 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5879 if (BO->IsNUW)
5880 Flags = setFlags(Flags, SCEV::FlagNUW);
5881 if (BO->IsNSW)
5882 Flags = setFlags(Flags, SCEV::FlagNSW);
5883 }
5884 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5885 if (GEP->getOperand(0) == PN) {
5886 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
5887 // If the increment has any nowrap flags, then we know the address
5888 // space cannot be wrapped around.
5889 if (NW != GEPNoWrapFlags::none())
5890 Flags = setFlags(Flags, SCEV::FlagNW);
5891 // If the GEP is nuw or nusw with non-negative offset, we know that
5892 // no unsigned wrap occurs. We cannot set the nsw flag as only the
5893 // offset is treated as signed, while the base is unsigned.
5894 if (NW.hasNoUnsignedWrap() ||
5896 Flags = setFlags(Flags, SCEV::FlagNUW);
5897 }
5898
5899 // We cannot transfer nuw and nsw flags from subtraction
5900 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5901 // for instance.
5902 }
5903
5904 const SCEV *StartVal = getSCEV(StartValueV);
5905 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5906
5907 // Okay, for the entire analysis of this edge we assumed the PHI
5908 // to be symbolic. We now need to go back and purge all of the
5909 // entries for the scalars that use the symbolic expression.
5910 forgetMemoizedResults(SymbolicName);
5911 insertValueToMap(PN, PHISCEV);
5912
5913 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5914 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5916 proveNoWrapViaConstantRanges(AR)));
5917 }
5918
5919 // We can add Flags to the post-inc expression only if we
5920 // know that it is *undefined behavior* for BEValueV to
5921 // overflow.
5922 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5923 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5924 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5925
5926 return PHISCEV;
5927 }
5928 }
5929 } else {
5930 // Otherwise, this could be a loop like this:
5931 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
5932 // In this case, j = {1,+,1} and BEValue is j.
5933 // Because the other in-value of i (0) fits the evolution of BEValue
5934 // i really is an addrec evolution.
5935 //
5936 // We can generalize this saying that i is the shifted value of BEValue
5937 // by one iteration:
5938 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
5939 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
5940 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
5941 if (Shifted != getCouldNotCompute() &&
5942 Start != getCouldNotCompute()) {
5943 const SCEV *StartVal = getSCEV(StartValueV);
5944 if (Start == StartVal) {
5945 // Okay, for the entire analysis of this edge we assumed the PHI
5946 // to be symbolic. We now need to go back and purge all of the
5947 // entries for the scalars that use the symbolic expression.
5948 forgetMemoizedResults(SymbolicName);
5949 insertValueToMap(PN, Shifted);
5950 return Shifted;
5951 }
5952 }
5953 }
5954
5955 // Remove the temporary PHI node SCEV that has been inserted while intending
5956 // to create an AddRecExpr for this PHI node. We can not keep this temporary
5957 // as it will prevent later (possibly simpler) SCEV expressions to be added
5958 // to the ValueExprMap.
5959 eraseValueFromMap(PN);
5960
5961 return nullptr;
5962}
5963
5964// Try to match a control flow sequence that branches out at BI and merges back
5965// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
5966// match.
5968 Value *&C, Value *&LHS, Value *&RHS) {
5969 C = BI->getCondition();
5970
5971 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
5972 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
5973
5974 if (!LeftEdge.isSingleEdge())
5975 return false;
5976
5977 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
5978
5979 Use &LeftUse = Merge->getOperandUse(0);
5980 Use &RightUse = Merge->getOperandUse(1);
5981
5982 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
5983 LHS = LeftUse;
5984 RHS = RightUse;
5985 return true;
5986 }
5987
5988 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
5989 LHS = RightUse;
5990 RHS = LeftUse;
5991 return true;
5992 }
5993
5994 return false;
5995}
5996
5997const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
5998 auto IsReachable =
5999 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
6000 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
6001 // Try to match
6002 //
6003 // br %cond, label %left, label %right
6004 // left:
6005 // br label %merge
6006 // right:
6007 // br label %merge
6008 // merge:
6009 // V = phi [ %x, %left ], [ %y, %right ]
6010 //
6011 // as "select %cond, %x, %y"
6012
6013 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
6014 assert(IDom && "At least the entry block should dominate PN");
6015
6016 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
6017 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6018
6019 if (BI && BI->isConditional() &&
6020 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
6021 properlyDominates(getSCEV(LHS), PN->getParent()) &&
6022 properlyDominates(getSCEV(RHS), PN->getParent()))
6023 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6024 }
6025
6026 return nullptr;
6027}
6028
6029const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6030 if (const SCEV *S = createAddRecFromPHI(PN))
6031 return S;
6032
6033 if (Value *V = simplifyInstruction(PN, {getDataLayout(), &TLI, &DT, &AC}))
6034 return getSCEV(V);
6035
6036 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6037 return S;
6038
6039 // If it's not a loop phi, we can't handle it yet.
6040 return getUnknown(PN);
6041}
6042
6043bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6044 SCEVTypes RootKind) {
6045 struct FindClosure {
6046 const SCEV *OperandToFind;
6047 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6048 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6049
6050 bool Found = false;
6051
6052 bool canRecurseInto(SCEVTypes Kind) const {
6053 // We can only recurse into the SCEV expression of the same effective type
6054 // as the type of our root SCEV expression, and into zero-extensions.
6055 return RootKind == Kind || NonSequentialRootKind == Kind ||
6056 scZeroExtend == Kind;
6057 };
6058
6059 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6060 : OperandToFind(OperandToFind), RootKind(RootKind),
6061 NonSequentialRootKind(
6063 RootKind)) {}
6064
6065 bool follow(const SCEV *S) {
6066 Found = S == OperandToFind;
6067
6068 return !isDone() && canRecurseInto(S->getSCEVType());
6069 }
6070
6071 bool isDone() const { return Found; }
6072 };
6073
6074 FindClosure FC(OperandToFind, RootKind);
6075 visitAll(Root, FC);
6076 return FC.Found;
6077}
6078
6079std::optional<const SCEV *>
6080ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6081 ICmpInst *Cond,
6082 Value *TrueVal,
6083 Value *FalseVal) {
6084 // Try to match some simple smax or umax patterns.
6085 auto *ICI = Cond;
6086
6087 Value *LHS = ICI->getOperand(0);
6088 Value *RHS = ICI->getOperand(1);
6089
6090 switch (ICI->getPredicate()) {
6091 case ICmpInst::ICMP_SLT:
6092 case ICmpInst::ICMP_SLE:
6093 case ICmpInst::ICMP_ULT:
6094 case ICmpInst::ICMP_ULE:
6095 std::swap(LHS, RHS);
6096 [[fallthrough]];
6097 case ICmpInst::ICMP_SGT:
6098 case ICmpInst::ICMP_SGE:
6099 case ICmpInst::ICMP_UGT:
6100 case ICmpInst::ICMP_UGE:
6101 // a > b ? a+x : b+x -> max(a, b)+x
6102 // a > b ? b+x : a+x -> min(a, b)+x
6104 bool Signed = ICI->isSigned();
6105 const SCEV *LA = getSCEV(TrueVal);
6106 const SCEV *RA = getSCEV(FalseVal);
6107 const SCEV *LS = getSCEV(LHS);
6108 const SCEV *RS = getSCEV(RHS);
6109 if (LA->getType()->isPointerTy()) {
6110 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6111 // Need to make sure we can't produce weird expressions involving
6112 // negated pointers.
6113 if (LA == LS && RA == RS)
6114 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6115 if (LA == RS && RA == LS)
6116 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6117 }
6118 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6119 if (Op->getType()->isPointerTy()) {
6121 if (isa<SCEVCouldNotCompute>(Op))
6122 return Op;
6123 }
6124 if (Signed)
6125 Op = getNoopOrSignExtend(Op, Ty);
6126 else
6127 Op = getNoopOrZeroExtend(Op, Ty);
6128 return Op;
6129 };
6130 LS = CoerceOperand(LS);
6131 RS = CoerceOperand(RS);
6132 if (isa<SCEVCouldNotCompute>(LS) || isa<SCEVCouldNotCompute>(RS))
6133 break;
6134 const SCEV *LDiff = getMinusSCEV(LA, LS);
6135 const SCEV *RDiff = getMinusSCEV(RA, RS);
6136 if (LDiff == RDiff)
6137 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6138 LDiff);
6139 LDiff = getMinusSCEV(LA, RS);
6140 RDiff = getMinusSCEV(RA, LS);
6141 if (LDiff == RDiff)
6142 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6143 LDiff);
6144 }
6145 break;
6146 case ICmpInst::ICMP_NE:
6147 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6148 std::swap(TrueVal, FalseVal);
6149 [[fallthrough]];
6150 case ICmpInst::ICMP_EQ:
6151 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6153 isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
6154 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6155 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6156 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6157 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6158 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6159 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6160 return getAddExpr(getUMaxExpr(X, C), Y);
6161 }
6162 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6163 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6164 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6165 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6166 if (isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero() &&
6167 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6168 const SCEV *X = getSCEV(LHS);
6169 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6170 X = ZExt->getOperand();
6171 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6172 const SCEV *FalseValExpr = getSCEV(FalseVal);
6173 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6174 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6175 /*Sequential=*/true);
6176 }
6177 }
6178 break;
6179 default:
6180 break;
6181 }
6182
6183 return std::nullopt;
6184}
6185
6186static std::optional<const SCEV *>
6188 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6189 assert(CondExpr->getType()->isIntegerTy(1) &&
6190 TrueExpr->getType() == FalseExpr->getType() &&
6191 TrueExpr->getType()->isIntegerTy(1) &&
6192 "Unexpected operands of a select.");
6193
6194 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6195 // --> C + (umin_seq cond, x - C)
6196 //
6197 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6198 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6199 // --> C + (umin_seq ~cond, x - C)
6200
6201 // FIXME: while we can't legally model the case where both of the hands
6202 // are fully variable, we only require that the *difference* is constant.
6203 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6204 return std::nullopt;
6205
6206 const SCEV *X, *C;
6207 if (isa<SCEVConstant>(TrueExpr)) {
6208 CondExpr = SE->getNotSCEV(CondExpr);
6209 X = FalseExpr;
6210 C = TrueExpr;
6211 } else {
6212 X = TrueExpr;
6213 C = FalseExpr;
6214 }
6215 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6216 /*Sequential=*/true));
6217}
6218
6219static std::optional<const SCEV *>
6221 Value *FalseVal) {
6222 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6223 return std::nullopt;
6224
6225 const auto *SECond = SE->getSCEV(Cond);
6226 const auto *SETrue = SE->getSCEV(TrueVal);
6227 const auto *SEFalse = SE->getSCEV(FalseVal);
6228 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6229}
6230
6231const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6232 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6233 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6234 assert(TrueVal->getType() == FalseVal->getType() &&
6235 V->getType() == TrueVal->getType() &&
6236 "Types of select hands and of the result must match.");
6237
6238 // For now, only deal with i1-typed `select`s.
6239 if (!V->getType()->isIntegerTy(1))
6240 return getUnknown(V);
6241
6242 if (std::optional<const SCEV *> S =
6243 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6244 return *S;
6245
6246 return getUnknown(V);
6247}
6248
6249const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6250 Value *TrueVal,
6251 Value *FalseVal) {
6252 // Handle "constant" branch or select. This can occur for instance when a
6253 // loop pass transforms an inner loop and moves on to process the outer loop.
6254 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6255 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6256
6257 if (auto *I = dyn_cast<Instruction>(V)) {
6258 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6259 if (std::optional<const SCEV *> S =
6260 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6261 TrueVal, FalseVal))
6262 return *S;
6263 }
6264 }
6265
6266 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6267}
6268
6269/// Expand GEP instructions into add and multiply operations. This allows them
6270/// to be analyzed by regular SCEV code.
6271const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6272 assert(GEP->getSourceElementType()->isSized() &&
6273 "GEP source element type must be sized");
6274
6276 for (Value *Index : GEP->indices())
6277 IndexExprs.push_back(getSCEV(Index));
6278 return getGEPExpr(GEP, IndexExprs);
6279}
6280
6281APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
6283 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6284 return TrailingZeros >= BitWidth
6286 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6287 };
6288 auto GetGCDMultiple = [this](const SCEVNAryExpr *N) {
6289 // The result is GCD of all operands results.
6290 APInt Res = getConstantMultiple(N->getOperand(0));
6291 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6293 Res, getConstantMultiple(N->getOperand(I)));
6294 return Res;
6295 };
6296
6297 switch (S->getSCEVType()) {
6298 case scConstant:
6299 return cast<SCEVConstant>(S)->getAPInt();
6300 case scPtrToInt:
6301 return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand());
6302 case scUDivExpr:
6303 case scVScale:
6304 return APInt(BitWidth, 1);
6305 case scTruncate: {
6306 // Only multiples that are a power of 2 will hold after truncation.
6307 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6308 uint32_t TZ = getMinTrailingZeros(T->getOperand());
6309 return GetShiftedByZeros(TZ);
6310 }
6311 case scZeroExtend: {
6312 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6313 return getConstantMultiple(Z->getOperand()).zext(BitWidth);
6314 }
6315 case scSignExtend: {
6316 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6318 }
6319 case scMulExpr: {
6320 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6321 if (M->hasNoUnsignedWrap()) {
6322 // The result is the product of all operand results.
6323 APInt Res = getConstantMultiple(M->getOperand(0));
6324 for (const SCEV *Operand : M->operands().drop_front())
6325 Res = Res * getConstantMultiple(Operand);
6326 return Res;
6327 }
6328
6329 // If there are no wrap guarentees, find the trailing zeros, which is the
6330 // sum of trailing zeros for all its operands.
6331 uint32_t TZ = 0;
6332 for (const SCEV *Operand : M->operands())
6333 TZ += getMinTrailingZeros(Operand);
6334 return GetShiftedByZeros(TZ);
6335 }
6336 case scAddExpr:
6337 case scAddRecExpr: {
6338 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6339 if (N->hasNoUnsignedWrap())
6340 return GetGCDMultiple(N);
6341 // Find the trailing bits, which is the minimum of its operands.
6342 uint32_t TZ = getMinTrailingZeros(N->getOperand(0));
6343 for (const SCEV *Operand : N->operands().drop_front())
6344 TZ = std::min(TZ, getMinTrailingZeros(Operand));
6345 return GetShiftedByZeros(TZ);
6346 }
6347 case scUMaxExpr:
6348 case scSMaxExpr:
6349 case scUMinExpr:
6350 case scSMinExpr:
6352 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6353 case scUnknown: {
6354 // ask ValueTracking for known bits
6355 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6356 unsigned Known =
6357 computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT)
6358 .countMinTrailingZeros();
6359 return GetShiftedByZeros(Known);
6360 }
6361 case scCouldNotCompute:
6362 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6363 }
6364 llvm_unreachable("Unknown SCEV kind!");
6365}
6366
6368 auto I = ConstantMultipleCache.find(S);
6369 if (I != ConstantMultipleCache.end())
6370 return I->second;
6371
6372 APInt Result = getConstantMultipleImpl(S);
6373 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6374 assert(InsertPair.second && "Should insert a new key");
6375 return InsertPair.first->second;
6376}
6377
6379 APInt Multiple = getConstantMultiple(S);
6380 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6381}
6382
6384 return std::min(getConstantMultiple(S).countTrailingZeros(),
6385 (unsigned)getTypeSizeInBits(S->getType()));
6386}
6387
6388/// Helper method to assign a range to V from metadata present in the IR.
6389static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6390 if (Instruction *I = dyn_cast<Instruction>(V)) {
6391 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6392 return getConstantRangeFromMetadata(*MD);
6393 if (const auto *CB = dyn_cast<CallBase>(V))
6394 if (std::optional<ConstantRange> Range = CB->getRange())
6395 return Range;
6396 }
6397 if (auto *A = dyn_cast<Argument>(V))
6398 if (std::optional<ConstantRange> Range = A->getRange())
6399 return Range;
6400
6401 return std::nullopt;
6402}
6403
6405 SCEV::NoWrapFlags Flags) {
6406 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6407 AddRec->setNoWrapFlags(Flags);
6408 UnsignedRanges.erase(AddRec);
6409 SignedRanges.erase(AddRec);
6410 ConstantMultipleCache.erase(AddRec);
6411 }
6412}
6413
6414ConstantRange ScalarEvolution::
6415getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6416 const DataLayout &DL = getDataLayout();
6417
6418 unsigned BitWidth = getTypeSizeInBits(U->getType());
6419 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6420
6421 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6422 // use information about the trip count to improve our available range. Note
6423 // that the trip count independent cases are already handled by known bits.
6424 // WARNING: The definition of recurrence used here is subtly different than
6425 // the one used by AddRec (and thus most of this file). Step is allowed to
6426 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6427 // and other addrecs in the same loop (for non-affine addrecs). The code
6428 // below intentionally handles the case where step is not loop invariant.
6429 auto *P = dyn_cast<PHINode>(U->getValue());
6430 if (!P)
6431 return FullSet;
6432
6433 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6434 // even the values that are not available in these blocks may come from them,
6435 // and this leads to false-positive recurrence test.
6436 for (auto *Pred : predecessors(P->getParent()))
6437 if (!DT.isReachableFromEntry(Pred))
6438 return FullSet;
6439
6440 BinaryOperator *BO;
6441 Value *Start, *Step;
6442 if (!matchSimpleRecurrence(P, BO, Start, Step))
6443 return FullSet;
6444
6445 // If we found a recurrence in reachable code, we must be in a loop. Note
6446 // that BO might be in some subloop of L, and that's completely okay.
6447 auto *L = LI.getLoopFor(P->getParent());
6448 assert(L && L->getHeader() == P->getParent());
6449 if (!L->contains(BO->getParent()))
6450 // NOTE: This bailout should be an assert instead. However, asserting
6451 // the condition here exposes a case where LoopFusion is querying SCEV
6452 // with malformed loop information during the midst of the transform.
6453 // There doesn't appear to be an obvious fix, so for the moment bailout
6454 // until the caller issue can be fixed. PR49566 tracks the bug.
6455 return FullSet;
6456
6457 // TODO: Extend to other opcodes such as mul, and div
6458 switch (BO->getOpcode()) {
6459 default:
6460 return FullSet;
6461 case Instruction::AShr:
6462 case Instruction::LShr:
6463 case Instruction::Shl:
6464 break;
6465 };
6466
6467 if (BO->getOperand(0) != P)
6468 // TODO: Handle the power function forms some day.
6469 return FullSet;
6470
6471 unsigned TC = getSmallConstantMaxTripCount(L);
6472 if (!TC || TC >= BitWidth)
6473 return FullSet;
6474
6475 auto KnownStart = computeKnownBits(Start, DL, 0, &AC, nullptr, &DT);
6476 auto KnownStep = computeKnownBits(Step, DL, 0, &AC, nullptr, &DT);
6477 assert(KnownStart.getBitWidth() == BitWidth &&
6478 KnownStep.getBitWidth() == BitWidth);
6479
6480 // Compute total shift amount, being careful of overflow and bitwidths.
6481 auto MaxShiftAmt = KnownStep.getMaxValue();
6482 APInt TCAP(BitWidth, TC-1);
6483 bool Overflow = false;
6484 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6485 if (Overflow)
6486 return FullSet;
6487
6488 switch (BO->getOpcode()) {
6489 default:
6490 llvm_unreachable("filtered out above");
6491 case Instruction::AShr: {
6492 // For each ashr, three cases:
6493 // shift = 0 => unchanged value
6494 // saturation => 0 or -1
6495 // other => a value closer to zero (of the same sign)
6496 // Thus, the end value is closer to zero than the start.
6497 auto KnownEnd = KnownBits::ashr(KnownStart,
6498 KnownBits::makeConstant(TotalShift));
6499 if (KnownStart.isNonNegative())
6500 // Analogous to lshr (simply not yet canonicalized)
6501 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6502 KnownStart.getMaxValue() + 1);
6503 if (KnownStart.isNegative())
6504 // End >=u Start && End <=s Start
6505 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6506 KnownEnd.getMaxValue() + 1);
6507 break;
6508 }
6509 case Instruction::LShr: {
6510 // For each lshr, three cases:
6511 // shift = 0 => unchanged value
6512 // saturation => 0
6513 // other => a smaller positive number
6514 // Thus, the low end of the unsigned range is the last value produced.
6515 auto KnownEnd = KnownBits::lshr(KnownStart,
6516 KnownBits::makeConstant(TotalShift));
6517 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6518 KnownStart.getMaxValue() + 1);
6519 }
6520 case Instruction::Shl: {
6521 // Iff no bits are shifted out, value increases on every shift.
6522 auto KnownEnd = KnownBits::shl(KnownStart,
6523 KnownBits::makeConstant(TotalShift));
6524 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6525 return ConstantRange(KnownStart.getMinValue(),
6526 KnownEnd.getMaxValue() + 1);
6527 break;
6528 }
6529 };
6530 return FullSet;
6531}
6532
6533const ConstantRange &
6534ScalarEvolution::getRangeRefIter(const SCEV *S,
6535 ScalarEvolution::RangeSignHint SignHint) {
6537 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6538 : SignedRanges;
6541
6542 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6543 // SCEVUnknown PHI node.
6544 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6545 if (!Seen.insert(Expr).second)
6546 return;
6547 if (Cache.contains(Expr))
6548 return;
6549 switch (Expr->getSCEVType()) {
6550 case scUnknown:
6551 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6552 break;
6553 [[fallthrough]];
6554 case scConstant:
6555 case scVScale:
6556 case scTruncate:
6557 case scZeroExtend:
6558 case scSignExtend:
6559 case scPtrToInt:
6560 case scAddExpr:
6561 case scMulExpr:
6562 case scUDivExpr:
6563 case scAddRecExpr:
6564 case scUMaxExpr:
6565 case scSMaxExpr:
6566 case scUMinExpr:
6567 case scSMinExpr:
6569 WorkList.push_back(Expr);
6570 break;
6571 case scCouldNotCompute:
6572 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6573 }
6574 };
6575 AddToWorklist(S);
6576
6577 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6578 for (unsigned I = 0; I != WorkList.size(); ++I) {
6579 const SCEV *P = WorkList[I];
6580 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6581 // If it is not a `SCEVUnknown`, just recurse into operands.
6582 if (!UnknownS) {
6583 for (const SCEV *Op : P->operands())
6584 AddToWorklist(Op);
6585 continue;
6586 }
6587 // `SCEVUnknown`'s require special treatment.
6588 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6589 if (!PendingPhiRangesIter.insert(P).second)
6590 continue;
6591 for (auto &Op : reverse(P->operands()))
6592 AddToWorklist(getSCEV(Op));
6593 }
6594 }
6595
6596 if (!WorkList.empty()) {
6597 // Use getRangeRef to compute ranges for items in the worklist in reverse
6598 // order. This will force ranges for earlier operands to be computed before
6599 // their users in most cases.
6600 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6601 getRangeRef(P, SignHint);
6602
6603 if (auto *UnknownS = dyn_cast<SCEVUnknown>(P))
6604 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue()))
6605 PendingPhiRangesIter.erase(P);
6606 }
6607 }
6608
6609 return getRangeRef(S, SignHint, 0);
6610}
6611
6612/// Determine the range for a particular SCEV. If SignHint is
6613/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6614/// with a "cleaner" unsigned (resp. signed) representation.
6615const ConstantRange &ScalarEvolution::getRangeRef(
6616 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6618 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6619 : SignedRanges;
6621 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6623
6624 // See if we've computed this range already.
6626 if (I != Cache.end())
6627 return I->second;
6628
6629 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6630 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6631
6632 // Switch to iteratively computing the range for S, if it is part of a deeply
6633 // nested expression.
6635 return getRangeRefIter(S, SignHint);
6636
6637 unsigned BitWidth = getTypeSizeInBits(S->getType());
6638 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6639 using OBO = OverflowingBinaryOperator;
6640
6641 // If the value has known zeros, the maximum value will have those known zeros
6642 // as well.
6643 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6644 APInt Multiple = getNonZeroConstantMultiple(S);
6645 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6646 if (!Remainder.isZero())
6647 ConservativeResult =
6649 APInt::getMaxValue(BitWidth) - Remainder + 1);
6650 }
6651 else {
6653 if (TZ != 0) {
6654 ConservativeResult = ConstantRange(
6656 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6657 }
6658 }
6659
6660 switch (S->getSCEVType()) {
6661 case scConstant:
6662 llvm_unreachable("Already handled above.");
6663 case scVScale:
6664 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6665 case scTruncate: {
6666 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6667 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6668 return setRange(
6669 Trunc, SignHint,
6670 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6671 }
6672 case scZeroExtend: {
6673 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6674 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6675 return setRange(
6676 ZExt, SignHint,
6677 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6678 }
6679 case scSignExtend: {
6680 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6681 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6682 return setRange(
6683 SExt, SignHint,
6684 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6685 }
6686 case scPtrToInt: {
6687 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(S);
6688 ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint, Depth + 1);
6689 return setRange(PtrToInt, SignHint, X);
6690 }
6691 case scAddExpr: {
6692 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6693 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6694 unsigned WrapType = OBO::AnyWrap;
6695 if (Add->hasNoSignedWrap())
6696 WrapType |= OBO::NoSignedWrap;
6697 if (Add->hasNoUnsignedWrap())
6698 WrapType |= OBO::NoUnsignedWrap;
6699 for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
6700 X = X.addWithNoWrap(getRangeRef(Add->getOperand(i), SignHint, Depth + 1),
6701 WrapType, RangeType);
6702 return setRange(Add, SignHint,
6703 ConservativeResult.intersectWith(X, RangeType));
6704 }
6705 case scMulExpr: {
6706 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6707 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6708 for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
6709 X = X.multiply(getRangeRef(Mul->getOperand(i), SignHint, Depth + 1));
6710 return setRange(Mul, SignHint,
6711 ConservativeResult.intersectWith(X, RangeType));
6712 }
6713 case scUDivExpr: {
6714 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6715 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6716 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6717 return setRange(UDiv, SignHint,
6718 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6719 }
6720 case scAddRecExpr: {
6721 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6722 // If there's no unsigned wrap, the value will never be less than its
6723 // initial value.
6724 if (AddRec->hasNoUnsignedWrap()) {
6725 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6726 if (!UnsignedMinValue.isZero())
6727 ConservativeResult = ConservativeResult.intersectWith(
6728 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6729 }
6730
6731 // If there's no signed wrap, and all the operands except initial value have
6732 // the same sign or zero, the value won't ever be:
6733 // 1: smaller than initial value if operands are non negative,
6734 // 2: bigger than initial value if operands are non positive.
6735 // For both cases, value can not cross signed min/max boundary.
6736 if (AddRec->hasNoSignedWrap()) {
6737 bool AllNonNeg = true;
6738 bool AllNonPos = true;
6739 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6740 if (!isKnownNonNegative(AddRec->getOperand(i)))
6741 AllNonNeg = false;
6742 if (!isKnownNonPositive(AddRec->getOperand(i)))
6743 AllNonPos = false;
6744 }
6745 if (AllNonNeg)
6746 ConservativeResult = ConservativeResult.intersectWith(
6749 RangeType);
6750 else if (AllNonPos)
6751 ConservativeResult = ConservativeResult.intersectWith(
6753 getSignedRangeMax(AddRec->getStart()) +
6754 1),
6755 RangeType);
6756 }
6757
6758 // TODO: non-affine addrec
6759 if (AddRec->isAffine()) {
6760 const SCEV *MaxBEScev =
6762 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
6763 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
6764
6765 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
6766 // MaxBECount's active bits are all <= AddRec's bit width.
6767 if (MaxBECount.getBitWidth() > BitWidth &&
6768 MaxBECount.getActiveBits() <= BitWidth)
6769 MaxBECount = MaxBECount.trunc(BitWidth);
6770 else if (MaxBECount.getBitWidth() < BitWidth)
6771 MaxBECount = MaxBECount.zext(BitWidth);
6772
6773 if (MaxBECount.getBitWidth() == BitWidth) {
6774 auto RangeFromAffine = getRangeForAffineAR(
6775 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6776 ConservativeResult =
6777 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
6778
6779 auto RangeFromFactoring = getRangeViaFactoring(
6780 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6781 ConservativeResult =
6782 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
6783 }
6784 }
6785
6786 // Now try symbolic BE count and more powerful methods.
6788 const SCEV *SymbolicMaxBECount =
6790 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
6791 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
6792 AddRec->hasNoSelfWrap()) {
6793 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
6794 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
6795 ConservativeResult =
6796 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
6797 }
6798 }
6799 }
6800
6801 return setRange(AddRec, SignHint, std::move(ConservativeResult));
6802 }
6803 case scUMaxExpr:
6804 case scSMaxExpr:
6805 case scUMinExpr:
6806 case scSMinExpr:
6807 case scSequentialUMinExpr: {
6809 switch (S->getSCEVType()) {
6810 case scUMaxExpr:
6811 ID = Intrinsic::umax;
6812 break;
6813 case scSMaxExpr:
6814 ID = Intrinsic::smax;
6815 break;
6816 case scUMinExpr:
6818 ID = Intrinsic::umin;
6819 break;
6820 case scSMinExpr:
6821 ID = Intrinsic::smin;
6822 break;
6823 default:
6824 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
6825 }
6826
6827 const auto *NAry = cast<SCEVNAryExpr>(S);
6828 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
6829 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
6830 X = X.intrinsic(
6831 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
6832 return setRange(S, SignHint,
6833 ConservativeResult.intersectWith(X, RangeType));
6834 }
6835 case scUnknown: {
6836 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6837 Value *V = U->getValue();
6838
6839 // Check if the IR explicitly contains !range metadata.
6840 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
6841 if (MDRange)
6842 ConservativeResult =
6843 ConservativeResult.intersectWith(*MDRange, RangeType);
6844
6845 // Use facts about recurrences in the underlying IR. Note that add
6846 // recurrences are AddRecExprs and thus don't hit this path. This
6847 // primarily handles shift recurrences.
6848 auto CR = getRangeForUnknownRecurrence(U);
6849 ConservativeResult = ConservativeResult.intersectWith(CR);
6850
6851 // See if ValueTracking can give us a useful range.
6852 const DataLayout &DL = getDataLayout();
6853 KnownBits Known = computeKnownBits(V, DL, 0, &AC, nullptr, &DT);
6854 if (Known.getBitWidth() != BitWidth)
6855 Known = Known.zextOrTrunc(BitWidth);
6856
6857 // ValueTracking may be able to compute a tighter result for the number of
6858 // sign bits than for the value of those sign bits.
6859 unsigned NS = ComputeNumSignBits(V, DL, 0, &AC, nullptr, &DT);
6860 if (U->getType()->isPointerTy()) {
6861 // If the pointer size is larger than the index size type, this can cause
6862 // NS to be larger than BitWidth. So compensate for this.
6863 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
6864 int ptrIdxDiff = ptrSize - BitWidth;
6865 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
6866 NS -= ptrIdxDiff;
6867 }
6868
6869 if (NS > 1) {
6870 // If we know any of the sign bits, we know all of the sign bits.
6871 if (!Known.Zero.getHiBits(NS).isZero())
6872 Known.Zero.setHighBits(NS);
6873 if (!Known.One.getHiBits(NS).isZero())
6874 Known.One.setHighBits(NS);
6875 }
6876
6877 if (Known.getMinValue() != Known.getMaxValue() + 1)
6878 ConservativeResult = ConservativeResult.intersectWith(
6879 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
6880 RangeType);
6881 if (NS > 1)
6882 ConservativeResult = ConservativeResult.intersectWith(
6884 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
6885 RangeType);
6886
6887 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
6888 // Strengthen the range if the underlying IR value is a
6889 // global/alloca/heap allocation using the size of the object.
6890 ObjectSizeOpts Opts;
6891 Opts.RoundToAlign = false;
6892 Opts.NullIsUnknownSize = true;
6893 uint64_t ObjSize;
6894 if ((isa<GlobalVariable>(V) || isa<AllocaInst>(V) ||
6895 isAllocationFn(V, &TLI)) &&
6896 getObjectSize(V, ObjSize, DL, &TLI, Opts) && ObjSize > 1) {
6897 // The highest address the object can start is ObjSize bytes before the
6898 // end (unsigned max value). If this value is not a multiple of the
6899 // alignment, the last possible start value is the next lowest multiple
6900 // of the alignment. Note: The computations below cannot overflow,
6901 // because if they would there's no possible start address for the
6902 // object.
6903 APInt MaxVal = APInt::getMaxValue(BitWidth) - APInt(BitWidth, ObjSize);
6904 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
6905 uint64_t Rem = MaxVal.urem(Align);
6906 MaxVal -= APInt(BitWidth, Rem);
6907 APInt MinVal = APInt::getZero(BitWidth);
6908 if (llvm::isKnownNonZero(V, DL))
6909 MinVal = Align;
6910 ConservativeResult = ConservativeResult.intersectWith(
6911 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
6912 }
6913 }
6914
6915 // A range of Phi is a subset of union of all ranges of its input.
6916 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
6917 // Make sure that we do not run over cycled Phis.
6918 if (PendingPhiRanges.insert(Phi).second) {
6919 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
6920
6921 for (const auto &Op : Phi->operands()) {
6922 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
6923 RangeFromOps = RangeFromOps.unionWith(OpRange);
6924 // No point to continue if we already have a full set.
6925 if (RangeFromOps.isFullSet())
6926 break;
6927 }
6928 ConservativeResult =
6929 ConservativeResult.intersectWith(RangeFromOps, RangeType);
6930 bool Erased = PendingPhiRanges.erase(Phi);
6931 assert(Erased && "Failed to erase Phi properly?");
6932 (void)Erased;
6933 }
6934 }
6935
6936 // vscale can't be equal to zero
6937 if (const auto *II = dyn_cast<IntrinsicInst>(V))
6938 if (II->getIntrinsicID() == Intrinsic::vscale) {
6940 ConservativeResult = ConservativeResult.difference(Disallowed);
6941 }
6942
6943 return setRange(U, SignHint, std::move(ConservativeResult));
6944 }
6945 case scCouldNotCompute:
6946 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6947 }
6948
6949 return setRange(S, SignHint, std::move(ConservativeResult));
6950}
6951
6952// Given a StartRange, Step and MaxBECount for an expression compute a range of
6953// values that the expression can take. Initially, the expression has a value
6954// from StartRange and then is changed by Step up to MaxBECount times. Signed
6955// argument defines if we treat Step as signed or unsigned.
6957 const ConstantRange &StartRange,
6958 const APInt &MaxBECount,
6959 bool Signed) {
6960 unsigned BitWidth = Step.getBitWidth();
6961 assert(BitWidth == StartRange.getBitWidth() &&
6962 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
6963 // If either Step or MaxBECount is 0, then the expression won't change, and we
6964 // just need to return the initial range.
6965 if (Step == 0 || MaxBECount == 0)
6966 return StartRange;
6967
6968 // If we don't know anything about the initial value (i.e. StartRange is
6969 // FullRange), then we don't know anything about the final range either.
6970 // Return FullRange.
6971 if (StartRange.isFullSet())
6972 return ConstantRange::getFull(BitWidth);
6973
6974 // If Step is signed and negative, then we use its absolute value, but we also
6975 // note that we're moving in the opposite direction.
6976 bool Descending = Signed && Step.isNegative();
6977
6978 if (Signed)
6979 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
6980 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
6981 // This equations hold true due to the well-defined wrap-around behavior of
6982 // APInt.
6983 Step = Step.abs();
6984
6985 // Check if Offset is more than full span of BitWidth. If it is, the
6986 // expression is guaranteed to overflow.
6987 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
6988 return ConstantRange::getFull(BitWidth);
6989
6990 // Offset is by how much the expression can change. Checks above guarantee no
6991 // overflow here.
6992 APInt Offset = Step * MaxBECount;
6993
6994 // Minimum value of the final range will match the minimal value of StartRange
6995 // if the expression is increasing and will be decreased by Offset otherwise.
6996 // Maximum value of the final range will match the maximal value of StartRange
6997 // if the expression is decreasing and will be increased by Offset otherwise.
6998 APInt StartLower = StartRange.getLower();
6999 APInt StartUpper = StartRange.getUpper() - 1;
7000 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
7001 : (StartUpper + std::move(Offset));
7002
7003 // It's possible that the new minimum/maximum value will fall into the initial
7004 // range (due to wrap around). This means that the expression can take any
7005 // value in this bitwidth, and we have to return full range.
7006 if (StartRange.contains(MovedBoundary))
7007 return ConstantRange::getFull(BitWidth);
7008
7009 APInt NewLower =
7010 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7011 APInt NewUpper =
7012 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7013 NewUpper += 1;
7014
7015 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7016 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7017}
7018
7019ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7020 const SCEV *Step,
7021 const APInt &MaxBECount) {
7022 assert(getTypeSizeInBits(Start->getType()) ==
7023 getTypeSizeInBits(Step->getType()) &&
7024 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7025 "mismatched bit widths");
7026
7027 // First, consider step signed.
7028 ConstantRange StartSRange = getSignedRange(Start);
7029 ConstantRange StepSRange = getSignedRange(Step);
7030
7031 // If Step can be both positive and negative, we need to find ranges for the
7032 // maximum absolute step values in both directions and union them.
7034 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7036 StartSRange, MaxBECount,
7037 /* Signed = */ true));
7038
7039 // Next, consider step unsigned.
7041 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7042 /* Signed = */ false);
7043
7044 // Finally, intersect signed and unsigned ranges.
7046}
7047
7048ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7049 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7050 ScalarEvolution::RangeSignHint SignHint) {
7051 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7052 assert(AddRec->hasNoSelfWrap() &&
7053 "This only works for non-self-wrapping AddRecs!");
7054 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7055 const SCEV *Step = AddRec->getStepRecurrence(*this);
7056 // Only deal with constant step to save compile time.
7057 if (!isa<SCEVConstant>(Step))
7058 return ConstantRange::getFull(BitWidth);
7059 // Let's make sure that we can prove that we do not self-wrap during
7060 // MaxBECount iterations. We need this because MaxBECount is a maximum
7061 // iteration count estimate, and we might infer nw from some exit for which we
7062 // do not know max exit count (or any other side reasoning).
7063 // TODO: Turn into assert at some point.
7064 if (getTypeSizeInBits(MaxBECount->getType()) >
7065 getTypeSizeInBits(AddRec->getType()))
7066 return ConstantRange::getFull(BitWidth);
7067 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7068 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7069 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7070 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7071 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7072 MaxItersWithoutWrap))
7073 return ConstantRange::getFull(BitWidth);
7074
7075 ICmpInst::Predicate LEPred =
7077 ICmpInst::Predicate GEPred =
7079 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7080
7081 // We know that there is no self-wrap. Let's take Start and End values and
7082 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7083 // the iteration. They either lie inside the range [Min(Start, End),
7084 // Max(Start, End)] or outside it:
7085 //
7086 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7087 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7088 //
7089 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7090 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7091 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7092 // Start <= End and step is positive, or Start >= End and step is negative.
7093 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7094 ConstantRange StartRange = getRangeRef(Start, SignHint);
7095 ConstantRange EndRange = getRangeRef(End, SignHint);
7096 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7097 // If they already cover full iteration space, we will know nothing useful
7098 // even if we prove what we want to prove.
7099 if (RangeBetween.isFullSet())
7100 return RangeBetween;
7101 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7102 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7103 : RangeBetween.isWrappedSet();
7104 if (IsWrappedSet)
7105 return ConstantRange::getFull(BitWidth);
7106
7107 if (isKnownPositive(Step) &&
7108 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7109 return RangeBetween;
7110 if (isKnownNegative(Step) &&
7111 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7112 return RangeBetween;
7113 return ConstantRange::getFull(BitWidth);
7114}
7115
7116ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7117 const SCEV *Step,
7118 const APInt &MaxBECount) {
7119 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7120 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7121
7122 unsigned BitWidth = MaxBECount.getBitWidth();
7123 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7124 getTypeSizeInBits(Step->getType()) == BitWidth &&
7125 "mismatched bit widths");
7126
7127 struct SelectPattern {
7128 Value *Condition = nullptr;
7129 APInt TrueValue;
7130 APInt FalseValue;
7131
7132 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7133 const SCEV *S) {
7134 std::optional<unsigned> CastOp;
7135 APInt Offset(BitWidth, 0);
7136
7138 "Should be!");
7139
7140 // Peel off a constant offset:
7141 if (auto *SA = dyn_cast<SCEVAddExpr>(S)) {
7142 // In the future we could consider being smarter here and handle
7143 // {Start+Step,+,Step} too.
7144 if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0)))
7145 return;
7146
7147 Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt();
7148 S = SA->getOperand(1);
7149 }
7150
7151 // Peel off a cast operation
7152 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7153 CastOp = SCast->getSCEVType();
7154 S = SCast->getOperand();
7155 }
7156
7157 using namespace llvm::PatternMatch;
7158
7159 auto *SU = dyn_cast<SCEVUnknown>(S);
7160 const APInt *TrueVal, *FalseVal;
7161 if (!SU ||
7162 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7163 m_APInt(FalseVal)))) {
7164 Condition = nullptr;
7165 return;
7166 }
7167
7168 TrueValue = *TrueVal;
7169 FalseValue = *FalseVal;
7170
7171 // Re-apply the cast we peeled off earlier
7172 if (CastOp)
7173 switch (*CastOp) {
7174 default:
7175 llvm_unreachable("Unknown SCEV cast type!");
7176
7177 case scTruncate:
7178 TrueValue = TrueValue.trunc(BitWidth);
7179 FalseValue = FalseValue.trunc(BitWidth);
7180 break;
7181 case scZeroExtend:
7182 TrueValue = TrueValue.zext(BitWidth);
7183 FalseValue = FalseValue.zext(BitWidth);
7184 break;
7185 case scSignExtend:
7186 TrueValue = TrueValue.sext(BitWidth);
7187 FalseValue = FalseValue.sext(BitWidth);
7188 break;
7189 }
7190
7191 // Re-apply the constant offset we peeled off earlier
7192 TrueValue += Offset;
7193 FalseValue += Offset;
7194 }
7195
7196 bool isRecognized() { return Condition != nullptr; }
7197 };
7198
7199 SelectPattern StartPattern(*this, BitWidth, Start);
7200 if (!StartPattern.isRecognized())
7201 return ConstantRange::getFull(BitWidth);
7202
7203 SelectPattern StepPattern(*this, BitWidth, Step);
7204 if (!StepPattern.isRecognized())
7205 return ConstantRange::getFull(BitWidth);
7206
7207 if (StartPattern.Condition != StepPattern.Condition) {
7208 // We don't handle this case today; but we could, by considering four
7209 // possibilities below instead of two. I'm not sure if there are cases where
7210 // that will help over what getRange already does, though.
7211 return ConstantRange::getFull(BitWidth);
7212 }
7213
7214 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7215 // construct arbitrary general SCEV expressions here. This function is called
7216 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7217 // say) can end up caching a suboptimal value.
7218
7219 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7220 // C2352 and C2512 (otherwise it isn't needed).
7221
7222 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7223 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7224 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7225 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7226
7227 ConstantRange TrueRange =
7228 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7229 ConstantRange FalseRange =
7230 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7231
7232 return TrueRange.unionWith(FalseRange);
7233}
7234
7235SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7236 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7237 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7238
7239 // Return early if there are no flags to propagate to the SCEV.
7241 if (BinOp->hasNoUnsignedWrap())
7243 if (BinOp->hasNoSignedWrap())
7245 if (Flags == SCEV::FlagAnyWrap)
7246 return SCEV::FlagAnyWrap;
7247
7248 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7249}
7250
7251const Instruction *
7252ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7253 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7254 return &*AddRec->getLoop()->getHeader()->begin();
7255 if (auto *U = dyn_cast<SCEVUnknown>(S))
7256 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7257 return I;
7258 return nullptr;
7259}
7260
7261const Instruction *
7262ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops,
7263 bool &Precise) {
7264 Precise = true;
7265 // Do a bounded search of the def relation of the requested SCEVs.
7268 auto pushOp = [&](const SCEV *S) {
7269 if (!Visited.insert(S).second)
7270 return;
7271 // Threshold of 30 here is arbitrary.
7272 if (Visited.size() > 30) {
7273 Precise = false;
7274 return;
7275 }
7276 Worklist.push_back(S);
7277 };
7278
7279 for (const auto *S : Ops)
7280 pushOp(S);
7281
7282 const Instruction *Bound = nullptr;
7283 while (!Worklist.empty()) {
7284 auto *S = Worklist.pop_back_val();
7285 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7286 if (!Bound || DT.dominates(Bound, DefI))
7287 Bound = DefI;
7288 } else {
7289 for (const auto *Op : S->operands())
7290 pushOp(Op);
7291 }
7292 }
7293 return Bound ? Bound : &*F.getEntryBlock().begin();
7294}
7295
7296const Instruction *
7297ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) {
7298 bool Discard;
7299 return getDefiningScopeBound(Ops, Discard);
7300}
7301
7302bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7303 const Instruction *B) {
7304 if (A->getParent() == B->getParent() &&
7306 B->getIterator()))
7307 return true;
7308
7309 auto *BLoop = LI.getLoopFor(B->getParent());
7310 if (BLoop && BLoop->getHeader() == B->getParent() &&
7311 BLoop->getLoopPreheader() == A->getParent() &&
7313 A->getParent()->end()) &&
7314 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7315 B->getIterator()))
7316 return true;
7317 return false;
7318}
7319
7320
7321bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7322 // Only proceed if we can prove that I does not yield poison.
7324 return false;
7325
7326 // At this point we know that if I is executed, then it does not wrap
7327 // according to at least one of NSW or NUW. If I is not executed, then we do
7328 // not know if the calculation that I represents would wrap. Multiple
7329 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7330 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7331 // derived from other instructions that map to the same SCEV. We cannot make
7332 // that guarantee for cases where I is not executed. So we need to find a
7333 // upper bound on the defining scope for the SCEV, and prove that I is
7334 // executed every time we enter that scope. When the bounding scope is a
7335 // loop (the common case), this is equivalent to proving I executes on every
7336 // iteration of that loop.
7338 for (const Use &Op : I->operands()) {
7339 // I could be an extractvalue from a call to an overflow intrinsic.
7340 // TODO: We can do better here in some cases.
7341 if (isSCEVable(Op->getType()))
7342 SCEVOps.push_back(getSCEV(Op));
7343 }
7344 auto *DefI = getDefiningScopeBound(SCEVOps);
7345 return isGuaranteedToTransferExecutionTo(DefI, I);
7346}
7347
7348bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7349 // If we know that \c I can never be poison period, then that's enough.
7350 if (isSCEVExprNeverPoison(I))
7351 return true;
7352
7353 // If the loop only has one exit, then we know that, if the loop is entered,
7354 // any instruction dominating that exit will be executed. If any such
7355 // instruction would result in UB, the addrec cannot be poison.
7356 //
7357 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7358 // also handles uses outside the loop header (they just need to dominate the
7359 // single exit).
7360
7361 auto *ExitingBB = L->getExitingBlock();
7362 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7363 return false;
7364
7367
7368 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7369 // things that are known to be poison under that assumption go on the
7370 // Worklist.
7371 KnownPoison.insert(I);
7372 Worklist.push_back(I);
7373
7374 while (!Worklist.empty()) {
7375 const Instruction *Poison = Worklist.pop_back_val();
7376
7377 for (const Use &U : Poison->uses()) {
7378 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7379 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7380 DT.dominates(PoisonUser->getParent(), ExitingBB))
7381 return true;
7382
7383 if (propagatesPoison(U) && L->contains(PoisonUser))
7384 if (KnownPoison.insert(PoisonUser).second)
7385 Worklist.push_back(PoisonUser);
7386 }
7387 }
7388
7389 return false;
7390}
7391
7392ScalarEvolution::LoopProperties
7393ScalarEvolution::getLoopProperties(const Loop *L) {
7394 using LoopProperties = ScalarEvolution::LoopProperties;
7395
7396 auto Itr = LoopPropertiesCache.find(L);
7397 if (Itr == LoopPropertiesCache.end()) {
7398 auto HasSideEffects = [](Instruction *I) {
7399 if (auto *SI = dyn_cast<StoreInst>(I))
7400 return !SI->isSimple();
7401
7402 return I->mayThrow() || I->mayWriteToMemory();
7403 };
7404
7405 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7406 /*HasNoSideEffects*/ true};
7407
7408 for (auto *BB : L->getBlocks())
7409 for (auto &I : *BB) {
7411 LP.HasNoAbnormalExits = false;
7412 if (HasSideEffects(&I))
7413 LP.HasNoSideEffects = false;
7414 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7415 break; // We're already as pessimistic as we can get.
7416 }
7417
7418 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7419 assert(InsertPair.second && "We just checked!");
7420 Itr = InsertPair.first;
7421 }
7422
7423 return Itr->second;
7424}
7425
7427 // A mustprogress loop without side effects must be finite.
7428 // TODO: The check used here is very conservative. It's only *specific*
7429 // side effects which are well defined in infinite loops.
7430 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7431}
7432
7433const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7434 // Worklist item with a Value and a bool indicating whether all operands have
7435 // been visited already.
7438
7439 Stack.emplace_back(V, true);
7440 Stack.emplace_back(V, false);
7441 while (!Stack.empty()) {
7442 auto E = Stack.pop_back_val();
7443 Value *CurV = E.getPointer();
7444
7445 if (getExistingSCEV(CurV))
7446 continue;
7447
7449 const SCEV *CreatedSCEV = nullptr;
7450 // If all operands have been visited already, create the SCEV.
7451 if (E.getInt()) {
7452 CreatedSCEV = createSCEV(CurV);
7453 } else {
7454 // Otherwise get the operands we need to create SCEV's for before creating
7455 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7456 // just use it.
7457 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7458 }
7459
7460 if (CreatedSCEV) {
7461 insertValueToMap(CurV, CreatedSCEV);
7462 } else {
7463 // Queue CurV for SCEV creation, followed by its's operands which need to
7464 // be constructed first.
7465 Stack.emplace_back(CurV, true);
7466 for (Value *Op : Ops)
7467 Stack.emplace_back(Op, false);
7468 }
7469 }
7470
7471 return getExistingSCEV(V);
7472}
7473
7474const SCEV *
7475ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7476 if (!isSCEVable(V->getType()))
7477 return getUnknown(V);
7478
7479 if (Instruction *I = dyn_cast<Instruction>(V)) {
7480 // Don't attempt to analyze instructions in blocks that aren't
7481 // reachable. Such instructions don't matter, and they aren't required
7482 // to obey basic rules for definitions dominating uses which this
7483 // analysis depends on.
7484 if (!DT.isReachableFromEntry(I->getParent()))
7485 return getUnknown(PoisonValue::get(V->getType()));
7486 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7487 return getConstant(CI);
7488 else if (isa<GlobalAlias>(V))
7489 return getUnknown(V);
7490 else if (!isa<ConstantExpr>(V))
7491 return getUnknown(V);
7492
7493 Operator *U = cast<Operator>(V);
7494 if (auto BO =
7495 MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) {
7496 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7497 switch (BO->Opcode) {
7498 case Instruction::Add:
7499 case Instruction::Mul: {
7500 // For additions and multiplications, traverse add/mul chains for which we
7501 // can potentially create a single SCEV, to reduce the number of
7502 // get{Add,Mul}Expr calls.
7503 do {
7504 if (BO->Op) {
7505 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7506 Ops.push_back(BO->Op);
7507 break;
7508 }
7509 }
7510 Ops.push_back(BO->RHS);
7511 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7512 dyn_cast<Instruction>(V));
7513 if (!NewBO ||
7514 (BO->Opcode == Instruction::Add &&
7515 (NewBO->Opcode != Instruction::Add &&
7516 NewBO->Opcode != Instruction::Sub)) ||
7517 (BO->Opcode == Instruction::Mul &&
7518 NewBO->Opcode != Instruction::Mul)) {
7519 Ops.push_back(BO->LHS);
7520 break;
7521 }
7522 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7523 // requires a SCEV for the LHS.
7524 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7525 auto *I = dyn_cast<Instruction>(BO->Op);
7526 if (I && programUndefinedIfPoison(I)) {
7527 Ops.push_back(BO->LHS);
7528 break;
7529 }
7530 }
7531 BO = NewBO;
7532 } while (true);
7533 return nullptr;
7534 }
7535 case Instruction::Sub:
7536 case Instruction::UDiv:
7537 case Instruction::URem:
7538 break;
7539 case Instruction::AShr:
7540 case Instruction::Shl:
7541 case Instruction::Xor:
7542 if (!IsConstArg)
7543 return nullptr;
7544 break;
7545 case Instruction::And:
7546 case Instruction::Or:
7547 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7548 return nullptr;
7549 break;
7550 case Instruction::LShr:
7551 return getUnknown(V);
7552 default:
7553 llvm_unreachable("Unhandled binop");
7554 break;
7555 }
7556
7557 Ops.push_back(BO->LHS);
7558 Ops.push_back(BO->RHS);
7559 return nullptr;
7560 }
7561
7562 switch (U->getOpcode()) {
7563 case Instruction::Trunc:
7564 case Instruction::ZExt:
7565 case Instruction::SExt:
7566 case Instruction::PtrToInt:
7567 Ops.push_back(U->getOperand(0));
7568 return nullptr;
7569
7570 case Instruction::BitCast:
7571 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7572 Ops.push_back(U->getOperand(0));
7573 return nullptr;
7574 }
7575 return getUnknown(V);
7576
7577 case Instruction::SDiv:
7578 case Instruction::SRem:
7579 Ops.push_back(U->getOperand(0));
7580 Ops.push_back(U->getOperand(1));
7581 return nullptr;
7582
7583 case Instruction::GetElementPtr:
7584 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7585 "GEP source element type must be sized");
7586 for (Value *Index : U->operands())
7587 Ops.push_back(Index);
7588 return nullptr;
7589
7590 case Instruction::IntToPtr:
7591 return getUnknown(V);
7592
7593 case Instruction::PHI:
7594 // Keep constructing SCEVs' for phis recursively for now.
7595 return nullptr;
7596
7597 case Instruction::Select: {
7598 // Check if U is a select that can be simplified to a SCEVUnknown.
7599 auto CanSimplifyToUnknown = [this, U]() {
7600 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7601 return false;
7602
7603 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7604 if (!ICI)
7605 return false;
7606 Value *LHS = ICI->getOperand(0);
7607 Value *RHS = ICI->getOperand(1);
7608 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7609 ICI->getPredicate() == CmpInst::ICMP_NE) {
7610 if (!(isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()))
7611 return true;
7612 } else if (getTypeSizeInBits(LHS->getType()) >
7613 getTypeSizeInBits(U->getType()))
7614 return true;
7615 return false;
7616 };
7617 if (CanSimplifyToUnknown())
7618 return getUnknown(U);
7619
7620 for (Value *Inc : U->operands())
7621 Ops.push_back(Inc);
7622 return nullptr;
7623 break;
7624 }
7625 case Instruction::Call:
7626 case Instruction::Invoke:
7627 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7628 Ops.push_back(RV);
7629 return nullptr;
7630 }
7631
7632 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7633 switch (II->getIntrinsicID()) {
7634 case Intrinsic::abs:
7635 Ops.push_back(II->getArgOperand(0));
7636 return nullptr;
7637 case Intrinsic::umax:
7638 case Intrinsic::umin:
7639 case Intrinsic::smax:
7640 case Intrinsic::smin:
7641 case Intrinsic::usub_sat:
7642 case Intrinsic::uadd_sat:
7643 Ops.push_back(II->getArgOperand(0));
7644 Ops.push_back(II->getArgOperand(1));
7645 return nullptr;
7646 case Intrinsic::start_loop_iterations:
7647 case Intrinsic::annotation:
7648 case Intrinsic::ptr_annotation:
7649 Ops.push_back(II->getArgOperand(0));
7650 return nullptr;
7651 default:
7652 break;
7653 }
7654 }
7655 break;
7656 }
7657
7658 return nullptr;
7659}
7660
7661const SCEV *ScalarEvolution::createSCEV(Value *V) {
7662 if (!isSCEVable(V->getType()))
7663 return getUnknown(V);
7664
7665 if (Instruction *I = dyn_cast<Instruction>(V)) {
7666 // Don't attempt to analyze instructions in blocks that aren't
7667 // reachable. Such instructions don't matter, and they aren't required
7668 // to obey basic rules for definitions dominating uses which this
7669 // analysis depends on.
7670 if (!DT.isReachableFromEntry(I->getParent()))
7671 return getUnknown(PoisonValue::get(V->getType()));
7672 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7673 return getConstant(CI);
7674 else if (isa<GlobalAlias>(V))
7675 return getUnknown(V);
7676 else if (!isa<ConstantExpr>(V))
7677 return getUnknown(V);
7678
7679 const SCEV *LHS;
7680 const SCEV *RHS;
7681
7682 Operator *U = cast<Operator>(V);
7683 if (auto BO =
7684 MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) {
7685 switch (BO->Opcode) {
7686 case Instruction::Add: {
7687 // The simple thing to do would be to just call getSCEV on both operands
7688 // and call getAddExpr with the result. However if we're looking at a
7689 // bunch of things all added together, this can be quite inefficient,
7690 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7691 // Instead, gather up all the operands and make a single getAddExpr call.
7692 // LLVM IR canonical form means we need only traverse the left operands.
7694 do {
7695 if (BO->Op) {
7696 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7697 AddOps.push_back(OpSCEV);
7698 break;
7699 }
7700
7701 // If a NUW or NSW flag can be applied to the SCEV for this
7702 // addition, then compute the SCEV for this addition by itself
7703 // with a separate call to getAddExpr. We need to do that
7704 // instead of pushing the operands of the addition onto AddOps,
7705 // since the flags are only known to apply to this particular
7706 // addition - they may not apply to other additions that can be
7707 // formed with operands from AddOps.
7708 const SCEV *RHS = getSCEV(BO->RHS);
7709 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7710 if (Flags != SCEV::FlagAnyWrap) {
7711 const SCEV *LHS = getSCEV(BO->LHS);
7712 if (BO->Opcode == Instruction::Sub)
7713 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
7714 else
7715 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
7716 break;
7717 }
7718 }
7719
7720 if (BO->Opcode == Instruction::Sub)
7721 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
7722 else
7723 AddOps.push_back(getSCEV(BO->RHS));
7724
7725 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7726 dyn_cast<Instruction>(V));
7727 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7728 NewBO->Opcode != Instruction::Sub)) {
7729 AddOps.push_back(getSCEV(BO->LHS));
7730 break;
7731 }
7732 BO = NewBO;
7733 } while (true);
7734
7735 return getAddExpr(AddOps);
7736 }
7737
7738 case Instruction::Mul: {
7740 do {
7741 if (BO->Op) {
7742 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7743 MulOps.push_back(OpSCEV);
7744 break;
7745 }
7746
7747 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7748 if (Flags != SCEV::FlagAnyWrap) {
7749 LHS = getSCEV(BO->LHS);
7750 RHS = getSCEV(BO->RHS);
7751 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
7752 break;
7753 }
7754 }
7755
7756 MulOps.push_back(getSCEV(BO->RHS));
7757 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7758 dyn_cast<Instruction>(V));
7759 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7760 MulOps.push_back(getSCEV(BO->LHS));
7761 break;
7762 }
7763 BO = NewBO;
7764 } while (true);
7765
7766 return getMulExpr(MulOps);
7767 }
7768 case Instruction::UDiv:
7769 LHS = getSCEV(BO->LHS);
7770 RHS = getSCEV(BO->RHS);
7771 return getUDivExpr(LHS, RHS);
7772 case Instruction::URem:
7773 LHS = getSCEV(BO->LHS);
7774 RHS = getSCEV(BO->RHS);
7775 return getURemExpr(LHS, RHS);
7776 case Instruction::Sub: {
7778 if (BO->Op)
7779 Flags = getNoWrapFlagsFromUB(BO->Op);
7780 LHS = getSCEV(BO->LHS);
7781 RHS = getSCEV(BO->RHS);
7782 return getMinusSCEV(LHS, RHS, Flags);
7783 }
7784 case Instruction::And:
7785 // For an expression like x&255 that merely masks off the high bits,
7786 // use zext(trunc(x)) as the SCEV expression.
7787 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7788 if (CI->isZero())
7789 return getSCEV(BO->RHS);
7790 if (CI->isMinusOne())
7791 return getSCEV(BO->LHS);
7792 const APInt &A = CI->getValue();
7793
7794 // Instcombine's ShrinkDemandedConstant may strip bits out of
7795 // constants, obscuring what would otherwise be a low-bits mask.
7796 // Use computeKnownBits to compute what ShrinkDemandedConstant
7797 // knew about to reconstruct a low-bits mask value.
7798 unsigned LZ = A.countl_zero();
7799 unsigned TZ = A.countr_zero();
7800 unsigned BitWidth = A.getBitWidth();
7801 KnownBits Known(BitWidth);
7802 computeKnownBits(BO->LHS, Known, getDataLayout(),
7803 0, &AC, nullptr, &DT);
7804
7805 APInt EffectiveMask =
7806 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
7807 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
7808 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
7809 const SCEV *LHS = getSCEV(BO->LHS);
7810 const SCEV *ShiftedLHS = nullptr;
7811 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
7812 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
7813 // For an expression like (x * 8) & 8, simplify the multiply.
7814 unsigned MulZeros = OpC->getAPInt().countr_zero();
7815 unsigned GCD = std::min(MulZeros, TZ);
7816 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
7818 MulOps.push_back(getConstant(OpC->getAPInt().lshr(GCD)));
7819 append_range(MulOps, LHSMul->operands().drop_front());
7820 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
7821 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
7822 }
7823 }
7824 if (!ShiftedLHS)
7825 ShiftedLHS = getUDivExpr(LHS, MulCount);
7826 return getMulExpr(
7828 getTruncateExpr(ShiftedLHS,
7829 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
7830 BO->LHS->getType()),
7831 MulCount);
7832 }
7833 }
7834 // Binary `and` is a bit-wise `umin`.
7835 if (BO->LHS->getType()->isIntegerTy(1)) {
7836 LHS = getSCEV(BO->LHS);
7837 RHS = getSCEV(BO->RHS);
7838 return getUMinExpr(LHS, RHS);
7839 }
7840 break;
7841
7842 case Instruction::Or:
7843 // Binary `or` is a bit-wise `umax`.
7844 if (BO->LHS->getType()->isIntegerTy(1)) {
7845 LHS = getSCEV(BO->LHS);
7846 RHS = getSCEV(BO->RHS);
7847 return getUMaxExpr(LHS, RHS);
7848 }
7849 break;
7850
7851 case Instruction::Xor:
7852 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7853 // If the RHS of xor is -1, then this is a not operation.
7854 if (CI->isMinusOne())
7855 return getNotSCEV(getSCEV(BO->LHS));
7856
7857 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
7858 // This is a variant of the check for xor with -1, and it handles
7859 // the case where instcombine has trimmed non-demanded bits out
7860 // of an xor with -1.
7861 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
7862 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
7863 if (LBO->getOpcode() == Instruction::And &&
7864 LCI->getValue() == CI->getValue())
7865 if (const SCEVZeroExtendExpr *Z =
7866 dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) {
7867 Type *UTy = BO->LHS->getType();
7868 const SCEV *Z0 = Z->getOperand();
7869 Type *Z0Ty = Z0->getType();
7870 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
7871
7872 // If C is a low-bits mask, the zero extend is serving to
7873 // mask off the high bits. Complement the operand and
7874 // re-apply the zext.
7875 if (CI->getValue().isMask(Z0TySize))
7876 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
7877
7878 // If C is a single bit, it may be in the sign-bit position
7879 // before the zero-extend. In this case, represent the xor
7880 // using an add, which is equivalent, and re-apply the zext.
7881 APInt Trunc = CI->getValue().trunc(Z0TySize);
7882 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
7883 Trunc.isSignMask())
7884 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
7885 UTy);
7886 }
7887 }
7888 break;
7889
7890 case Instruction::Shl:
7891 // Turn shift left of a constant amount into a multiply.
7892 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
7893 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
7894
7895 // If the shift count is not less than the bitwidth, the result of
7896 // the shift is undefined. Don't try to analyze it, because the
7897 // resolution chosen here may differ from the resolution chosen in
7898 // other parts of the compiler.
7899 if (SA->getValue().uge(BitWidth))
7900 break;
7901
7902 // We can safely preserve the nuw flag in all cases. It's also safe to
7903 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
7904 // requires special handling. It can be preserved as long as we're not
7905 // left shifting by bitwidth - 1.
7906 auto Flags = SCEV::FlagAnyWrap;
7907 if (BO->Op) {
7908 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
7909 if ((MulFlags & SCEV::FlagNSW) &&
7910 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
7912 if (MulFlags & SCEV::FlagNUW)
7914 }
7915
7916 ConstantInt *X = ConstantInt::get(
7917 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
7918 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
7919 }
7920 break;
7921
7922 case Instruction::AShr:
7923 // AShr X, C, where C is a constant.
7924 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
7925 if (!CI)
7926 break;
7927
7928 Type *OuterTy = BO->LHS->getType();
7930 // If the shift count is not less than the bitwidth, the result of
7931 // the shift is undefined. Don't try to analyze it, because the
7932 // resolution chosen here may differ from the resolution chosen in
7933 // other parts of the compiler.
7934 if (CI->getValue().uge(BitWidth))
7935 break;
7936
7937 if (CI->isZero())
7938 return getSCEV(BO->LHS); // shift by zero --> noop
7939
7940 uint64_t AShrAmt = CI->getZExtValue();
7941 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
7942
7943 Operator *L = dyn_cast<Operator>(BO->LHS);
7944 const SCEV *AddTruncateExpr = nullptr;
7945 ConstantInt *ShlAmtCI = nullptr;
7946 const SCEV *AddConstant = nullptr;
7947
7948 if (L && L->getOpcode() == Instruction::Add) {
7949 // X = Shl A, n
7950 // Y = Add X, c
7951 // Z = AShr Y, m
7952 // n, c and m are constants.
7953
7954 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
7955 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
7956 if (LShift && LShift->getOpcode() == Instruction::Shl) {
7957 if (AddOperandCI) {
7958 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
7959 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
7960 // since we truncate to TruncTy, the AddConstant should be of the
7961 // same type, so create a new Constant with type same as TruncTy.
7962 // Also, the Add constant should be shifted right by AShr amount.
7963 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
7964 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
7965 // we model the expression as sext(add(trunc(A), c << n)), since the
7966 // sext(trunc) part is already handled below, we create a
7967 // AddExpr(TruncExp) which will be used later.
7968 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
7969 }
7970 }
7971 } else if (L && L->getOpcode() == Instruction::Shl) {
7972 // X = Shl A, n
7973 // Y = AShr X, m
7974 // Both n and m are constant.
7975
7976 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
7977 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
7978 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
7979 }
7980
7981 if (AddTruncateExpr && ShlAmtCI) {
7982 // We can merge the two given cases into a single SCEV statement,
7983 // incase n = m, the mul expression will be 2^0, so it gets resolved to
7984 // a simpler case. The following code handles the two cases:
7985 //
7986 // 1) For a two-shift sext-inreg, i.e. n = m,
7987 // use sext(trunc(x)) as the SCEV expression.
7988 //
7989 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
7990 // expression. We already checked that ShlAmt < BitWidth, so
7991 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
7992 // ShlAmt - AShrAmt < Amt.
7993 const APInt &ShlAmt = ShlAmtCI->getValue();
7994 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
7996 ShlAmtCI->getZExtValue() - AShrAmt);
7997 const SCEV *CompositeExpr =
7998 getMulExpr(AddTruncateExpr, getConstant(Mul));
7999 if (L->getOpcode() != Instruction::Shl)
8000 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8001
8002 return getSignExtendExpr(CompositeExpr, OuterTy);
8003 }
8004 }
8005 break;
8006 }
8007 }
8008
8009 switch (U->getOpcode()) {
8010 case Instruction::Trunc:
8011 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8012
8013 case Instruction::ZExt:
8014 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8015
8016 case Instruction::SExt:
8017 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8018 dyn_cast<Instruction>(V))) {
8019 // The NSW flag of a subtract does not always survive the conversion to
8020 // A + (-1)*B. By pushing sign extension onto its operands we are much
8021 // more likely to preserve NSW and allow later AddRec optimisations.
8022 //
8023 // NOTE: This is effectively duplicating this logic from getSignExtend:
8024 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8025 // but by that point the NSW information has potentially been lost.
8026 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8027 Type *Ty = U->getType();
8028 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8029 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8030 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8031 }
8032 }
8033 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8034
8035 case Instruction::BitCast:
8036 // BitCasts are no-op casts so we just eliminate the cast.
8037 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8038 return getSCEV(U->getOperand(0));
8039 break;
8040
8041 case Instruction::PtrToInt: {
8042 // Pointer to integer cast is straight-forward, so do model it.
8043 const SCEV *Op = getSCEV(U->getOperand(0));
8044 Type *DstIntTy = U->getType();
8045 // But only if effective SCEV (integer) type is wide enough to represent
8046 // all possible pointer values.
8047 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8048 if (isa<SCEVCouldNotCompute>(IntOp))
8049 return getUnknown(V);
8050 return IntOp;
8051 }
8052 case Instruction::IntToPtr:
8053 // Just don't deal with inttoptr casts.
8054 return getUnknown(V);
8055
8056 case Instruction::SDiv:
8057 // If both operands are non-negative, this is just an udiv.
8058 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8059 isKnownNonNegative(getSCEV(U->getOperand(1))))
8060 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8061 break;
8062
8063 case Instruction::SRem:
8064 // If both operands are non-negative, this is just an urem.
8065 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8066 isKnownNonNegative(getSCEV(U->getOperand(1))))
8067 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8068 break;
8069
8070 case Instruction::GetElementPtr:
8071 return createNodeForGEP(cast<GEPOperator>(U));
8072
8073 case Instruction::PHI:
8074 return createNodeForPHI(cast<PHINode>(U));
8075
8076 case Instruction::Select:
8077 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8078 U->getOperand(2));
8079
8080 case Instruction::Call:
8081 case Instruction::Invoke:
8082 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8083 return getSCEV(RV);
8084
8085 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8086 switch (II->getIntrinsicID()) {
8087 case Intrinsic::abs:
8088 return getAbsExpr(
8089 getSCEV(II->getArgOperand(0)),
8090 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8091 case Intrinsic::umax:
8092 LHS = getSCEV(II->getArgOperand(0));
8093 RHS = getSCEV(II->getArgOperand(1));
8094 return getUMaxExpr(LHS, RHS);
8095 case Intrinsic::umin:
8096 LHS = getSCEV(II->getArgOperand(0));
8097 RHS = getSCEV(II->getArgOperand(1));
8098 return getUMinExpr(LHS, RHS);
8099 case Intrinsic::smax:
8100 LHS = getSCEV(II->getArgOperand(0));
8101 RHS = getSCEV(II->getArgOperand(1));
8102 return getSMaxExpr(LHS, RHS);
8103 case Intrinsic::smin:
8104 LHS = getSCEV(II->getArgOperand(0));
8105 RHS = getSCEV(II->getArgOperand(1));
8106 return getSMinExpr(LHS, RHS);
8107 case Intrinsic::usub_sat: {
8108 const SCEV *X = getSCEV(II->getArgOperand(0));
8109 const SCEV *Y = getSCEV(II->getArgOperand(1));
8110 const SCEV *ClampedY = getUMinExpr(X, Y);
8111 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8112 }
8113 case Intrinsic::uadd_sat: {
8114 const SCEV *X = getSCEV(II->getArgOperand(0));
8115 const SCEV *Y = getSCEV(II->getArgOperand(1));
8116 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8117 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8118 }
8119 case Intrinsic::start_loop_iterations:
8120 case Intrinsic::annotation:
8121 case Intrinsic::ptr_annotation:
8122 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8123 // just eqivalent to the first operand for SCEV purposes.
8124 return getSCEV(II->getArgOperand(0));
8125 case Intrinsic::vscale:
8126 return getVScale(II->getType());
8127 default:
8128 break;
8129 }
8130 }
8131 break;
8132 }
8133
8134 return getUnknown(V);
8135}
8136
8137//===----------------------------------------------------------------------===//
8138// Iteration Count Computation Code
8139//
8140
8142 if (isa<SCEVCouldNotCompute>(ExitCount))
8143 return getCouldNotCompute();
8144
8145 auto *ExitCountType = ExitCount->getType();
8146 assert(ExitCountType->isIntegerTy());
8147 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8148 1 + ExitCountType->getScalarSizeInBits());
8149 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8150}
8151
8153 Type *EvalTy,
8154 const Loop *L) {
8155 if (isa<SCEVCouldNotCompute>(ExitCount))
8156 return getCouldNotCompute();
8157
8158 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8159 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8160
8161 auto CanAddOneWithoutOverflow = [&]() {
8162 ConstantRange ExitCountRange =
8163 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8164 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8165 return true;
8166
8167 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8168 getMinusOne(ExitCount->getType()));
8169 };
8170
8171 // If we need to zero extend the backedge count, check if we can add one to
8172 // it prior to zero extending without overflow. Provided this is safe, it
8173 // allows better simplification of the +1.
8174 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8175 return getZeroExtendExpr(
8176 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8177
8178 // Get the total trip count from the count by adding 1. This may wrap.
8179 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8180}
8181
8182static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8183 if (!ExitCount)
8184 return 0;
8185
8186 ConstantInt *ExitConst = ExitCount->getValue();
8187
8188 // Guard against huge trip counts.
8189 if (ExitConst->getValue().getActiveBits() > 32)
8190 return 0;
8191
8192 // In case of integer overflow, this returns 0, which is correct.
8193 return ((unsigned)ExitConst->getZExtValue()) + 1;
8194}
8195
8197 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8198 return getConstantTripCount(ExitCount);
8199}
8200
8201unsigned
8203 const BasicBlock *ExitingBlock) {
8204 assert(ExitingBlock && "Must pass a non-null exiting block!");
8205 assert(L->isLoopExiting(ExitingBlock) &&
8206 "Exiting block must actually branch out of the loop!");
8207 const SCEVConstant *ExitCount =
8208 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8209 return getConstantTripCount(ExitCount);
8210}
8211
8213 const auto *MaxExitCount =
8214 dyn_cast<SCEVConstant>(getConstantMaxBackedgeTakenCount(L));
8215 return getConstantTripCount(MaxExitCount);
8216}
8217
8219 SmallVector<BasicBlock *, 8> ExitingBlocks;
8220 L->getExitingBlocks(ExitingBlocks);
8221
8222 std::optional<unsigned> Res;
8223 for (auto *ExitingBB : ExitingBlocks) {
8224 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8225 if (!Res)
8226 Res = Multiple;
8227 Res = (unsigned)std::gcd(*Res, Multiple);
8228 }
8229 return Res.value_or(1);
8230}
8231
8233 const SCEV *ExitCount) {
8234 if (ExitCount == getCouldNotCompute())
8235 return 1;
8236
8237 // Get the trip count
8238 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8239
8240 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8241 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8242 // the greatest power of 2 divisor less than 2^32.
8243 return Multiple.getActiveBits() > 32
8244 ? 1U << std::min((unsigned)31, Multiple.countTrailingZeros())
8245 : (unsigned)Multiple.zextOrTrunc(32).getZExtValue();
8246}
8247
8248/// Returns the largest constant divisor of the trip count of this loop as a
8249/// normal unsigned value, if possible. This means that the actual trip count is
8250/// always a multiple of the returned value (don't forget the trip count could
8251/// very well be zero as well!).
8252///
8253/// Returns 1 if the trip count is unknown or not guaranteed to be the
8254/// multiple of a constant (which is also the case if the trip count is simply
8255/// constant, use getSmallConstantTripCount for that case), Will also return 1
8256/// if the trip count is very large (>= 2^32).
8257///
8258/// As explained in the comments for getSmallConstantTripCount, this assumes
8259/// that control exits the loop via ExitingBlock.
8260unsigned
8262 const BasicBlock *ExitingBlock) {
8263 assert(ExitingBlock && "Must pass a non-null exiting block!");
8264 assert(L->isLoopExiting(ExitingBlock) &&
8265 "Exiting block must actually branch out of the loop!");
8266 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8267 return getSmallConstantTripMultiple(L, ExitCount);
8268}
8269
8271 const BasicBlock *ExitingBlock,
8272 ExitCountKind Kind) {
8273 switch (Kind) {
8274 case Exact:
8275 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8276 case SymbolicMaximum:
8277 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8278 case ConstantMaximum:
8279 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8280 };
8281 llvm_unreachable("Invalid ExitCountKind!");
8282}
8283
8284const SCEV *
8287 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8288}
8289
8291 ExitCountKind Kind) {
8292 switch (Kind) {
8293 case Exact:
8294 return getBackedgeTakenInfo(L).getExact(L, this);
8295 case ConstantMaximum:
8296 return getBackedgeTakenInfo(L).getConstantMax(this);
8297 case SymbolicMaximum:
8298 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8299 };
8300 llvm_unreachable("Invalid ExitCountKind!");
8301}
8302
8305 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8306}
8307
8309 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8310}
8311
8312/// Push PHI nodes in the header of the given loop onto the given Worklist.
8313static void PushLoopPHIs(const Loop *L,
8316 BasicBlock *Header = L->getHeader();
8317
8318 // Push all Loop-header PHIs onto the Worklist stack.
8319 for (PHINode &PN : Header->phis())
8320 if (Visited.insert(&PN).second)
8321 Worklist.push_back(&PN);
8322}
8323
8324ScalarEvolution::BackedgeTakenInfo &
8325ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8326 auto &BTI = getBackedgeTakenInfo(L);
8327 if (BTI.hasFullInfo())
8328 return BTI;
8329
8330 auto Pair = PredicatedBackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
8331
8332 if (!Pair.second)
8333 return Pair.first->second;
8334
8335 BackedgeTakenInfo Result =
8336 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8337
8338 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8339}
8340
8341ScalarEvolution::BackedgeTakenInfo &
8342ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8343 // Initially insert an invalid entry for this loop. If the insertion
8344 // succeeds, proceed to actually compute a backedge-taken count and
8345 // update the value. The temporary CouldNotCompute value tells SCEV
8346 // code elsewhere that it shouldn't attempt to request a new
8347 // backedge-taken count, which could result in infinite recursion.
8348 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8349 BackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
8350 if (!Pair.second)
8351 return Pair.first->second;
8352
8353 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8354 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8355 // must be cleared in this scope.
8356 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8357
8358 // Now that we know more about the trip count for this loop, forget any
8359 // existing SCEV values for PHI nodes in this loop since they are only
8360 // conservative estimates made without the benefit of trip count
8361 // information. This invalidation is not necessary for correctness, and is
8362 // only done to produce more precise results.
8363 if (Result.hasAnyInfo()) {
8364 // Invalidate any expression using an addrec in this loop.
8366 auto LoopUsersIt = LoopUsers.find(L);
8367 if (LoopUsersIt != LoopUsers.end())
8368 append_range(ToForget, LoopUsersIt->second);
8369 forgetMemoizedResults(ToForget);
8370
8371 // Invalidate constant-evolved loop header phis.
8372 for (PHINode &PN : L->getHeader()->phis())
8373 ConstantEvolutionLoopExitValue.erase(&PN);
8374 }
8375
8376 // Re-lookup the insert position, since the call to
8377 // computeBackedgeTakenCount above could result in a
8378 // recusive call to getBackedgeTakenInfo (on a different
8379 // loop), which would invalidate the iterator computed
8380 // earlier.
8381 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8382}
8383
8385 // This method is intended to forget all info about loops. It should
8386 // invalidate caches as if the following happened:
8387 // - The trip counts of all loops have changed arbitrarily
8388 // - Every llvm::Value has been updated in place to produce a different
8389 // result.
8390 BackedgeTakenCounts.clear();
8391 PredicatedBackedgeTakenCounts.clear();
8392 BECountUsers.clear();
8393 LoopPropertiesCache.clear();
8394 ConstantEvolutionLoopExitValue.clear();
8395 ValueExprMap.clear();
8396 ValuesAtScopes.clear();
8397 ValuesAtScopesUsers.clear();
8398 LoopDispositions.clear();
8399 BlockDispositions.clear();
8400 UnsignedRanges.clear();
8401 SignedRanges.clear();
8402 ExprValueMap.clear();
8403 HasRecMap.clear();
8404 ConstantMultipleCache.clear();
8405 PredicatedSCEVRewrites.clear();
8406 FoldCache.clear();
8407 FoldCacheUser.clear();
8408}
8409void ScalarEvolution::visitAndClearUsers(
8413 while (!Worklist.empty()) {
8414 Instruction *I = Worklist.pop_back_val();
8415 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8416 continue;
8417
8419 ValueExprMap.find_as(static_cast<Value *>(I));
8420 if (It != ValueExprMap.end()) {
8421 eraseValueFromMap(It->first);
8422 ToForget.push_back(It->second);
8423 if (PHINode *PN = dyn_cast<PHINode>(I))
8424 ConstantEvolutionLoopExitValue.erase(PN);
8425 }
8426
8427 PushDefUseChildren(I, Worklist, Visited);
8428 }
8429}
8430
8432 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8436
8437 // Iterate over all the loops and sub-loops to drop SCEV information.
8438 while (!LoopWorklist.empty()) {
8439 auto *CurrL = LoopWorklist.pop_back_val();
8440
8441 // Drop any stored trip count value.
8442 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8443 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8444
8445 // Drop information about predicated SCEV rewrites for this loop.
8446 for (auto I = PredicatedSCEVRewrites.begin();
8447 I != PredicatedSCEVRewrites.end();) {
8448 std::pair<const SCEV *, const Loop *> Entry = I->first;
8449 if (Entry.second == CurrL)
8450 PredicatedSCEVRewrites.erase(I++);
8451 else
8452 ++I;
8453 }
8454
8455 auto LoopUsersItr = LoopUsers.find(CurrL);
8456 if (LoopUsersItr != LoopUsers.end()) {
8457 ToForget.insert(ToForget.end(), LoopUsersItr->second.begin(),
8458 LoopUsersItr->second.end());
8459 }
8460
8461 // Drop information about expressions based on loop-header PHIs.
8462 PushLoopPHIs(CurrL, Worklist, Visited);
8463 visitAndClearUsers(Worklist, Visited, ToForget);
8464
8465 LoopPropertiesCache.erase(CurrL);
8466 // Forget all contained loops too, to avoid dangling entries in the
8467 // ValuesAtScopes map.
8468 LoopWorklist.append(CurrL->begin(), CurrL->end());
8469 }
8470 forgetMemoizedResults(ToForget);
8471}
8472
8474 forgetLoop(L->getOutermostLoop());
8475}
8476
8478 Instruction *I = dyn_cast<Instruction>(V);
8479 if (!I) return;
8480
8481 // Drop information about expressions based on loop-header PHIs.
8485 Worklist.push_back(I);
8486 Visited.insert(I);
8487 visitAndClearUsers(Worklist, Visited, ToForget);
8488
8489 forgetMemoizedResults(ToForget);
8490}
8491
8493 if (!isSCEVable(V->getType()))
8494 return;
8495
8496 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8497 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8498 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8499 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8500 if (const SCEV *S = getExistingSCEV(V)) {
8501 struct InvalidationRootCollector {
8502 Loop *L;
8504
8505 InvalidationRootCollector(Loop *L) : L(L) {}
8506
8507 bool follow(const SCEV *S) {
8508 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8509 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8510 if (L->contains(I))
8511 Roots.push_back(S);
8512 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8513 if (L->contains(AddRec->getLoop()))
8514 Roots.push_back(S);
8515 }
8516 return true;
8517 }
8518 bool isDone() const { return false; }
8519 };
8520
8521 InvalidationRootCollector C(L);
8522 visitAll(S, C);
8523 forgetMemoizedResults(C.Roots);
8524 }
8525
8526 // Also perform the normal invalidation.
8527 forgetValue(V);
8528}
8529
8530void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8531
8533 // Unless a specific value is passed to invalidation, completely clear both
8534 // caches.
8535 if (!V) {
8536 BlockDispositions.clear();
8537 LoopDispositions.clear();
8538 return;
8539 }
8540
8541 if (!isSCEVable(V->getType()))
8542 return;
8543
8544 const SCEV *S = getExistingSCEV(V);
8545 if (!S)
8546 return;
8547
8548 // Invalidate the block and loop dispositions cached for S. Dispositions of
8549 // S's users may change if S's disposition changes (i.e. a user may change to
8550 // loop-invariant, if S changes to loop invariant), so also invalidate
8551 // dispositions of S's users recursively.
8552 SmallVector<const SCEV *, 8> Worklist = {S};
8554 while (!Worklist.empty()) {
8555 const SCEV *Curr = Worklist.pop_back_val();
8556 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8557 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8558 if (!LoopDispoRemoved && !BlockDispoRemoved)
8559 continue;
8560 auto Users = SCEVUsers.find(Curr);
8561 if (Users != SCEVUsers.end())
8562 for (const auto *User : Users->second)
8563 if (Seen.insert(User).second)
8564 Worklist.push_back(User);
8565 }
8566}
8567
8568/// Get the exact loop backedge taken count considering all loop exits. A
8569/// computable result can only be returned for loops with all exiting blocks
8570/// dominating the latch. howFarToZero assumes that the limit of each loop test
8571/// is never skipped. This is a valid assumption as long as the loop exits via
8572/// that test. For precise results, it is the caller's responsibility to specify
8573/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8574const SCEV *
8575ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE,
8577 // If any exits were not computable, the loop is not computable.
8578 if (!isComplete() || ExitNotTaken.empty())
8579 return SE->getCouldNotCompute();
8580
8581 const BasicBlock *Latch = L->getLoopLatch();
8582 // All exiting blocks we have collected must dominate the only backedge.
8583 if (!Latch)
8584 return SE->getCouldNotCompute();
8585
8586 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8587 // count is simply a minimum out of all these calculated exit counts.
8589 for (const auto &ENT : ExitNotTaken) {
8590 const SCEV *BECount = ENT.ExactNotTaken;
8591 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8592 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8593 "We should only have known counts for exiting blocks that dominate "
8594 "latch!");
8595
8596 Ops.push_back(BECount);
8597
8598 if (Preds)
8599 for (const auto *P : ENT.Predicates)
8600 Preds->push_back(P);
8601
8602 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8603 "Predicate should be always true!");
8604 }
8605
8606 // If an earlier exit exits on the first iteration (exit count zero), then
8607 // a later poison exit count should not propagate into the result. This are
8608 // exactly the semantics provided by umin_seq.
8609 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8610}
8611
8612/// Get the exact not taken count for this loop exit.
8613const SCEV *
8614ScalarEvolution::BackedgeTakenInfo::getExact(const BasicBlock *ExitingBlock,
8615 ScalarEvolution *SE) const {
8616 for (const auto &ENT : ExitNotTaken)
8617 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8618 return ENT.ExactNotTaken;
8619
8620 return SE->getCouldNotCompute();
8621}
8622
8623const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8624 const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
8625 for (const auto &ENT : ExitNotTaken)
8626 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8627 return ENT.ConstantMaxNotTaken;
8628
8629 return SE->getCouldNotCompute();
8630}
8631
8632const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8633 const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
8634 for (const auto &ENT : ExitNotTaken)
8635 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8636 return ENT.SymbolicMaxNotTaken;
8637
8638 return SE->getCouldNotCompute();
8639}
8640
8641/// getConstantMax - Get the constant max backedge taken count for the loop.
8642const SCEV *
8643ScalarEvolution::BackedgeTakenInfo::getConstantMax(ScalarEvolution *SE) const {
8644 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8645 return !ENT.hasAlwaysTruePredicate();
8646 };
8647
8648 if (!getConstantMax() || any_of(ExitNotTaken, PredicateNotAlwaysTrue))
8649 return SE->getCouldNotCompute();
8650
8651 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8652 isa<SCEVConstant>(getConstantMax())) &&
8653 "No point in having a non-constant max backedge taken count!");
8654 return getConstantMax();
8655}
8656
8657const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8658 const Loop *L, ScalarEvolution *SE,
8660 if (!SymbolicMax) {
8661 // Form an expression for the maximum exit count possible for this loop. We
8662 // merge the max and exact information to approximate a version of
8663 // getConstantMaxBackedgeTakenCount which isn't restricted to just
8664 // constants.
8666
8667 for (const auto &ENT : ExitNotTaken) {
8668 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
8669 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
8670 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
8671 "We should only have known counts for exiting blocks that "
8672 "dominate latch!");
8673 ExitCounts.push_back(ExitCount);
8674 if (Predicates)
8675 for (const auto *P : ENT.Predicates)
8676 Predicates->push_back(P);
8677
8678 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
8679 "Predicate should be always true!");
8680 }
8681 }
8682 if (ExitCounts.empty())
8683 SymbolicMax = SE->getCouldNotCompute();
8684 else
8685 SymbolicMax =
8686 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
8687 }
8688 return SymbolicMax;
8689}
8690
8691bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
8692 ScalarEvolution *SE) const {
8693 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8694 return !ENT.hasAlwaysTruePredicate();
8695 };
8696 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
8697}
8698
8700 : ExitLimit(E, E, E, false, std::nullopt) {}
8701
8703 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8704 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8706 : ExactNotTaken(E), ConstantMaxNotTaken(ConstantMaxNotTaken),
8707 SymbolicMaxNotTaken(SymbolicMaxNotTaken), MaxOrZero(MaxOrZero) {
8708 // If we prove the max count is zero, so is the symbolic bound. This happens
8709 // in practice due to differences in a) how context sensitive we've chosen
8710 // to be and b) how we reason about bounds implied by UB.
8711 if (ConstantMaxNotTaken->isZero()) {
8713 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
8714 }
8715
8716 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
8717 !isa<SCEVCouldNotCompute>(ConstantMaxNotTaken)) &&
8718 "Exact is not allowed to be less precise than Constant Max");
8719 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
8720 !isa<SCEVCouldNotCompute>(SymbolicMaxNotTaken)) &&
8721 "Exact is not allowed to be less precise than Symbolic Max");
8722 assert((isa<SCEVCouldNotCompute>(SymbolicMaxNotTaken) ||
8723 !isa<SCEVCouldNotCompute>(ConstantMaxNotTaken)) &&
8724 "Symbolic Max is not allowed to be less precise than Constant Max");
8725 assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
8726 isa<SCEVConstant>(ConstantMaxNotTaken)) &&
8727 "No point in having a non-constant max backedge taken count!");
8728 for (const auto *PredSet : PredSetList)
8729 for (const auto *P : *PredSet)
8730 addPredicate(P);
8731 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
8732 "Backedge count should be int");
8733 assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
8735 "Max backedge count should be int");
8736}
8737
8739 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8740 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8742 : ExitLimit(E, ConstantMaxNotTaken, SymbolicMaxNotTaken, MaxOrZero,
8743 { &PredSet }) {}
8744
8745/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
8746/// computable exit into a persistent ExitNotTakenInfo array.
8747ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
8749 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
8750 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
8751 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8752
8753 ExitNotTaken.reserve(ExitCounts.size());
8754 std::transform(ExitCounts.begin(), ExitCounts.end(),
8755 std::back_inserter(ExitNotTaken),
8756 [&](const EdgeExitInfo &EEI) {
8757 BasicBlock *ExitBB = EEI.first;
8758 const ExitLimit &EL = EEI.second;
8759 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
8760 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
8761 EL.Predicates);
8762 });
8763 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
8764 isa<SCEVConstant>(ConstantMax)) &&
8765 "No point in having a non-constant max backedge taken count!");
8766}
8767
8768/// Compute the number of times the backedge of the specified loop will execute.
8769ScalarEvolution::BackedgeTakenInfo
8770ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
8771 bool AllowPredicates) {
8772 SmallVector<BasicBlock *, 8> ExitingBlocks;
8773 L->getExitingBlocks(ExitingBlocks);
8774
8775 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8776
8778 bool CouldComputeBECount = true;
8779 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
8780 const SCEV *MustExitMaxBECount = nullptr;
8781 const SCEV *MayExitMaxBECount = nullptr;
8782 bool MustExitMaxOrZero = false;
8783 bool IsOnlyExit = ExitingBlocks.size() == 1;
8784
8785 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
8786 // and compute maxBECount.
8787 // Do a union of all the predicates here.
8788 for (BasicBlock *ExitBB : ExitingBlocks) {
8789 // We canonicalize untaken exits to br (constant), ignore them so that
8790 // proving an exit untaken doesn't negatively impact our ability to reason
8791 // about the loop as whole.
8792 if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
8793 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
8794 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8795 if (ExitIfTrue == CI->isZero())
8796 continue;
8797 }
8798
8799 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
8800
8801 assert((AllowPredicates || EL.Predicates.empty()) &&
8802 "Predicated exit limit when predicates are not allowed!");
8803
8804 // 1. For each exit that can be computed, add an entry to ExitCounts.
8805 // CouldComputeBECount is true only if all exits can be computed.
8806 if (EL.ExactNotTaken != getCouldNotCompute())
8807 ++NumExitCountsComputed;
8808 else
8809 // We couldn't compute an exact value for this exit, so
8810 // we won't be able to compute an exact value for the loop.
8811 CouldComputeBECount = false;
8812 // Remember exit count if either exact or symbolic is known. Because
8813 // Exact always implies symbolic, only check symbolic.
8814 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
8815 ExitCounts.emplace_back(ExitBB, EL);
8816 else {
8817 assert(EL.ExactNotTaken == getCouldNotCompute() &&
8818 "Exact is known but symbolic isn't?");
8819 ++NumExitCountsNotComputed;
8820 }
8821
8822 // 2. Derive the loop's MaxBECount from each exit's max number of
8823 // non-exiting iterations. Partition the loop exits into two kinds:
8824 // LoopMustExits and LoopMayExits.
8825 //
8826 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
8827 // is a LoopMayExit. If any computable LoopMustExit is found, then
8828 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
8829 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
8830 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
8831 // any
8832 // computable EL.ConstantMaxNotTaken.
8833 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
8834 DT.dominates(ExitBB, Latch)) {
8835 if (!MustExitMaxBECount) {
8836 MustExitMaxBECount = EL.ConstantMaxNotTaken;
8837 MustExitMaxOrZero = EL.MaxOrZero;
8838 } else {
8839 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
8840 EL.ConstantMaxNotTaken);
8841 }
8842 } else if (MayExitMaxBECount != getCouldNotCompute()) {
8843 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
8844 MayExitMaxBECount = EL.ConstantMaxNotTaken;
8845 else {
8846 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
8847 EL.ConstantMaxNotTaken);
8848 }
8849 }
8850 }
8851 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
8852 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
8853 // The loop backedge will be taken the maximum or zero times if there's
8854 // a single exit that must be taken the maximum or zero times.
8855 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
8856
8857 // Remember which SCEVs are used in exit limits for invalidation purposes.
8858 // We only care about non-constant SCEVs here, so we can ignore
8859 // EL.ConstantMaxNotTaken
8860 // and MaxBECount, which must be SCEVConstant.
8861 for (const auto &Pair : ExitCounts) {
8862 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
8863 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
8864 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
8865 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
8866 {L, AllowPredicates});
8867 }
8868 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
8869 MaxBECount, MaxOrZero);
8870}
8871
8873ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
8874 bool IsOnlyExit, bool AllowPredicates) {
8875 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
8876 // If our exiting block does not dominate the latch, then its connection with
8877 // loop's exit limit may be far from trivial.
8878 const BasicBlock *Latch = L->getLoopLatch();
8879 if (!Latch || !DT.dominates(ExitingBlock, Latch))
8880 return getCouldNotCompute();
8881
8882 Instruction *Term = ExitingBlock->getTerminator();
8883 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
8884 assert(BI->isConditional() && "If unconditional, it can't be in loop!");
8885 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8886 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
8887 "It should have one successor in loop and one exit block!");
8888 // Proceed to the next level to examine the exit condition expression.
8889 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
8890 /*ControlsOnlyExit=*/IsOnlyExit,
8891 AllowPredicates);
8892 }
8893
8894 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
8895 // For switch, make sure that there is a single exit from the loop.
8896 BasicBlock *Exit = nullptr;
8897 for (auto *SBB : successors(ExitingBlock))
8898 if (!L->contains(SBB)) {
8899 if (Exit) // Multiple exit successors.
8900 return getCouldNotCompute();
8901 Exit = SBB;
8902 }
8903 assert(Exit && "Exiting block must have at least one exit");
8904 return computeExitLimitFromSingleExitSwitch(
8905 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
8906 }
8907
8908 return getCouldNotCompute();
8909}
8910
8912 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
8913 bool AllowPredicates) {
8914 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
8915 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
8916 ControlsOnlyExit, AllowPredicates);
8917}
8918
8919std::optional<ScalarEvolution::ExitLimit>
8920ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
8921 bool ExitIfTrue, bool ControlsOnlyExit,
8922 bool AllowPredicates) {
8923 (void)this->L;
8924 (void)this->ExitIfTrue;
8925 (void)this->AllowPredicates;
8926
8927 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
8928 this->AllowPredicates == AllowPredicates &&
8929 "Variance in assumed invariant key components!");
8930 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
8931 if (Itr == TripCountMap.end())
8932 return std::nullopt;
8933 return Itr->second;
8934}
8935
8936void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
8937 bool ExitIfTrue,
8938 bool ControlsOnlyExit,
8939 bool AllowPredicates,
8940 const ExitLimit &EL) {
8941 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
8942 this->AllowPredicates == AllowPredicates &&
8943 "Variance in assumed invariant key components!");
8944
8945 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
8946 assert(InsertResult.second && "Expected successful insertion!");
8947 (void)InsertResult;
8948 (void)ExitIfTrue;
8949}
8950
8951ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
8952 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8953 bool ControlsOnlyExit, bool AllowPredicates) {
8954
8955 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
8956 AllowPredicates))
8957 return *MaybeEL;
8958
8959 ExitLimit EL = computeExitLimitFromCondImpl(
8960 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
8961 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
8962 return EL;
8963}
8964
8965ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
8966 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8967 bool ControlsOnlyExit, bool AllowPredicates) {
8968 // Handle BinOp conditions (And, Or).
8969 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
8970 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
8971 return *LimitFromBinOp;
8972
8973 // With an icmp, it may be feasible to compute an exact backedge-taken count.
8974 // Proceed to the next level to examine the icmp.
8975 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
8976 ExitLimit EL =
8977 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
8978 if (EL.hasFullInfo() || !AllowPredicates)
8979 return EL;
8980
8981 // Try again, but use SCEV predicates this time.
8982 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
8983 ControlsOnlyExit,
8984 /*AllowPredicates=*/true);
8985 }
8986
8987 // Check for a constant condition. These are normally stripped out by
8988 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
8989 // preserve the CFG and is temporarily leaving constant conditions
8990 // in place.
8991 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
8992 if (ExitIfTrue == !CI->getZExtValue())
8993 // The backedge is always taken.
8994 return getCouldNotCompute();
8995 // The backedge is never taken.
8996 return getZero(CI->getType());
8997 }
8998
8999 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9000 // with a constant step, we can form an equivalent icmp predicate and figure
9001 // out how many iterations will be taken before we exit.
9002 const WithOverflowInst *WO;
9003 const APInt *C;
9004 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9005 match(WO->getRHS(), m_APInt(C))) {
9006 ConstantRange NWR =
9008 WO->getNoWrapKind());
9009 CmpInst::Predicate Pred;
9010 APInt NewRHSC, Offset;
9011 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9012 if (!ExitIfTrue)
9013 Pred = ICmpInst::getInversePredicate(Pred);
9014 auto *LHS = getSCEV(WO->getLHS());
9015 if (Offset != 0)
9017 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9018 ControlsOnlyExit, AllowPredicates);
9019 if (EL.hasAnyInfo())
9020 return EL;
9021 }
9022
9023 // If it's not an integer or pointer comparison then compute it the hard way.
9024 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9025}
9026
9027std::optional<ScalarEvolution::ExitLimit>
9028ScalarEvolution::computeExitLimitFromCondFromBinOp(
9029 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9030 bool ControlsOnlyExit, bool AllowPredicates) {
9031 // Check if the controlling expression for this loop is an And or Or.
9032 Value *Op0, *Op1;
9033 bool IsAnd = false;
9034 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9035 IsAnd = true;
9036 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9037 IsAnd = false;
9038 else
9039 return std::nullopt;
9040
9041 // EitherMayExit is true in these two cases:
9042 // br (and Op0 Op1), loop, exit
9043 // br (or Op0 Op1), exit, loop
9044 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9045 ExitLimit EL0 = computeExitLimitFromCondCached(
9046 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9047 AllowPredicates);
9048 ExitLimit EL1 = computeExitLimitFromCondCached(
9049 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9050 AllowPredicates);
9051
9052 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
9053 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9054 if (isa<ConstantInt>(Op1))
9055 return Op1 == NeutralElement ? EL0 : EL1;
9056 if (isa<ConstantInt>(Op0))
9057 return Op0 == NeutralElement ? EL1 : EL0;
9058
9059 const SCEV *BECount = getCouldNotCompute();
9060 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9061 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9062 if (EitherMayExit) {
9063 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9064 // Both conditions must be same for the loop to continue executing.
9065 // Choose the less conservative count.
9066 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9067 EL1.ExactNotTaken != getCouldNotCompute()) {
9068 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9069 UseSequentialUMin);
9070 }
9071 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9072 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9073 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9074 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9075 else
9076 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9077 EL1.ConstantMaxNotTaken);
9078 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9079 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9080 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9081 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9082 else
9083 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9084 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9085 } else {
9086 // Both conditions must be same at the same time for the loop to exit.
9087 // For now, be conservative.
9088 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9089 BECount = EL0.ExactNotTaken;
9090 }
9091
9092 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9093 // to be more aggressive when computing BECount than when computing
9094 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9095 // and
9096 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9097 // EL1.ConstantMaxNotTaken to not.
9098 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9099 !isa<SCEVCouldNotCompute>(BECount))
9100 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9101 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9102 SymbolicMaxBECount =
9103 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9104 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9105 { &EL0.Predicates, &EL1.Predicates });
9106}
9107
9108ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9109 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9110 bool AllowPredicates) {
9111 // If the condition was exit on true, convert the condition to exit on false
9113 if (!ExitIfTrue)
9114 Pred = ExitCond->getPredicate();
9115 else
9116 Pred = ExitCond->getInversePredicate();
9117 const ICmpInst::Predicate OriginalPred = Pred;
9118
9119 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9120 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9121
9122 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9123 AllowPredicates);
9124 if (EL.hasAnyInfo())
9125 return EL;
9126
9127 auto *ExhaustiveCount =
9128 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9129
9130 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9131 return ExhaustiveCount;
9132
9133 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9134 ExitCond->getOperand(1), L, OriginalPred);
9135}
9136ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9137 const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
9138 bool ControlsOnlyExit, bool AllowPredicates) {
9139
9140 // Try to evaluate any dependencies out of the loop.
9141 LHS = getSCEVAtScope(LHS, L);
9142 RHS = getSCEVAtScope(RHS, L);
9143
9144 // At this point, we would like to compute how many iterations of the
9145 // loop the predicate will return true for these inputs.
9146 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9147 // If there is a loop-invariant, force it into the RHS.
9148 std::swap(LHS, RHS);
9149 Pred = ICmpInst::getSwappedPredicate(Pred);
9150 }
9151
9152 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9154 // Simplify the operands before analyzing them.
9155 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9156
9157 // If we have a comparison of a chrec against a constant, try to use value
9158 // ranges to answer this query.
9159 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9160 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9161 if (AddRec->getLoop() == L) {
9162 // Form the constant range.
9163 ConstantRange CompRange =
9164 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9165
9166 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9167 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9168 }
9169
9170 // If this loop must exit based on this condition (or execute undefined
9171 // behaviour), and we can prove the test sequence produced must repeat
9172 // the same values on self-wrap of the IV, then we can infer that IV
9173 // doesn't self wrap because if it did, we'd have an infinite (undefined)
9174 // loop.
9175 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9176 // TODO: We can peel off any functions which are invertible *in L*. Loop
9177 // invariant terms are effectively constants for our purposes here.
9178 auto *InnerLHS = LHS;
9179 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9180 InnerLHS = ZExt->getOperand();
9181 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS)) {
9182 auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
9183 if (!AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9184 StrideC && StrideC->getAPInt().isPowerOf2()) {
9185 auto Flags = AR->getNoWrapFlags();
9186 Flags = setFlags(Flags, SCEV::FlagNW);
9189 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9190 }
9191 }
9192 }
9193
9194 switch (Pred) {
9195 case ICmpInst::ICMP_NE: { // while (X != Y)
9196 // Convert to: while (X-Y != 0)
9197 if (LHS->getType()->isPointerTy()) {
9199 if (isa<SCEVCouldNotCompute>(LHS))
9200 return LHS;
9201 }
9202 if (RHS->getType()->isPointerTy()) {
9204 if (isa<SCEVCouldNotCompute>(RHS))
9205 return RHS;
9206 }
9207 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9208 AllowPredicates);
9209 if (EL.hasAnyInfo())
9210 return EL;
9211 break;
9212 }
9213 case ICmpInst::ICMP_EQ: { // while (X == Y)
9214 // Convert to: while (X-Y == 0)
9215 if (LHS->getType()->isPointerTy()) {
9217 if (isa<SCEVCouldNotCompute>(LHS))
9218 return LHS;
9219 }
9220 if (RHS->getType()->isPointerTy()) {
9222 if (isa<SCEVCouldNotCompute>(RHS))
9223 return RHS;
9224 }
9225 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9226 if (EL.hasAnyInfo()) return EL;
9227 break;
9228 }
9229 case ICmpInst::ICMP_SLE:
9230 case ICmpInst::ICMP_ULE:
9231 // Since the loop is finite, an invariant RHS cannot include the boundary
9232 // value, otherwise it would loop forever.
9233 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9234 !isLoopInvariant(RHS, L)) {
9235 // Otherwise, perform the addition in a wider type, to avoid overflow.
9236 // If the LHS is an addrec with the appropriate nowrap flag, the
9237 // extension will be sunk into it and the exit count can be analyzed.
9238 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9239 if (!OldType)
9240 break;
9241 // Prefer doubling the bitwidth over adding a single bit to make it more
9242 // likely that we use a legal type.
9243 auto *NewType =
9244 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9245 if (ICmpInst::isSigned(Pred)) {
9246 LHS = getSignExtendExpr(LHS, NewType);
9247 RHS = getSignExtendExpr(RHS, NewType);
9248 } else {
9249 LHS = getZeroExtendExpr(LHS, NewType);
9250 RHS = getZeroExtendExpr(RHS, NewType);
9251 }
9252 }
9253 RHS = getAddExpr(getOne(RHS->getType()), RHS);
9254 [[fallthrough]];
9255 case ICmpInst::ICMP_SLT:
9256 case ICmpInst::ICMP_ULT: { // while (X < Y)
9257 bool IsSigned = ICmpInst::isSigned(Pred);
9258 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9259 AllowPredicates);
9260 if (EL.hasAnyInfo())
9261 return EL;
9262 break;
9263 }
9264 case ICmpInst::ICMP_SGE:
9265 case ICmpInst::ICMP_UGE:
9266 // Since the loop is finite, an invariant RHS cannot include the boundary
9267 // value, otherwise it would loop forever.
9268 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9269 !isLoopInvariant(RHS, L))
9270 break;
9271 RHS = getAddExpr(getMinusOne(RHS->getType()), RHS);
9272 [[fallthrough]];
9273 case ICmpInst::ICMP_SGT:
9274 case ICmpInst::ICMP_UGT: { // while (X > Y)
9275 bool IsSigned = ICmpInst::isSigned(Pred);
9276 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9277 AllowPredicates);
9278 if (EL.hasAnyInfo())
9279 return EL;
9280 break;
9281 }
9282 default:
9283 break;
9284 }
9285
9286 return getCouldNotCompute();
9287}
9288
9290ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9291 SwitchInst *Switch,
9292 BasicBlock *ExitingBlock,
9293 bool ControlsOnlyExit) {
9294 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9295
9296 // Give up if the exit is the default dest of a switch.
9297 if (Switch->getDefaultDest() == ExitingBlock)
9298 return getCouldNotCompute();
9299
9300 assert(L->contains(Switch->getDefaultDest()) &&
9301 "Default case must not exit the loop!");
9302 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9303 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9304
9305 // while (X != Y) --> while (X-Y != 0)
9306 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9307 if (EL.hasAnyInfo())
9308 return EL;
9309
9310 return getCouldNotCompute();
9311}
9312
9313static ConstantInt *
9315 ScalarEvolution &SE) {
9316 const SCEV *InVal = SE.getConstant(C);
9317 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9318 assert(isa<SCEVConstant>(Val) &&
9319 "Evaluation of SCEV at constant didn't fold correctly?");
9320 return cast<SCEVConstant>(Val)->getValue();
9321}
9322
9323ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9324 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9325 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9326 if (!RHS)
9327 return getCouldNotCompute();
9328
9329 const BasicBlock *Latch = L->getLoopLatch();
9330 if (!Latch)
9331 return getCouldNotCompute();
9332
9333 const BasicBlock *Predecessor = L->getLoopPredecessor();
9334 if (!Predecessor)
9335 return getCouldNotCompute();
9336
9337 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9338 // Return LHS in OutLHS and shift_opt in OutOpCode.
9339 auto MatchPositiveShift =
9340 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9341
9342 using namespace PatternMatch;
9343
9344 ConstantInt *ShiftAmt;
9345 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9346 OutOpCode = Instruction::LShr;
9347 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9348 OutOpCode = Instruction::AShr;
9349 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9350 OutOpCode = Instruction::Shl;
9351 else
9352 return false;
9353
9354 return ShiftAmt->getValue().isStrictlyPositive();
9355 };
9356
9357 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9358 //
9359 // loop:
9360 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9361 // %iv.shifted = lshr i32 %iv, <positive constant>
9362 //
9363 // Return true on a successful match. Return the corresponding PHI node (%iv
9364 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9365 auto MatchShiftRecurrence =
9366 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9367 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9368
9369 {
9371 Value *V;
9372
9373 // If we encounter a shift instruction, "peel off" the shift operation,
9374 // and remember that we did so. Later when we inspect %iv's backedge
9375 // value, we will make sure that the backedge value uses the same
9376 // operation.
9377 //
9378 // Note: the peeled shift operation does not have to be the same
9379 // instruction as the one feeding into the PHI's backedge value. We only
9380 // really care about it being the same *kind* of shift instruction --
9381 // that's all that is required for our later inferences to hold.
9382 if (MatchPositiveShift(LHS, V, OpC)) {
9383 PostShiftOpCode = OpC;
9384 LHS = V;
9385 }
9386 }
9387
9388 PNOut = dyn_cast<PHINode>(LHS);
9389 if (!PNOut || PNOut->getParent() != L->getHeader())
9390 return false;
9391
9392 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9393 Value *OpLHS;
9394
9395 return
9396 // The backedge value for the PHI node must be a shift by a positive
9397 // amount
9398 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9399
9400 // of the PHI node itself
9401 OpLHS == PNOut &&
9402
9403 // and the kind of shift should be match the kind of shift we peeled
9404 // off, if any.
9405 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9406 };
9407
9408 PHINode *PN;
9410 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9411 return getCouldNotCompute();
9412
9413 const DataLayout &DL = getDataLayout();
9414
9415 // The key rationale for this optimization is that for some kinds of shift
9416 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9417 // within a finite number of iterations. If the condition guarding the
9418 // backedge (in the sense that the backedge is taken if the condition is true)
9419 // is false for the value the shift recurrence stabilizes to, then we know
9420 // that the backedge is taken only a finite number of times.
9421
9422 ConstantInt *StableValue = nullptr;
9423 switch (OpCode) {
9424 default:
9425 llvm_unreachable("Impossible case!");
9426
9427 case Instruction::AShr: {
9428 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9429 // bitwidth(K) iterations.
9430 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9431 KnownBits Known = computeKnownBits(FirstValue, DL, 0, &AC,
9432 Predecessor->getTerminator(), &DT);
9433 auto *Ty = cast<IntegerType>(RHS->getType());
9434 if (Known.isNonNegative())
9435 StableValue = ConstantInt::get(Ty, 0);
9436 else if (Known.isNegative())
9437 StableValue = ConstantInt::get(Ty, -1, true);
9438 else
9439 return getCouldNotCompute();
9440
9441 break;
9442 }
9443 case Instruction::LShr:
9444 case Instruction::Shl:
9445 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9446 // stabilize to 0 in at most bitwidth(K) iterations.
9447 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9448 break;
9449 }
9450
9451 auto *Result =
9452 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9453 assert(Result->getType()->isIntegerTy(1) &&
9454 "Otherwise cannot be an operand to a branch instruction");
9455
9456 if (Result->isZeroValue()) {
9457 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9458 const SCEV *UpperBound =
9460 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9461 }
9462
9463 return getCouldNotCompute();
9464}
9465
9466/// Return true if we can constant fold an instruction of the specified type,
9467/// assuming that all operands were constants.
9468static bool CanConstantFold(const Instruction *I) {
9469 if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
9470 isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) ||
9471 isa<LoadInst>(I) || isa<ExtractValueInst>(I))
9472 return true;
9473
9474 if (const CallInst *CI = dyn_cast<CallInst>(I))
9475 if (const Function *F = CI->getCalledFunction())
9476 return canConstantFoldCallTo(CI, F);
9477 return false;
9478}
9479
9480/// Determine whether this instruction can constant evolve within this loop
9481/// assuming its operands can all constant evolve.
9482static bool canConstantEvolve(Instruction *I, const Loop *L) {
9483 // An instruction outside of the loop can't be derived from a loop PHI.
9484 if (!L->contains(I)) return false;
9485
9486 if (isa<PHINode>(I)) {
9487 // We don't currently keep track of the control flow needed to evaluate
9488 // PHIs, so we cannot handle PHIs inside of loops.
9489 return L->getHeader() == I->getParent();
9490 }
9491
9492 // If we won't be able to constant fold this expression even if the operands
9493 // are constants, bail early.
9494 return CanConstantFold(I);
9495}
9496
9497/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9498/// recursing through each instruction operand until reaching a loop header phi.
9499static PHINode *
9502 unsigned Depth) {
9504 return nullptr;
9505
9506 // Otherwise, we can evaluate this instruction if all of its operands are
9507 // constant or derived from a PHI node themselves.
9508 PHINode *PHI = nullptr;
9509 for (Value *Op : UseInst->operands()) {
9510 if (isa<Constant>(Op)) continue;
9511
9512 Instruction *OpInst = dyn_cast<Instruction>(Op);
9513 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9514
9515 PHINode *P = dyn_cast<PHINode>(OpInst);
9516 if (!P)
9517 // If this operand is already visited, reuse the prior result.
9518 // We may have P != PHI if this is the deepest point at which the
9519 // inconsistent paths meet.
9520 P = PHIMap.lookup(OpInst);
9521 if (!P) {
9522 // Recurse and memoize the results, whether a phi is found or not.
9523 // This recursive call invalidates pointers into PHIMap.
9524 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9525 PHIMap[OpInst] = P;
9526 }
9527 if (!P)
9528 return nullptr; // Not evolving from PHI
9529 if (PHI && PHI != P)
9530 return nullptr; // Evolving from multiple different PHIs.
9531 PHI = P;
9532 }
9533 // This is a expression evolving from a constant PHI!
9534 return PHI;
9535}
9536
9537/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9538/// in the loop that V is derived from. We allow arbitrary operations along the
9539/// way, but the operands of an operation must either be constants or a value
9540/// derived from a constant PHI. If this expression does not fit with these
9541/// constraints, return null.
9543 Instruction *I = dyn_cast<Instruction>(V);
9544 if (!I || !canConstantEvolve(I, L)) return nullptr;
9545
9546 if (PHINode *PN = dyn_cast<PHINode>(I))
9547 return PN;
9548
9549 // Record non-constant instructions contained by the loop.
9551 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9552}
9553
9554/// EvaluateExpression - Given an expression that passes the
9555/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9556/// in the loop has the value PHIVal. If we can't fold this expression for some
9557/// reason, return null.
9560 const DataLayout &DL,
9561 const TargetLibraryInfo *TLI) {
9562 // Convenient constant check, but redundant for recursive calls.
9563 if (Constant *C = dyn_cast<Constant>(V)) return C;
9564 Instruction *I = dyn_cast<Instruction>(V);
9565 if (!I) return nullptr;
9566
9567 if (Constant *C = Vals.lookup(I)) return C;
9568
9569 // An instruction inside the loop depends on a value outside the loop that we
9570 // weren't given a mapping for, or a value such as a call inside the loop.
9571 if (!canConstantEvolve(I, L)) return nullptr;
9572
9573 // An unmapped PHI can be due to a branch or another loop inside this loop,
9574 // or due to this not being the initial iteration through a loop where we
9575 // couldn't compute the evolution of this particular PHI last time.
9576 if (isa<PHINode>(I)) return nullptr;
9577
9578 std::vector<Constant*> Operands(I->getNumOperands());
9579
9580 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9581 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9582 if (!Operand) {
9583 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9584 if (!Operands[i]) return nullptr;
9585 continue;
9586 }
9587 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9588 Vals[Operand] = C;
9589 if (!C) return nullptr;
9590 Operands[i] = C;
9591 }
9592
9593 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9594 /*AllowNonDeterministic=*/false);
9595}
9596
9597
9598// If every incoming value to PN except the one for BB is a specific Constant,
9599// return that, else return nullptr.
9601 Constant *IncomingVal = nullptr;
9602
9603 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9604 if (PN->getIncomingBlock(i) == BB)
9605 continue;
9606
9607 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9608 if (!CurrentVal)
9609 return nullptr;
9610
9611 if (IncomingVal != CurrentVal) {
9612 if (IncomingVal)
9613 return nullptr;
9614 IncomingVal = CurrentVal;
9615 }
9616 }
9617
9618 return IncomingVal;
9619}
9620
9621/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9622/// in the header of its containing loop, we know the loop executes a
9623/// constant number of times, and the PHI node is just a recurrence
9624/// involving constants, fold it.
9625Constant *
9626ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9627 const APInt &BEs,
9628 const Loop *L) {
9629 auto I = ConstantEvolutionLoopExitValue.find(PN);
9630 if (I != ConstantEvolutionLoopExitValue.end())
9631 return I->second;
9632
9634 return ConstantEvolutionLoopExitValue[PN] = nullptr; // Not going to evaluate it.
9635
9636 Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
9637
9639 BasicBlock *Header = L->getHeader();
9640 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9641
9642 BasicBlock *Latch = L->getLoopLatch();
9643 if (!Latch)
9644 return nullptr;
9645
9646 for (PHINode &PHI : Header->phis()) {
9647 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9648 CurrentIterVals[&PHI] = StartCST;
9649 }
9650 if (!CurrentIterVals.count(PN))
9651 return RetVal = nullptr;
9652
9653 Value *BEValue = PN->getIncomingValueForBlock(Latch);
9654
9655 // Execute the loop symbolically to determine the exit value.
9656 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
9657 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
9658
9659 unsigned NumIterations = BEs.getZExtValue(); // must be in range
9660 unsigned IterationNum = 0;
9661 const DataLayout &DL = getDataLayout();
9662 for (; ; ++IterationNum) {
9663 if (IterationNum == NumIterations)
9664 return RetVal = CurrentIterVals[PN]; // Got exit value!
9665
9666 // Compute the value of the PHIs for the next iteration.
9667 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
9669 Constant *NextPHI =
9670 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9671 if (!NextPHI)
9672 return nullptr; // Couldn't evaluate!
9673 NextIterVals[PN] = NextPHI;
9674
9675 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
9676
9677 // Also evaluate the other PHI nodes. However, we don't get to stop if we
9678 // cease to be able to evaluate one of them or if they stop evolving,
9679 // because that doesn't necessarily prevent us from computing PN.
9681 for (const auto &I : CurrentIterVals) {
9682 PHINode *PHI = dyn_cast<PHINode>(I.first);
9683 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
9684 PHIsToCompute.emplace_back(PHI, I.second);
9685 }
9686 // We use two distinct loops because EvaluateExpression may invalidate any
9687 // iterators into CurrentIterVals.
9688 for (const auto &I : PHIsToCompute) {
9689 PHINode *PHI = I.first;
9690 Constant *&NextPHI = NextIterVals[PHI];
9691 if (!NextPHI) { // Not already computed.
9692 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9693 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9694 }
9695 if (NextPHI != I.second)
9696 StoppedEvolving = false;
9697 }
9698
9699 // If all entries in CurrentIterVals == NextIterVals then we can stop
9700 // iterating, the loop can't continue to change.
9701 if (StoppedEvolving)
9702 return RetVal = CurrentIterVals[PN];
9703
9704 CurrentIterVals.swap(NextIterVals);
9705 }
9706}
9707
9708const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
9709 Value *Cond,
9710 bool ExitWhen) {
9712 if (!PN) return getCouldNotCompute();
9713
9714 // If the loop is canonicalized, the PHI will have exactly two entries.
9715 // That's the only form we support here.
9716 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
9717
9719 BasicBlock *Header = L->getHeader();
9720 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9721
9722 BasicBlock *Latch = L->getLoopLatch();
9723 assert(Latch && "Should follow from NumIncomingValues == 2!");
9724
9725 for (PHINode &PHI : Header->phis()) {
9726 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9727 CurrentIterVals[&PHI] = StartCST;
9728 }
9729 if (!CurrentIterVals.count(PN))
9730 return getCouldNotCompute();
9731
9732 // Okay, we find a PHI node that defines the trip count of this loop. Execute
9733 // the loop symbolically to determine when the condition gets a value of
9734 // "ExitWhen".
9735 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
9736 const DataLayout &DL = getDataLayout();
9737 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
9738 auto *CondVal = dyn_cast_or_null<ConstantInt>(
9739 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
9740
9741 // Couldn't symbolically evaluate.
9742 if (!CondVal) return getCouldNotCompute();
9743
9744 if (CondVal->getValue() == uint64_t(ExitWhen)) {
9745 ++NumBruteForceTripCountsComputed;
9746 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
9747 }
9748
9749 // Update all the PHI nodes for the next iteration.
9751
9752 // Create a list of which PHIs we need to compute. We want to do this before
9753 // calling EvaluateExpression on them because that may invalidate iterators
9754 // into CurrentIterVals.
9755 SmallVector<PHINode *, 8> PHIsToCompute;
9756 for (const auto &I : CurrentIterVals) {
9757 PHINode *PHI = dyn_cast<PHINode>(I.first);
9758 if (!PHI || PHI->getParent() != Header) continue;
9759 PHIsToCompute.push_back(PHI);
9760 }
9761 for (PHINode *PHI : PHIsToCompute) {
9762 Constant *&NextPHI = NextIterVals[PHI];
9763 if (NextPHI) continue; // Already computed!
9764
9765 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9766 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9767 }
9768 CurrentIterVals.swap(NextIterVals);
9769 }
9770
9771 // Too many iterations were needed to evaluate.
9772 return getCouldNotCompute();
9773}
9774
9775const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
9777 ValuesAtScopes[V];
9778 // Check to see if we've folded this expression at this loop before.
9779 for (auto &LS : Values)
9780 if (LS.first == L)
9781 return LS.second ? LS.second : V;
9782
9783 Values.emplace_back(L, nullptr);
9784
9785 // Otherwise compute it.
9786 const SCEV *C = computeSCEVAtScope(V, L);
9787 for (auto &LS : reverse(ValuesAtScopes[V]))
9788 if (LS.first == L) {
9789 LS.second = C;
9790 if (!isa<SCEVConstant>(C))
9791 ValuesAtScopesUsers[C].push_back({L, V});
9792 break;
9793 }
9794 return C;
9795}
9796
9797/// This builds up a Constant using the ConstantExpr interface. That way, we
9798/// will return Constants for objects which aren't represented by a
9799/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
9800/// Returns NULL if the SCEV isn't representable as a Constant.
9802 switch (V->getSCEVType()) {
9803 case scCouldNotCompute:
9804 case scAddRecExpr:
9805 case scVScale:
9806 return nullptr;
9807 case scConstant:
9808 return cast<SCEVConstant>(V)->getValue();
9809 case scUnknown:
9810 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
9811 case scPtrToInt: {
9812 const SCEVPtrToIntExpr *P2I = cast<SCEVPtrToIntExpr>(V);
9813 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
9814 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
9815
9816 return nullptr;
9817 }
9818 case scTruncate: {
9819 const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
9820 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
9821 return ConstantExpr::getTrunc(CastOp, ST->getType());
9822 return nullptr;
9823 }
9824 case scAddExpr: {
9825 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
9826 Constant *C = nullptr;
9827 for (const SCEV *Op : SA->operands()) {
9829 if (!OpC)
9830 return nullptr;
9831 if (!C) {
9832 C = OpC;
9833 continue;
9834 }
9835 assert(!C->getType()->isPointerTy() &&
9836 "Can only have one pointer, and it must be last");
9837 if (OpC->getType()->isPointerTy()) {
9838 // The offsets have been converted to bytes. We can add bytes using
9839 // an i8 GEP.
9841 OpC, C);
9842 } else {
9843 C = ConstantExpr::getAdd(C, OpC);
9844 }
9845 }
9846 return C;
9847 }
9848 case scMulExpr:
9849 case scSignExtend:
9850 case scZeroExtend:
9851 case scUDivExpr:
9852 case scSMaxExpr:
9853 case scUMaxExpr:
9854 case scSMinExpr:
9855 case scUMinExpr:
9857 return nullptr;
9858 }
9859 llvm_unreachable("Unknown SCEV kind!");
9860}
9861
9862const SCEV *
9863ScalarEvolution::getWithOperands(const SCEV *S,
9865 switch (S->getSCEVType()) {
9866 case scTruncate:
9867 case scZeroExtend:
9868 case scSignExtend:
9869 case scPtrToInt:
9870 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
9871 case scAddRecExpr: {
9872 auto *AddRec = cast<SCEVAddRecExpr>(S);
9873 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
9874 }
9875 case scAddExpr:
9876 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
9877 case scMulExpr:
9878 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
9879 case scUDivExpr:
9880 return getUDivExpr(NewOps[0], NewOps[1]);
9881 case scUMaxExpr:
9882 case scSMaxExpr:
9883 case scUMinExpr:
9884 case scSMinExpr:
9885 return getMinMaxExpr(S->getSCEVType(), NewOps);
9887 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
9888 case scConstant:
9889 case scVScale:
9890 case scUnknown:
9891 return S;
9892 case scCouldNotCompute:
9893 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
9894 }
9895 llvm_unreachable("Unknown SCEV kind!");
9896}
9897
9898const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
9899 switch (V->getSCEVType()) {
9900 case scConstant:
9901 case scVScale:
9902 return V;
9903 case scAddRecExpr: {
9904 // If this is a loop recurrence for a loop that does not contain L, then we
9905 // are dealing with the final value computed by the loop.
9906 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
9907 // First, attempt to evaluate each operand.
9908 // Avoid performing the look-up in the common case where the specified
9909 // expression has no loop-variant portions.
9910 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
9911 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
9912 if (OpAtScope == AddRec->getOperand(i))
9913 continue;
9914
9915 // Okay, at least one of these operands is loop variant but might be
9916 // foldable. Build a new instance of the folded commutative expression.
9918 NewOps.reserve(AddRec->getNumOperands());
9919 append_range(NewOps, AddRec->operands().take_front(i));
9920 NewOps.push_back(OpAtScope);
9921 for (++i; i != e; ++i)
9922 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
9923
9924 const SCEV *FoldedRec = getAddRecExpr(
9925 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
9926 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
9927 // The addrec may be folded to a nonrecurrence, for example, if the
9928 // induction variable is multiplied by zero after constant folding. Go
9929 // ahead and return the folded value.
9930 if (!AddRec)
9931 return FoldedRec;
9932 break;
9933 }
9934
9935 // If the scope is outside the addrec's loop, evaluate it by using the
9936 // loop exit value of the addrec.
9937 if (!AddRec->getLoop()->contains(L)) {
9938 // To evaluate this recurrence, we need to know how many times the AddRec
9939 // loop iterates. Compute this now.
9940 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
9941 if (BackedgeTakenCount == getCouldNotCompute())
9942 return AddRec;
9943
9944 // Then, evaluate the AddRec.
9945 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
9946 }
9947
9948 return AddRec;
9949 }
9950 case scTruncate:
9951 case scZeroExtend:
9952 case scSignExtend:
9953 case scPtrToInt:
9954 case scAddExpr:
9955 case scMulExpr:
9956 case scUDivExpr:
9957 case scUMaxExpr:
9958 case scSMaxExpr:
9959 case scUMinExpr:
9960 case scSMinExpr:
9961 case scSequentialUMinExpr: {
9962 ArrayRef<const SCEV *> Ops = V->operands();
9963 // Avoid performing the look-up in the common case where the specified
9964 // expression has no loop-variant portions.
9965 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
9966 const SCEV *OpAtScope = getSCEVAtScope(Ops[i], L);
9967 if (OpAtScope != Ops[i]) {
9968 // Okay, at least one of these operands is loop variant but might be
9969 // foldable. Build a new instance of the folded commutative expression.
9971 NewOps.reserve(Ops.size());
9972 append_range(NewOps, Ops.take_front(i));
9973 NewOps.push_back(OpAtScope);
9974
9975 for (++i; i != e; ++i) {
9976 OpAtScope = getSCEVAtScope(Ops[i], L);
9977 NewOps.push_back(OpAtScope);
9978 }
9979
9980 return getWithOperands(V, NewOps);
9981 }
9982 }
9983 // If we got here, all operands are loop invariant.
9984 return V;
9985 }
9986 case scUnknown: {
9987 // If this instruction is evolved from a constant-evolving PHI, compute the
9988 // exit value from the loop without using SCEVs.
9989 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
9990 Instruction *I = dyn_cast<Instruction>(SU->getValue());
9991 if (!I)
9992 return V; // This is some other type of SCEVUnknown, just return it.
9993
9994 if (PHINode *PN = dyn_cast<PHINode>(I)) {
9995 const Loop *CurrLoop = this->LI[I->getParent()];
9996 // Looking for loop exit value.
9997 if (CurrLoop && CurrLoop->getParentLoop() == L &&
9998 PN->getParent() == CurrLoop->getHeader()) {
9999 // Okay, there is no closed form solution for the PHI node. Check
10000 // to see if the loop that contains it has a known backedge-taken
10001 // count. If so, we may be able to force computation of the exit
10002 // value.
10003 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10004 // This trivial case can show up in some degenerate cases where
10005 // the incoming IR has not yet been fully simplified.
10006 if (BackedgeTakenCount->isZero()) {
10007 Value *InitValue = nullptr;
10008 bool MultipleInitValues = false;
10009 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10010 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10011 if (!InitValue)
10012 InitValue = PN->getIncomingValue(i);
10013 else if (InitValue != PN->getIncomingValue(i)) {
10014 MultipleInitValues = true;
10015 break;
10016 }
10017 }
10018 }
10019 if (!MultipleInitValues && InitValue)
10020 return getSCEV(InitValue);
10021 }
10022 // Do we have a loop invariant value flowing around the backedge
10023 // for a loop which must execute the backedge?
10024 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10025 isKnownNonZero(BackedgeTakenCount) &&
10026 PN->getNumIncomingValues() == 2) {
10027
10028 unsigned InLoopPred =
10029 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10030 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10031 if (CurrLoop->isLoopInvariant(BackedgeVal))
10032 return getSCEV(BackedgeVal);
10033 }
10034 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10035 // Okay, we know how many times the containing loop executes. If
10036 // this is a constant evolving PHI node, get the final value at
10037 // the specified iteration number.
10038 Constant *RV =
10039 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10040 if (RV)
10041 return getSCEV(RV);
10042 }
10043 }
10044 }
10045
10046 // Okay, this is an expression that we cannot symbolically evaluate
10047 // into a SCEV. Check to see if it's possible to symbolically evaluate
10048 // the arguments into constants, and if so, try to constant propagate the
10049 // result. This is particularly useful for computing loop exit values.
10050 if (!CanConstantFold(I))
10051 return V; // This is some other type of SCEVUnknown, just return it.
10052
10054 Operands.reserve(I->getNumOperands());
10055 bool MadeImprovement = false;
10056 for (Value *Op : I->operands()) {
10057 if (Constant *C = dyn_cast<Constant>(Op)) {
10058 Operands.push_back(C);
10059 continue;
10060 }
10061
10062 // If any of the operands is non-constant and if they are
10063 // non-integer and non-pointer, don't even try to analyze them
10064 // with scev techniques.
10065 if (!isSCEVable(Op->getType()))
10066 return V;
10067
10068 const SCEV *OrigV = getSCEV(Op);
10069 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10070 MadeImprovement |= OrigV != OpV;
10071
10073 if (!C)
10074 return V;
10075 assert(C->getType() == Op->getType() && "Type mismatch");
10076 Operands.push_back(C);
10077 }
10078
10079 // Check to see if getSCEVAtScope actually made an improvement.
10080 if (!MadeImprovement)
10081 return V; // This is some other type of SCEVUnknown, just return it.
10082
10083 Constant *C = nullptr;
10084 const DataLayout &DL = getDataLayout();
10085 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10086 /*AllowNonDeterministic=*/false);
10087 if (!C)
10088 return V;
10089 return getSCEV(C);
10090 }
10091 case scCouldNotCompute:
10092 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10093 }
10094 llvm_unreachable("Unknown SCEV type!");
10095}
10096
10098 return getSCEVAtScope(getSCEV(V), L);
10099}
10100
10101const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10102 if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S))
10103 return stripInjectiveFunctions(ZExt->getOperand());
10104 if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S))
10105 return stripInjectiveFunctions(SExt->getOperand());
10106 return S;
10107}
10108
10109/// Finds the minimum unsigned root of the following equation:
10110///
10111/// A * X = B (mod N)
10112///
10113/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10114/// A and B isn't important.
10115///
10116/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10117static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
10118 ScalarEvolution &SE) {
10119 uint32_t BW = A.getBitWidth();
10120 assert(BW == SE.getTypeSizeInBits(B->getType()));
10121 assert(A != 0 && "A must be non-zero.");
10122
10123 // 1. D = gcd(A, N)
10124 //
10125 // The gcd of A and N may have only one prime factor: 2. The number of
10126 // trailing zeros in A is its multiplicity
10127 uint32_t Mult2 = A.countr_zero();
10128 // D = 2^Mult2
10129
10130 // 2. Check if B is divisible by D.
10131 //
10132 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10133 // is not less than multiplicity of this prime factor for D.
10134 if (SE.getMinTrailingZeros(B) < Mult2)
10135 return SE.getCouldNotCompute();
10136
10137 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10138 // modulo (N / D).
10139 //
10140 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10141 // (N / D) in general. The inverse itself always fits into BW bits, though,
10142 // so we immediately truncate it.
10143 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10144 APInt I = AD.multiplicativeInverse().zext(BW);
10145
10146 // 4. Compute the minimum unsigned root of the equation:
10147 // I * (B / D) mod (N / D)
10148 // To simplify the computation, we factor out the divide by D:
10149 // (I * B mod N) / D
10150 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10151 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10152}
10153
10154/// For a given quadratic addrec, generate coefficients of the corresponding
10155/// quadratic equation, multiplied by a common value to ensure that they are
10156/// integers.
10157/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10158/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10159/// were multiplied by, and BitWidth is the bit width of the original addrec
10160/// coefficients.
10161/// This function returns std::nullopt if the addrec coefficients are not
10162/// compile- time constants.
10163static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10165 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10166 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10167 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10168 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10169 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10170 << *AddRec << '\n');
10171
10172 // We currently can only solve this if the coefficients are constants.
10173 if (!LC || !MC || !NC) {
10174 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10175 return std::nullopt;
10176 }
10177
10178 APInt L = LC->getAPInt();
10179 APInt M = MC->getAPInt();
10180 APInt N = NC->getAPInt();
10181 assert(!N.isZero() && "This is not a quadratic addrec");
10182
10183 unsigned BitWidth = LC->getAPInt().getBitWidth();
10184 unsigned NewWidth = BitWidth + 1;
10185 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10186 << BitWidth << '\n');
10187 // The sign-extension (as opposed to a zero-extension) here matches the
10188 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10189 N = N.sext(NewWidth);
10190 M = M.sext(NewWidth);
10191 L = L.sext(NewWidth);
10192
10193 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10194 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10195 // L+M, L+2M+N, L+3M+3N, ...
10196 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10197 //
10198 // The equation Acc = 0 is then
10199 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10200 // In a quadratic form it becomes:
10201 // N n^2 + (2M-N) n + 2L = 0.
10202
10203 APInt A = N;
10204 APInt B = 2 * M - A;
10205 APInt C = 2 * L;
10206 APInt T = APInt(NewWidth, 2);
10207 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10208 << "x + " << C << ", coeff bw: " << NewWidth
10209 << ", multiplied by " << T << '\n');
10210 return std::make_tuple(A, B, C, T, BitWidth);
10211}
10212
10213/// Helper function to compare optional APInts:
10214/// (a) if X and Y both exist, return min(X, Y),
10215/// (b) if neither X nor Y exist, return std::nullopt,
10216/// (c) if exactly one of X and Y exists, return that value.
10217static std::optional<APInt> MinOptional(std::optional<APInt> X,
10218 std::optional<APInt> Y) {
10219 if (X && Y) {
10220 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10221 APInt XW = X->sext(W);
10222 APInt YW = Y->sext(W);
10223 return XW.slt(YW) ? *X : *Y;
10224 }
10225 if (!X && !Y)
10226 return std::nullopt;
10227 return X ? *X : *Y;
10228}
10229
10230/// Helper function to truncate an optional APInt to a given BitWidth.
10231/// When solving addrec-related equations, it is preferable to return a value
10232/// that has the same bit width as the original addrec's coefficients. If the
10233/// solution fits in the original bit width, truncate it (except for i1).
10234/// Returning a value of a different bit width may inhibit some optimizations.
10235///
10236/// In general, a solution to a quadratic equation generated from an addrec
10237/// may require BW+1 bits, where BW is the bit width of the addrec's
10238/// coefficients. The reason is that the coefficients of the quadratic
10239/// equation are BW+1 bits wide (to avoid truncation when converting from
10240/// the addrec to the equation).
10241static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10242 unsigned BitWidth) {
10243 if (!X)
10244 return std::nullopt;
10245 unsigned W = X->getBitWidth();
10246 if (BitWidth > 1 && BitWidth < W && X->isIntN(BitWidth))
10247 return X->trunc(BitWidth);
10248 return X;
10249}
10250
10251/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10252/// iterations. The values L, M, N are assumed to be signed, and they
10253/// should all have the same bit widths.
10254/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10255/// where BW is the bit width of the addrec's coefficients.
10256/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10257/// returned as such, otherwise the bit width of the returned value may
10258/// be greater than BW.
10259///
10260/// This function returns std::nullopt if
10261/// (a) the addrec coefficients are not constant, or
10262/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10263/// like x^2 = 5, no integer solutions exist, in other cases an integer
10264/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10265static std::optional<APInt>
10267 APInt A, B, C, M;
10268 unsigned BitWidth;
10269 auto T = GetQuadraticEquation(AddRec);
10270 if (!T)
10271 return std::nullopt;
10272
10273 std::tie(A, B, C, M, BitWidth) = *T;
10274 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10275 std::optional<APInt> X =
10277 if (!X)
10278 return std::nullopt;
10279
10280 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10281 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10282 if (!V->isZero())
10283 return std::nullopt;
10284
10285 return TruncIfPossible(X, BitWidth);
10286}
10287
10288/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10289/// iterations. The values M, N are assumed to be signed, and they
10290/// should all have the same bit widths.
10291/// Find the least n such that c(n) does not belong to the given range,
10292/// while c(n-1) does.
10293///
10294/// This function returns std::nullopt if
10295/// (a) the addrec coefficients are not constant, or
10296/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10297/// bounds of the range.
10298static std::optional<APInt>
10300 const ConstantRange &Range, ScalarEvolution &SE) {
10301 assert(AddRec->getOperand(0)->isZero() &&
10302 "Starting value of addrec should be 0");
10303 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10304 << Range << ", addrec " << *AddRec << '\n');
10305 // This case is handled in getNumIterationsInRange. Here we can assume that
10306 // we start in the range.
10307 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10308 "Addrec's initial value should be in range");
10309
10310 APInt A, B, C, M;
10311 unsigned BitWidth;
10312 auto T = GetQuadraticEquation(AddRec);
10313 if (!T)
10314 return std::nullopt;
10315
10316 // Be careful about the return value: there can be two reasons for not
10317 // returning an actual number. First, if no solutions to the equations
10318 // were found, and second, if the solutions don't leave the given range.
10319 // The first case means that the actual solution is "unknown", the second
10320 // means that it's known, but not valid. If the solution is unknown, we
10321 // cannot make any conclusions.
10322 // Return a pair: the optional solution and a flag indicating if the
10323 // solution was found.
10324 auto SolveForBoundary =
10325 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10326 // Solve for signed overflow and unsigned overflow, pick the lower
10327 // solution.
10328 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10329 << Bound << " (before multiplying by " << M << ")\n");
10330 Bound *= M; // The quadratic equation multiplier.
10331
10332 std::optional<APInt> SO;
10333 if (BitWidth > 1) {
10334 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10335 "signed overflow\n");
10337 }
10338 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10339 "unsigned overflow\n");
10340 std::optional<APInt> UO =
10342
10343 auto LeavesRange = [&] (const APInt &X) {
10344 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10345 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10346 if (Range.contains(V0->getValue()))
10347 return false;
10348 // X should be at least 1, so X-1 is non-negative.
10349 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10350 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10351 if (Range.contains(V1->getValue()))
10352 return true;
10353 return false;
10354 };
10355
10356 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10357 // can be a solution, but the function failed to find it. We cannot treat it
10358 // as "no solution".
10359 if (!SO || !UO)
10360 return {std::nullopt, false};
10361
10362 // Check the smaller value first to see if it leaves the range.
10363 // At this point, both SO and UO must have values.
10364 std::optional<APInt> Min = MinOptional(SO, UO);
10365 if (LeavesRange(*Min))
10366 return { Min, true };
10367 std::optional<APInt> Max = Min == SO ? UO : SO;
10368 if (LeavesRange(*Max))
10369 return { Max, true };
10370
10371 // Solutions were found, but were eliminated, hence the "true".
10372 return {std::nullopt, true};
10373 };
10374
10375 std::tie(A, B, C, M, BitWidth) = *T;
10376 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10377 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10378 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10379 auto SL = SolveForBoundary(Lower);
10380 auto SU = SolveForBoundary(Upper);
10381 // If any of the solutions was unknown, no meaninigful conclusions can
10382 // be made.
10383 if (!SL.second || !SU.second)
10384 return std::nullopt;
10385
10386 // Claim: The correct solution is not some value between Min and Max.
10387 //
10388 // Justification: Assuming that Min and Max are different values, one of
10389 // them is when the first signed overflow happens, the other is when the
10390 // first unsigned overflow happens. Crossing the range boundary is only
10391 // possible via an overflow (treating 0 as a special case of it, modeling
10392 // an overflow as crossing k*2^W for some k).
10393 //
10394 // The interesting case here is when Min was eliminated as an invalid
10395 // solution, but Max was not. The argument is that if there was another
10396 // overflow between Min and Max, it would also have been eliminated if
10397 // it was considered.
10398 //
10399 // For a given boundary, it is possible to have two overflows of the same
10400 // type (signed/unsigned) without having the other type in between: this
10401 // can happen when the vertex of the parabola is between the iterations
10402 // corresponding to the overflows. This is only possible when the two
10403 // overflows cross k*2^W for the same k. In such case, if the second one
10404 // left the range (and was the first one to do so), the first overflow
10405 // would have to enter the range, which would mean that either we had left
10406 // the range before or that we started outside of it. Both of these cases
10407 // are contradictions.
10408 //
10409 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10410 // solution is not some value between the Max for this boundary and the
10411 // Min of the other boundary.
10412 //
10413 // Justification: Assume that we had such Max_A and Min_B corresponding
10414 // to range boundaries A and B and such that Max_A < Min_B. If there was
10415 // a solution between Max_A and Min_B, it would have to be caused by an
10416 // overflow corresponding to either A or B. It cannot correspond to B,
10417 // since Min_B is the first occurrence of such an overflow. If it
10418 // corresponded to A, it would have to be either a signed or an unsigned
10419 // overflow that is larger than both eliminated overflows for A. But
10420 // between the eliminated overflows and this overflow, the values would
10421 // cover the entire value space, thus crossing the other boundary, which
10422 // is a contradiction.
10423
10424 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10425}
10426
10427ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10428 const Loop *L,
10429 bool ControlsOnlyExit,
10430 bool AllowPredicates) {
10431
10432 // This is only used for loops with a "x != y" exit test. The exit condition
10433 // is now expressed as a single expression, V = x-y. So the exit test is
10434 // effectively V != 0. We know and take advantage of the fact that this
10435 // expression only being used in a comparison by zero context.
10436
10438 // If the value is a constant
10439 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10440 // If the value is already zero, the branch will execute zero times.
10441 if (C->getValue()->isZero()) return C;
10442 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10443 }
10444
10445 const SCEVAddRecExpr *AddRec =
10446 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10447
10448 if (!AddRec && AllowPredicates)
10449 // Try to make this an AddRec using runtime tests, in the first X
10450 // iterations of this loop, where X is the SCEV expression found by the
10451 // algorithm below.
10452 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10453
10454 if (!AddRec || AddRec->getLoop() != L)
10455 return getCouldNotCompute();
10456
10457 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10458 // the quadratic equation to solve it.
10459 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10460 // We can only use this value if the chrec ends up with an exact zero
10461 // value at this index. When solving for "X*X != 5", for example, we
10462 // should not accept a root of 2.
10463 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10464 const auto *R = cast<SCEVConstant>(getConstant(*S));
10465 return ExitLimit(R, R, R, false, Predicates);
10466 }
10467 return getCouldNotCompute();
10468 }
10469
10470 // Otherwise we can only handle this if it is affine.
10471 if (!AddRec->isAffine())
10472 return getCouldNotCompute();
10473
10474 // If this is an affine expression, the execution count of this branch is
10475 // the minimum unsigned root of the following equation:
10476 //
10477 // Start + Step*N = 0 (mod 2^BW)
10478 //
10479 // equivalent to:
10480 //
10481 // Step*N = -Start (mod 2^BW)
10482 //
10483 // where BW is the common bit width of Start and Step.
10484
10485 // Get the initial value for the loop.
10486 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10487 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10488 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10489
10490 if (!isLoopInvariant(Step, L))
10491 return getCouldNotCompute();
10492
10493 LoopGuards Guards = LoopGuards::collect(L, *this);
10494 // Specialize step for this loop so we get context sensitive facts below.
10495 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10496
10497 // For positive steps (counting up until unsigned overflow):
10498 // N = -Start/Step (as unsigned)
10499 // For negative steps (counting down to zero):
10500 // N = Start/-Step
10501 // First compute the unsigned distance from zero in the direction of Step.
10502 bool CountDown = isKnownNegative(StepWLG);
10503 if (!CountDown && !isKnownNonNegative(StepWLG))
10504 return getCouldNotCompute();
10505
10506 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10507 // Handle unitary steps, which cannot wraparound.
10508 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10509 // N = Distance (as unsigned)
10510 if (StepC &&
10511 (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne())) {
10512 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10513 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10514
10515 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10516 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10517 // case, and see if we can improve the bound.
10518 //
10519 // Explicitly handling this here is necessary because getUnsignedRange
10520 // isn't context-sensitive; it doesn't know that we only care about the
10521 // range inside the loop.
10522 const SCEV *Zero = getZero(Distance->getType());
10523 const SCEV *One = getOne(Distance->getType());
10524 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10525 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10526 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10527 // as "unsigned_max(Distance + 1) - 1".
10528 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10529 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10530 }
10531 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10532 Predicates);
10533 }
10534
10535 // If the condition controls loop exit (the loop exits only if the expression
10536 // is true) and the addition is no-wrap we can use unsigned divide to
10537 // compute the backedge count. In this case, the step may not divide the
10538 // distance, but we don't care because if the condition is "missed" the loop
10539 // will have undefined behavior due to wrapping.
10540 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10541 loopHasNoAbnormalExits(AddRec->getLoop())) {
10542
10543 // If the stride is zero, the loop must be infinite. In C++, most loops
10544 // are finite by assumption, in which case the step being zero implies
10545 // UB must execute if the loop is entered.
10546 if (!loopIsFiniteByAssumption(L) && !isKnownNonZero(StepWLG))
10547 return getCouldNotCompute();
10548
10549 const SCEV *Exact =
10550 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10551 const SCEV *ConstantMax = getCouldNotCompute();
10552 if (Exact != getCouldNotCompute()) {
10554 ConstantMax =
10556 }
10557 const SCEV *SymbolicMax =
10558 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10559 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10560 }
10561
10562 // Solve the general equation.
10563 if (!StepC || StepC->getValue()->isZero())
10564 return getCouldNotCompute();
10565 const SCEV *E = SolveLinEquationWithOverflow(StepC->getAPInt(),
10566 getNegativeSCEV(Start), *this);
10567
10568 const SCEV *M = E;
10569 if (E != getCouldNotCompute()) {
10570 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10571 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10572 }
10573 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10574 return ExitLimit(E, M, S, false, Predicates);
10575}
10576
10578ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10579 // Loops that look like: while (X == 0) are very strange indeed. We don't
10580 // handle them yet except for the trivial case. This could be expanded in the
10581 // future as needed.
10582
10583 // If the value is a constant, check to see if it is known to be non-zero
10584 // already. If so, the backedge will execute zero times.
10585 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10586 if (!C->getValue()->isZero())
10587 return getZero(C->getType());
10588 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10589 }
10590
10591 // We could implement others, but I really doubt anyone writes loops like
10592 // this, and if they did, they would already be constant folded.
10593 return getCouldNotCompute();
10594}
10595
10596std::pair<const BasicBlock *, const BasicBlock *>
10597ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10598 const {
10599 // If the block has a unique predecessor, then there is no path from the
10600 // predecessor to the block that does not go through the direct edge
10601 // from the predecessor to the block.
10602 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10603 return {Pred, BB};
10604
10605 // A loop's header is defined to be a block that dominates the loop.
10606 // If the header has a unique predecessor outside the loop, it must be
10607 // a block that has exactly one successor that can reach the loop.
10608 if (const Loop *L = LI.getLoopFor(BB))
10609 return {L->getLoopPredecessor(), L->getHeader()};
10610
10611 return {nullptr, nullptr};
10612}
10613
10614/// SCEV structural equivalence is usually sufficient for testing whether two
10615/// expressions are equal, however for the purposes of looking for a condition
10616/// guarding a loop, it can be useful to be a little more general, since a
10617/// front-end may have replicated the controlling expression.
10618static bool HasSameValue(const SCEV *A, const SCEV *B) {
10619 // Quick check to see if they are the same SCEV.
10620 if (A == B) return true;
10621
10622 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
10623 // Not all instructions that are "identical" compute the same value. For
10624 // instance, two distinct alloca instructions allocating the same type are
10625 // identical and do not read memory; but compute distinct values.
10626 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
10627 };
10628
10629 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
10630 // two different instructions with the same value. Check for this case.
10631 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
10632 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
10633 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
10634 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
10635 if (ComputesEqualValues(AI, BI))
10636 return true;
10637
10638 // Otherwise assume they may have a different value.
10639 return false;
10640}
10641
10642static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
10643 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S);
10644 if (!Add || Add->getNumOperands() != 2)
10645 return false;
10646 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
10647 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10648 LHS = Add->getOperand(1);
10649 RHS = ME->getOperand(1);
10650 return true;
10651 }
10652 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
10653 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10654 LHS = Add->getOperand(0);
10655 RHS = ME->getOperand(1);
10656 return true;
10657 }
10658 return false;
10659}
10660
10662 const SCEV *&LHS, const SCEV *&RHS,
10663 unsigned Depth) {
10664 bool Changed = false;
10665 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
10666 // '0 != 0'.
10667 auto TrivialCase = [&](bool TriviallyTrue) {
10669 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
10670 return true;
10671 };
10672 // If we hit the max recursion limit bail out.
10673 if (Depth >= 3)
10674 return false;
10675
10676 // Canonicalize a constant to the right side.
10677 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
10678 // Check for both operands constant.
10679 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
10680 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
10681 return TrivialCase(false);
10682 return TrivialCase(true);
10683 }
10684 // Otherwise swap the operands to put the constant on the right.
10685 std::swap(LHS, RHS);
10686 Pred = ICmpInst::getSwappedPredicate(Pred);
10687 Changed = true;
10688 }
10689
10690 // If we're comparing an addrec with a value which is loop-invariant in the
10691 // addrec's loop, put the addrec on the left. Also make a dominance check,
10692 // as both operands could be addrecs loop-invariant in each other's loop.
10693 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
10694 const Loop *L = AR->getLoop();
10695 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
10696 std::swap(LHS, RHS);
10697 Pred = ICmpInst::getSwappedPredicate(Pred);
10698 Changed = true;
10699 }
10700 }
10701
10702 // If there's a constant operand, canonicalize comparisons with boundary
10703 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
10704 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
10705 const APInt &RA = RC->getAPInt();
10706
10707 bool SimplifiedByConstantRange = false;
10708
10709 if (!ICmpInst::isEquality(Pred)) {
10711 if (ExactCR.isFullSet())
10712 return TrivialCase(true);
10713 if (ExactCR.isEmptySet())
10714 return TrivialCase(false);
10715
10716 APInt NewRHS;
10717 CmpInst::Predicate NewPred;
10718 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
10719 ICmpInst::isEquality(NewPred)) {
10720 // We were able to convert an inequality to an equality.
10721 Pred = NewPred;
10722 RHS = getConstant(NewRHS);
10723 Changed = SimplifiedByConstantRange = true;
10724 }
10725 }
10726
10727 if (!SimplifiedByConstantRange) {
10728 switch (Pred) {
10729 default:
10730 break;
10731 case ICmpInst::ICMP_EQ:
10732 case ICmpInst::ICMP_NE:
10733 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
10734 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
10735 Changed = true;
10736 break;
10737
10738 // The "Should have been caught earlier!" messages refer to the fact
10739 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
10740 // should have fired on the corresponding cases, and canonicalized the
10741 // check to trivial case.
10742
10743 case ICmpInst::ICMP_UGE:
10744 assert(!RA.isMinValue() && "Should have been caught earlier!");
10745 Pred = ICmpInst::ICMP_UGT;
10746 RHS = getConstant(RA - 1);
10747 Changed = true;
10748 break;
10749 case ICmpInst::ICMP_ULE:
10750 assert(!RA.isMaxValue() && "Should have been caught earlier!");
10751 Pred = ICmpInst::ICMP_ULT;
10752 RHS = getConstant(RA + 1);
10753 Changed = true;
10754 break;
10755 case ICmpInst::ICMP_SGE:
10756 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
10757 Pred = ICmpInst::ICMP_SGT;
10758 RHS = getConstant(RA - 1);
10759 Changed = true;
10760 break;
10761 case ICmpInst::ICMP_SLE:
10762 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
10763 Pred = ICmpInst::ICMP_SLT;
10764 RHS = getConstant(RA + 1);
10765 Changed = true;
10766 break;
10767 }
10768 }
10769 }
10770
10771 // Check for obvious equality.
10772 if (HasSameValue(LHS, RHS)) {
10773 if (ICmpInst::isTrueWhenEqual(Pred))
10774 return TrivialCase(true);
10776 return TrivialCase(false);
10777 }
10778
10779 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
10780 // adding or subtracting 1 from one of the operands.
10781 switch (Pred) {
10782 case ICmpInst::ICMP_SLE:
10783 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
10784 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10786 Pred = ICmpInst::ICMP_SLT;
10787 Changed = true;
10788 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
10789 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
10791 Pred = ICmpInst::ICMP_SLT;
10792 Changed = true;
10793 }
10794 break;
10795 case ICmpInst::ICMP_SGE:
10796 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
10797 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
10799 Pred = ICmpInst::ICMP_SGT;
10800 Changed = true;
10801 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
10802 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10804 Pred = ICmpInst::ICMP_SGT;
10805 Changed = true;
10806 }
10807 break;
10808 case ICmpInst::ICMP_ULE:
10809 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
10810 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10812 Pred = ICmpInst::ICMP_ULT;
10813 Changed = true;
10814 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
10815 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
10816 Pred = ICmpInst::ICMP_ULT;
10817 Changed = true;
10818 }
10819 break;
10820 case ICmpInst::ICMP_UGE:
10821 if (!getUnsignedRangeMin(RHS).isMinValue()) {
10822 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
10823 Pred = ICmpInst::ICMP_UGT;
10824 Changed = true;
10825 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
10826 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10828 Pred = ICmpInst::ICMP_UGT;
10829 Changed = true;
10830 }
10831 break;
10832 default:
10833 break;
10834 }
10835
10836 // TODO: More simplifications are possible here.
10837
10838 // Recursively simplify until we either hit a recursion limit or nothing
10839 // changes.
10840 if (Changed)
10841 return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
10842
10843 return Changed;
10844}
10845
10847 return getSignedRangeMax(S).isNegative();
10848}
10849
10852}
10853
10855 return !getSignedRangeMin(S).isNegative();
10856}
10857
10860}
10861
10863 // Query push down for cases where the unsigned range is
10864 // less than sufficient.
10865 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
10866 return isKnownNonZero(SExt->getOperand(0));
10867 return getUnsignedRangeMin(S) != 0;
10868}
10869
10870std::pair<const SCEV *, const SCEV *>
10872 // Compute SCEV on entry of loop L.
10873 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
10874 if (Start == getCouldNotCompute())
10875 return { Start, Start };
10876 // Compute post increment SCEV for loop L.
10877 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
10878 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
10879 return { Start, PostInc };
10880}
10881
10883 const SCEV *LHS, const SCEV *RHS) {
10884 // First collect all loops.
10886 getUsedLoops(LHS, LoopsUsed);
10887 getUsedLoops(RHS, LoopsUsed);
10888
10889 if (LoopsUsed.empty())
10890 return false;
10891
10892 // Domination relationship must be a linear order on collected loops.
10893#ifndef NDEBUG
10894 for (const auto *L1 : LoopsUsed)
10895 for (const auto *L2 : LoopsUsed)
10896 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
10897 DT.dominates(L2->getHeader(), L1->getHeader())) &&
10898 "Domination relationship is not a linear order");
10899#endif
10900
10901 const Loop *MDL =
10902 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
10903 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
10904 });
10905
10906 // Get init and post increment value for LHS.
10907 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
10908 // if LHS contains unknown non-invariant SCEV then bail out.
10909 if (SplitLHS.first == getCouldNotCompute())
10910 return false;
10911 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
10912 // Get init and post increment value for RHS.
10913 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
10914 // if RHS contains unknown non-invariant SCEV then bail out.
10915 if (SplitRHS.first == getCouldNotCompute())
10916 return false;
10917 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
10918 // It is possible that init SCEV contains an invariant load but it does
10919 // not dominate MDL and is not available at MDL loop entry, so we should
10920 // check it here.
10921 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
10922 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
10923 return false;
10924
10925 // It seems backedge guard check is faster than entry one so in some cases
10926 // it can speed up whole estimation by short circuit
10927 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
10928 SplitRHS.second) &&
10929 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
10930}
10931
10933 const SCEV *LHS, const SCEV *RHS) {
10934 // Canonicalize the inputs first.
10935 (void)SimplifyICmpOperands(Pred, LHS, RHS);
10936
10937 if (isKnownViaInduction(Pred, LHS, RHS))
10938 return true;
10939
10940 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
10941 return true;
10942
10943 // Otherwise see what can be done with some simple reasoning.
10944 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
10945}
10946
10948 const SCEV *LHS,
10949 const SCEV *RHS) {
10950 if (isKnownPredicate(Pred, LHS, RHS))
10951 return true;
10953 return false;
10954 return std::nullopt;
10955}
10956
10958 const SCEV *LHS, const SCEV *RHS,
10959 const Instruction *CtxI) {
10960 // TODO: Analyze guards and assumes from Context's block.
10961 return isKnownPredicate(Pred, LHS, RHS) ||
10963}
10964
10965std::optional<bool>
10967 const SCEV *RHS, const Instruction *CtxI) {
10968 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
10969 if (KnownWithoutContext)
10970 return KnownWithoutContext;
10971
10972 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
10973 return true;
10976 LHS, RHS))
10977 return false;
10978 return std::nullopt;
10979}
10980
10982 const SCEVAddRecExpr *LHS,
10983 const SCEV *RHS) {
10984 const Loop *L = LHS->getLoop();
10985 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
10986 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
10987}
10988
10989std::optional<ScalarEvolution::MonotonicPredicateType>
10991 ICmpInst::Predicate Pred) {
10992 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
10993
10994#ifndef NDEBUG
10995 // Verify an invariant: inverting the predicate should turn a monotonically
10996 // increasing change to a monotonically decreasing one, and vice versa.
10997 if (Result) {
10998 auto ResultSwapped =
10999 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11000
11001 assert(*ResultSwapped != *Result &&
11002 "monotonicity should flip as we flip the predicate");
11003 }
11004#endif
11005
11006 return Result;
11007}
11008
11009std::optional<ScalarEvolution::MonotonicPredicateType>
11010ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11011 ICmpInst::Predicate Pred) {
11012 // A zero step value for LHS means the induction variable is essentially a
11013 // loop invariant value. We don't really depend on the predicate actually
11014 // flipping from false to true (for increasing predicates, and the other way
11015 // around for decreasing predicates), all we care about is that *if* the
11016 // predicate changes then it only changes from false to true.
11017 //
11018 // A zero step value in itself is not very useful, but there may be places
11019 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11020 // as general as possible.
11021
11022 // Only handle LE/LT/GE/GT predicates.
11023 if (!ICmpInst::isRelational(Pred))
11024 return std::nullopt;
11025
11026 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11027 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11028 "Should be greater or less!");
11029
11030 // Check that AR does not wrap.
11031 if (ICmpInst::isUnsigned(Pred)) {
11032 if (!LHS->hasNoUnsignedWrap())
11033 return std::nullopt;
11035 }
11036 assert(ICmpInst::isSigned(Pred) &&
11037 "Relational predicate is either signed or unsigned!");
11038 if (!LHS->hasNoSignedWrap())
11039 return std::nullopt;
11040
11041 const SCEV *Step = LHS->getStepRecurrence(*this);
11042
11043 if (isKnownNonNegative(Step))
11045
11046 if (isKnownNonPositive(Step))
11048
11049 return std::nullopt;
11050}
11051
11052std::optional<ScalarEvolution::LoopInvariantPredicate>
11054 const SCEV *LHS, const SCEV *RHS,
11055 const Loop *L,
11056 const Instruction *CtxI) {
11057 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11058 if (!isLoopInvariant(RHS, L)) {
11059 if (!isLoopInvariant(LHS, L))
11060 return std::nullopt;
11061
11062 std::swap(LHS, RHS);
11063 Pred = ICmpInst::getSwappedPredicate(Pred);
11064 }
11065
11066 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11067 if (!ArLHS || ArLHS->getLoop() != L)
11068 return std::nullopt;
11069
11070 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11071 if (!MonotonicType)
11072 return std::nullopt;
11073 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11074 // true as the loop iterates, and the backedge is control dependent on
11075 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11076 //
11077 // * if the predicate was false in the first iteration then the predicate
11078 // is never evaluated again, since the loop exits without taking the
11079 // backedge.
11080 // * if the predicate was true in the first iteration then it will
11081 // continue to be true for all future iterations since it is
11082 // monotonically increasing.
11083 //
11084 // For both the above possibilities, we can replace the loop varying
11085 // predicate with its value on the first iteration of the loop (which is
11086 // loop invariant).
11087 //
11088 // A similar reasoning applies for a monotonically decreasing predicate, by
11089 // replacing true with false and false with true in the above two bullets.
11090 bool Increasing = *MonotonicType == ScalarEvolution::MonotonicallyIncreasing;
11091 auto P = Increasing ? Pred : ICmpInst::getInversePredicate(Pred);
11092
11095 RHS);
11096
11097 if (!CtxI)
11098 return std::nullopt;
11099 // Try to prove via context.
11100 // TODO: Support other cases.
11101 switch (Pred) {
11102 default:
11103 break;
11104 case ICmpInst::ICMP_ULE:
11105 case ICmpInst::ICMP_ULT: {
11106 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11107 // Given preconditions
11108 // (1) ArLHS does not cross the border of positive and negative parts of
11109 // range because of:
11110 // - Positive step; (TODO: lift this limitation)
11111 // - nuw - does not cross zero boundary;
11112 // - nsw - does not cross SINT_MAX boundary;
11113 // (2) ArLHS <s RHS
11114 // (3) RHS >=s 0
11115 // we can replace the loop variant ArLHS <u RHS condition with loop
11116 // invariant Start(ArLHS) <u RHS.
11117 //
11118 // Because of (1) there are two options:
11119 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11120 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11121 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11122 // Because of (2) ArLHS <u RHS is trivially true.
11123 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11124 // We can strengthen this to Start(ArLHS) <u RHS.
11125 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11126 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11127 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11129 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11131 RHS);
11132 }
11133 }
11134
11135 return std::nullopt;
11136}
11137
11138std::optional<ScalarEvolution::LoopInvariantPredicate>
11140 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11141 const Instruction *CtxI, const SCEV *MaxIter) {
11143 Pred, LHS, RHS, L, CtxI, MaxIter))
11144 return LIP;
11145 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11146 // Number of iterations expressed as UMIN isn't always great for expressing
11147 // the value on the last iteration. If the straightforward approach didn't
11148 // work, try the following trick: if the a predicate is invariant for X, it
11149 // is also invariant for umin(X, ...). So try to find something that works
11150 // among subexpressions of MaxIter expressed as umin.
11151 for (auto *Op : UMin->operands())
11153 Pred, LHS, RHS, L, CtxI, Op))
11154 return LIP;
11155 return std::nullopt;
11156}
11157
11158std::optional<ScalarEvolution::LoopInvariantPredicate>
11160 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11161 const Instruction *CtxI, const SCEV *MaxIter) {
11162 // Try to prove the following set of facts:
11163 // - The predicate is monotonic in the iteration space.
11164 // - If the check does not fail on the 1st iteration:
11165 // - No overflow will happen during first MaxIter iterations;
11166 // - It will not fail on the MaxIter'th iteration.
11167 // If the check does fail on the 1st iteration, we leave the loop and no
11168 // other checks matter.
11169
11170 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11171 if (!isLoopInvariant(RHS, L)) {
11172 if (!isLoopInvariant(LHS, L))
11173 return std::nullopt;
11174
11175 std::swap(LHS, RHS);
11176 Pred = ICmpInst::getSwappedPredicate(Pred);
11177 }
11178
11179 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11180 if (!AR || AR->getLoop() != L)
11181 return std::nullopt;
11182
11183 // The predicate must be relational (i.e. <, <=, >=, >).
11184 if (!ICmpInst::isRelational(Pred))
11185 return std::nullopt;
11186
11187 // TODO: Support steps other than +/- 1.
11188 const SCEV *Step = AR->getStepRecurrence(*this);
11189 auto *One = getOne(Step->getType());
11190 auto *MinusOne = getNegativeSCEV(One);
11191 if (Step != One && Step != MinusOne)
11192 return std::nullopt;
11193
11194 // Type mismatch here means that MaxIter is potentially larger than max
11195 // unsigned value in start type, which mean we cannot prove no wrap for the
11196 // indvar.
11197 if (AR->getType() != MaxIter->getType())
11198 return std::nullopt;
11199
11200 // Value of IV on suggested last iteration.
11201 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11202 // Does it still meet the requirement?
11203 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11204 return std::nullopt;
11205 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11206 // not exceed max unsigned value of this type), this effectively proves
11207 // that there is no wrap during the iteration. To prove that there is no
11208 // signed/unsigned wrap, we need to check that
11209 // Start <= Last for step = 1 or Start >= Last for step = -1.
11210 ICmpInst::Predicate NoOverflowPred =
11212 if (Step == MinusOne)
11213 NoOverflowPred = CmpInst::getSwappedPredicate(NoOverflowPred);
11214 const SCEV *Start = AR->getStart();
11215 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11216 return std::nullopt;
11217
11218 // Everything is fine.
11219 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11220}
11221
11222bool ScalarEvolution::isKnownPredicateViaConstantRanges(
11223 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) {
11224 if (HasSameValue(LHS, RHS))
11225 return ICmpInst::isTrueWhenEqual(Pred);
11226
11227 // This code is split out from isKnownPredicate because it is called from
11228 // within isLoopEntryGuardedByCond.
11229
11230 auto CheckRanges = [&](const ConstantRange &RangeLHS,
11231 const ConstantRange &RangeRHS) {
11232 return RangeLHS.icmp(Pred, RangeRHS);
11233 };
11234
11235 // The check at the top of the function catches the case where the values are
11236 // known to be equal.
11237 if (Pred == CmpInst::ICMP_EQ)
11238 return false;
11239
11240 if (Pred == CmpInst::ICMP_NE) {
11241 auto SL = getSignedRange(LHS);
11242 auto SR = getSignedRange(RHS);
11243 if (CheckRanges(SL, SR))
11244 return true;
11245 auto UL = getUnsignedRange(LHS);
11246 auto UR = getUnsignedRange(RHS);
11247 if (CheckRanges(UL, UR))
11248 return true;
11249 auto *Diff = getMinusSCEV(LHS, RHS);
11250 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11251 }
11252
11253 if (CmpInst::isSigned(Pred)) {
11254 auto SL = getSignedRange(LHS);
11255 auto SR = getSignedRange(RHS);
11256 return CheckRanges(SL, SR);
11257 }
11258
11259 auto UL = getUnsignedRange(LHS);
11260 auto UR = getUnsignedRange(RHS);
11261 return CheckRanges(UL, UR);
11262}
11263
11264bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
11265 const SCEV *LHS,
11266 const SCEV *RHS) {
11267 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11268 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11269 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11270 // OutC1 and OutC2.
11271 auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
11272 APInt &OutC1, APInt &OutC2,
11273 SCEV::NoWrapFlags ExpectedFlags) {
11274 const SCEV *XNonConstOp, *XConstOp;
11275 const SCEV *YNonConstOp, *YConstOp;
11276 SCEV::NoWrapFlags XFlagsPresent;
11277 SCEV::NoWrapFlags YFlagsPresent;
11278
11279 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11280 XConstOp = getZero(X->getType());
11281 XNonConstOp = X;
11282 XFlagsPresent = ExpectedFlags;
11283 }
11284 if (!isa<SCEVConstant>(XConstOp) ||
11285 (XFlagsPresent & ExpectedFlags) != ExpectedFlags)
11286 return false;
11287
11288 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11289 YConstOp = getZero(Y->getType());
11290 YNonConstOp = Y;
11291 YFlagsPresent = ExpectedFlags;
11292 }
11293
11294 if (!isa<SCEVConstant>(YConstOp) ||
11295 (YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11296 return false;
11297
11298 if (YNonConstOp != XNonConstOp)
11299 return false;
11300
11301 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11302 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11303
11304 return true;
11305 };
11306
11307 APInt C1;
11308 APInt C2;
11309
11310 switch (Pred) {
11311 default:
11312 break;
11313
11314 case ICmpInst::ICMP_SGE:
11315 std::swap(LHS, RHS);
11316 [[fallthrough]];
11317 case ICmpInst::ICMP_SLE:
11318 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11319 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11320 return true;
11321
11322 break;
11323
11324 case ICmpInst::ICMP_SGT:
11325 std::swap(LHS, RHS);
11326 [[fallthrough]];
11327 case ICmpInst::ICMP_SLT:
11328 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11329 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11330 return true;
11331
11332 break;
11333
11334 case ICmpInst::ICMP_UGE:
11335 std::swap(LHS, RHS);
11336 [[fallthrough]];
11337 case ICmpInst::ICMP_ULE:
11338 // (X + C1)<nuw> u<= (X + C2)<nuw> for C1 u<= C2.
11339 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11340 return true;
11341
11342 break;
11343
11344 case ICmpInst::ICMP_UGT:
11345 std::swap(LHS, RHS);
11346 [[fallthrough]];
11347 case ICmpInst::ICMP_ULT:
11348 // (X + C1)<nuw> u< (X + C2)<nuw> if C1 u< C2.
11349 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11350 return true;
11351 break;
11352 }
11353
11354 return false;
11355}
11356
11357bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred,
11358 const SCEV *LHS,
11359 const SCEV *RHS) {
11360 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11361 return false;
11362
11363 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11364 // the stack can result in exponential time complexity.
11365 SaveAndRestore Restore(ProvingSplitPredicate, true);
11366
11367 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11368 //
11369 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11370 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11371 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11372 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11373 // use isKnownPredicate later if needed.
11374 return isKnownNonNegative(RHS) &&
11377}
11378
11379bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB,
11381 const SCEV *LHS, const SCEV *RHS) {
11382 // No need to even try if we know the module has no guards.
11383 if (!HasGuards)
11384 return false;
11385
11386 return any_of(*BB, [&](const Instruction &I) {
11387 using namespace llvm::PatternMatch;
11388
11389 Value *Condition;
11390 return match(&I, m_Intrinsic<Intrinsic::experimental_guard>(
11391 m_Value(Condition))) &&
11392 isImpliedCond(Pred, LHS, RHS, Condition, false);
11393 });
11394}
11395
11396/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11397/// protected by a conditional between LHS and RHS. This is used to
11398/// to eliminate casts.
11399bool
11402 const SCEV *LHS, const SCEV *RHS) {
11403 // Interpret a null as meaning no loop, where there is obviously no guard
11404 // (interprocedural conditions notwithstanding). Do not bother about
11405 // unreachable loops.
11406 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11407 return true;
11408
11409 if (VerifyIR)
11410 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11411 "This cannot be done on broken IR!");
11412
11413
11414 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11415 return true;
11416
11417 BasicBlock *Latch = L->getLoopLatch();
11418 if (!Latch)
11419 return false;
11420
11421 BranchInst *LoopContinuePredicate =
11422 dyn_cast<BranchInst>(Latch->getTerminator());
11423 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
11424 isImpliedCond(Pred, LHS, RHS,
11425 LoopContinuePredicate->getCondition(),
11426 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11427 return true;
11428
11429 // We don't want more than one activation of the following loops on the stack
11430 // -- that can lead to O(n!) time complexity.
11431 if (WalkingBEDominatingConds)
11432 return false;
11433
11434 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11435
11436 // See if we can exploit a trip count to prove the predicate.
11437 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11438 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11439 if (LatchBECount != getCouldNotCompute()) {
11440 // We know that Latch branches back to the loop header exactly
11441 // LatchBECount times. This means the backdege condition at Latch is
11442 // equivalent to "{0,+,1} u< LatchBECount".
11443 Type *Ty = LatchBECount->getType();
11444 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11445 const SCEV *LoopCounter =
11446 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11447 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11448 LatchBECount))
11449 return true;
11450 }
11451
11452 // Check conditions due to any @llvm.assume intrinsics.
11453 for (auto &AssumeVH : AC.assumptions()) {
11454 if (!AssumeVH)
11455 continue;
11456 auto *CI = cast<CallInst>(AssumeVH);
11457 if (!DT.dominates(CI, Latch->getTerminator()))
11458 continue;
11459
11460 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11461 return true;
11462 }
11463
11464 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11465 return true;
11466
11467 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11468 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11469 assert(DTN && "should reach the loop header before reaching the root!");
11470
11471 BasicBlock *BB = DTN->getBlock();
11472 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11473 return true;
11474
11475 BasicBlock *PBB = BB->getSinglePredecessor();
11476 if (!PBB)
11477 continue;
11478
11479 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
11480 if (!ContinuePredicate || !ContinuePredicate->isConditional())
11481 continue;
11482
11483 Value *Condition = ContinuePredicate->getCondition();
11484
11485 // If we have an edge `E` within the loop body that dominates the only
11486 // latch, the condition guarding `E` also guards the backedge. This
11487 // reasoning works only for loops with a single latch.
11488
11489 BasicBlockEdge DominatingEdge(PBB, BB);
11490 if (DominatingEdge.isSingleEdge()) {
11491 // We're constructively (and conservatively) enumerating edges within the
11492 // loop body that dominate the latch. The dominator tree better agree
11493 // with us on this:
11494 assert(DT.dominates(DominatingEdge, Latch) && "should be!");
11495
11496 if (isImpliedCond(Pred, LHS, RHS, Condition,
11497 BB != ContinuePredicate->getSuccessor(0)))
11498 return true;
11499 }
11500 }
11501
11502 return false;
11503}
11504
11507 const SCEV *LHS,
11508 const SCEV *RHS) {
11509 // Do not bother proving facts for unreachable code.
11510 if (!DT.isReachableFromEntry(BB))
11511 return true;
11512 if (VerifyIR)
11513 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11514 "This cannot be done on broken IR!");
11515
11516 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11517 // the facts (a >= b && a != b) separately. A typical situation is when the
11518 // non-strict comparison is known from ranges and non-equality is known from
11519 // dominating predicates. If we are proving strict comparison, we always try
11520 // to prove non-equality and non-strict comparison separately.
11521 auto NonStrictPredicate = ICmpInst::getNonStrictPredicate(Pred);
11522 const bool ProvingStrictComparison = (Pred != NonStrictPredicate);
11523 bool ProvedNonStrictComparison = false;
11524 bool ProvedNonEquality = false;
11525
11526 auto SplitAndProve =
11527 [&](std::function<bool(ICmpInst::Predicate)> Fn) -> bool {
11528 if (!ProvedNonStrictComparison)
11529 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11530 if (!ProvedNonEquality)
11531 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11532 if (ProvedNonStrictComparison && ProvedNonEquality)
11533 return true;
11534 return false;
11535 };
11536
11537 if (ProvingStrictComparison) {
11538 auto ProofFn = [&](ICmpInst::Predicate P) {
11539 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
11540 };
11541 if (SplitAndProve(ProofFn))
11542 return true;
11543 }
11544
11545 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
11546 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
11547 const Instruction *CtxI = &BB->front();
11548 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
11549 return true;
11550 if (ProvingStrictComparison) {
11551 auto ProofFn = [&](ICmpInst::Predicate P) {
11552 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
11553 };
11554 if (SplitAndProve(ProofFn))
11555 return true;
11556 }
11557 return false;
11558 };
11559
11560 // Starting at the block's predecessor, climb up the predecessor chain, as long
11561 // as there are predecessors that can be found that have unique successors
11562 // leading to the original block.
11563 const Loop *ContainingLoop = LI.getLoopFor(BB);
11564 const BasicBlock *PredBB;
11565 if (ContainingLoop && ContainingLoop->getHeader() == BB)
11566 PredBB = ContainingLoop->getLoopPredecessor();
11567 else
11568 PredBB = BB->getSinglePredecessor();
11569 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
11570 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
11571 const BranchInst *BlockEntryPredicate =
11572 dyn_cast<BranchInst>(Pair.first->getTerminator());
11573 if (!BlockEntryPredicate || BlockEntryPredicate->isUnconditional())
11574 continue;
11575
11576 if (ProveViaCond(BlockEntryPredicate->getCondition(),
11577 BlockEntryPredicate->getSuccessor(0) != Pair.second))
11578 return true;
11579 }
11580
11581 // Check conditions due to any @llvm.assume intrinsics.
11582 for (auto &AssumeVH : AC.assumptions()) {
11583 if (!AssumeVH)
11584 continue;
11585 auto *CI = cast<CallInst>(AssumeVH);
11586 if (!DT.dominates(CI, BB))
11587 continue;
11588
11589 if (ProveViaCond(CI->getArgOperand(0), false))
11590 return true;
11591 }
11592
11593 // Check conditions due to any @llvm.experimental.guard intrinsics.
11594 auto *GuardDecl = F.getParent()->getFunction(
11595 Intrinsic::getName(Intrinsic::experimental_guard));
11596 if (GuardDecl)
11597 for (const auto *GU : GuardDecl->users())
11598 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
11599 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
11600 if (ProveViaCond(Guard->getArgOperand(0), false))
11601 return true;
11602 return false;
11603}
11604
11607 const SCEV *LHS,
11608 const SCEV *RHS) {
11609 // Interpret a null as meaning no loop, where there is obviously no guard
11610 // (interprocedural conditions notwithstanding).
11611 if (!L)
11612 return false;
11613
11614 // Both LHS and RHS must be available at loop entry.
11616 "LHS is not available at Loop Entry");
11618 "RHS is not available at Loop Entry");
11619
11620 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11621 return true;
11622
11623 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
11624}
11625
11626bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
11627 const SCEV *RHS,
11628 const Value *FoundCondValue, bool Inverse,
11629 const Instruction *CtxI) {
11630 // False conditions implies anything. Do not bother analyzing it further.
11631 if (FoundCondValue ==
11632 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
11633 return true;
11634
11635 if (!PendingLoopPredicates.insert(FoundCondValue).second)
11636 return false;
11637
11638 auto ClearOnExit =
11639 make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); });
11640
11641 // Recursively handle And and Or conditions.
11642 const Value *Op0, *Op1;
11643 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
11644 if (!Inverse)
11645 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11646 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11647 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
11648 if (Inverse)
11649 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11650 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11651 }
11652
11653 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
11654 if (!ICI) return false;
11655
11656 // Now that we found a conditional branch that dominates the loop or controls
11657 // the loop latch. Check to see if it is the comparison we are looking for.
11658 ICmpInst::Predicate FoundPred;
11659 if (Inverse)
11660 FoundPred = ICI->getInversePredicate();
11661 else
11662 FoundPred = ICI->getPredicate();
11663
11664 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
11665 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
11666
11667 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
11668}
11669
11670bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
11671 const SCEV *RHS,
11672 ICmpInst::Predicate FoundPred,
11673 const SCEV *FoundLHS, const SCEV *FoundRHS,
11674 const Instruction *CtxI) {
11675 // Balance the types.
11676 if (getTypeSizeInBits(LHS->getType()) <
11677 getTypeSizeInBits(FoundLHS->getType())) {
11678 // For unsigned and equality predicates, try to prove that both found
11679 // operands fit into narrow unsigned range. If so, try to prove facts in
11680 // narrow types.
11681 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
11682 !FoundRHS->getType()->isPointerTy()) {
11683 auto *NarrowType = LHS->getType();
11684 auto *WideType = FoundLHS->getType();
11685 auto BitWidth = getTypeSizeInBits(NarrowType);
11686 const SCEV *MaxValue = getZeroExtendExpr(
11688 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
11689 MaxValue) &&
11690 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
11691 MaxValue)) {
11692 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
11693 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
11694 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, TruncFoundLHS,
11695 TruncFoundRHS, CtxI))
11696 return true;
11697 }
11698 }
11699
11700 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
11701 return false;
11702 if (CmpInst::isSigned(Pred)) {
11703 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
11704 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
11705 } else {
11706 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
11707 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
11708 }
11709 } else if (getTypeSizeInBits(LHS->getType()) >
11710 getTypeSizeInBits(FoundLHS->getType())) {
11711 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
11712 return false;
11713 if (CmpInst::isSigned(FoundPred)) {
11714 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
11715 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
11716 } else {
11717 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
11718 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
11719 }
11720 }
11721 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
11722 FoundRHS, CtxI);
11723}
11724
11725bool ScalarEvolution::isImpliedCondBalancedTypes(
11726 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
11727 ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, const SCEV *FoundRHS,
11728 const Instruction *CtxI) {
11730 getTypeSizeInBits(FoundLHS->getType()) &&
11731 "Types should be balanced!");
11732 // Canonicalize the query to match the way instcombine will have
11733 // canonicalized the comparison.
11734 if (SimplifyICmpOperands(Pred, LHS, RHS))
11735 if (LHS == RHS)
11736 return CmpInst::isTrueWhenEqual(Pred);
11737 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
11738 if (FoundLHS == FoundRHS)
11739 return CmpInst::isFalseWhenEqual(FoundPred);
11740
11741 // Check to see if we can make the LHS or RHS match.
11742 if (LHS == FoundRHS || RHS == FoundLHS) {
11743 if (isa<SCEVConstant>(RHS)) {
11744 std::swap(FoundLHS, FoundRHS);
11745 FoundPred = ICmpInst::getSwappedPredicate(FoundPred);
11746 } else {
11747 std::swap(LHS, RHS);
11748 Pred = ICmpInst::getSwappedPredicate(Pred);
11749 }
11750 }
11751
11752 // Check whether the found predicate is the same as the desired predicate.
11753 if (FoundPred == Pred)
11754 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11755
11756 // Check whether swapping the found predicate makes it the same as the
11757 // desired predicate.
11758 if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) {
11759 // We can write the implication
11760 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
11761 // using one of the following ways:
11762 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
11763 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
11764 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
11765 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
11766 // Forms 1. and 2. require swapping the operands of one condition. Don't
11767 // do this if it would break canonical constant/addrec ordering.
11768 if (!isa<SCEVConstant>(RHS) && !isa<SCEVAddRecExpr>(LHS))
11769 return isImpliedCondOperands(FoundPred, RHS, LHS, FoundLHS, FoundRHS,
11770 CtxI);
11771 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
11772 return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS, CtxI);
11773
11774 // There's no clear preference between forms 3. and 4., try both. Avoid
11775 // forming getNotSCEV of pointer values as the resulting subtract is
11776 // not legal.
11777 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
11778 isImpliedCondOperands(FoundPred, getNotSCEV(LHS), getNotSCEV(RHS),
11779 FoundLHS, FoundRHS, CtxI))
11780 return true;
11781
11782 if (!FoundLHS->getType()->isPointerTy() &&
11783 !FoundRHS->getType()->isPointerTy() &&
11784 isImpliedCondOperands(Pred, LHS, RHS, getNotSCEV(FoundLHS),
11785 getNotSCEV(FoundRHS), CtxI))
11786 return true;
11787
11788 return false;
11789 }
11790
11791 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
11792 CmpInst::Predicate P2) {
11793 assert(P1 != P2 && "Handled earlier!");
11794 return CmpInst::isRelational(P2) &&
11796 };
11797 if (IsSignFlippedPredicate(Pred, FoundPred)) {
11798 // Unsigned comparison is the same as signed comparison when both the
11799 // operands are non-negative or negative.
11800 if ((isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) ||
11801 (isKnownNegative(FoundLHS) && isKnownNegative(FoundRHS)))
11802 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11803 // Create local copies that we can freely swap and canonicalize our
11804 // conditions to "le/lt".
11805 ICmpInst::Predicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
11806 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
11807 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
11808 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
11809 CanonicalPred = ICmpInst::getSwappedPredicate(CanonicalPred);
11810 CanonicalFoundPred = ICmpInst::getSwappedPredicate(CanonicalFoundPred);
11811 std::swap(CanonicalLHS, CanonicalRHS);
11812 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
11813 }
11814 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
11815 "Must be!");
11816 assert((ICmpInst::isLT(CanonicalFoundPred) ||
11817 ICmpInst::isLE(CanonicalFoundPred)) &&
11818 "Must be!");
11819 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
11820 // Use implication:
11821 // x <u y && y >=s 0 --> x <s y.
11822 // If we can prove the left part, the right part is also proven.
11823 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11824 CanonicalRHS, CanonicalFoundLHS,
11825 CanonicalFoundRHS);
11826 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
11827 // Use implication:
11828 // x <s y && y <s 0 --> x <u y.
11829 // If we can prove the left part, the right part is also proven.
11830 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11831 CanonicalRHS, CanonicalFoundLHS,
11832 CanonicalFoundRHS);
11833 }
11834
11835 // Check if we can make progress by sharpening ranges.
11836 if (FoundPred == ICmpInst::ICMP_NE &&
11837 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
11838
11839 const SCEVConstant *C = nullptr;
11840 const SCEV *V = nullptr;
11841
11842 if (isa<SCEVConstant>(FoundLHS)) {
11843 C = cast<SCEVConstant>(FoundLHS);
11844 V = FoundRHS;
11845 } else {
11846 C = cast<SCEVConstant>(FoundRHS);
11847 V = FoundLHS;
11848 }
11849
11850 // The guarding predicate tells us that C != V. If the known range
11851 // of V is [C, t), we can sharpen the range to [C + 1, t). The
11852 // range we consider has to correspond to same signedness as the
11853 // predicate we're interested in folding.
11854
11855 APInt Min = ICmpInst::isSigned(Pred) ?
11857
11858 if (Min == C->getAPInt()) {
11859 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
11860 // This is true even if (Min + 1) wraps around -- in case of
11861 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
11862
11863 APInt SharperMin = Min + 1;
11864
11865 switch (Pred) {
11866 case ICmpInst::ICMP_SGE:
11867 case ICmpInst::ICMP_UGE:
11868 // We know V `Pred` SharperMin. If this implies LHS `Pred`
11869 // RHS, we're done.
11870 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
11871 CtxI))
11872 return true;
11873 [[fallthrough]];
11874
11875 case ICmpInst::ICMP_SGT:
11876 case ICmpInst::ICMP_UGT:
11877 // We know from the range information that (V `Pred` Min ||
11878 // V == Min). We know from the guarding condition that !(V
11879 // == Min). This gives us
11880 //
11881 // V `Pred` Min || V == Min && !(V == Min)
11882 // => V `Pred` Min
11883 //
11884 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
11885
11886 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
11887 return true;
11888 break;
11889
11890 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
11891 case ICmpInst::ICMP_SLE:
11892 case ICmpInst::ICMP_ULE:
11893 if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
11894 LHS, V, getConstant(SharperMin), CtxI))
11895 return true;
11896 [[fallthrough]];
11897
11898 case ICmpInst::ICMP_SLT:
11899 case ICmpInst::ICMP_ULT:
11900 if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
11901 LHS, V, getConstant(Min), CtxI))
11902 return true;
11903 break;
11904
11905 default:
11906 // No change
11907 break;
11908 }
11909 }
11910 }
11911
11912 // Check whether the actual condition is beyond sufficient.
11913 if (FoundPred == ICmpInst::ICMP_EQ)
11914 if (ICmpInst::isTrueWhenEqual(Pred))
11915 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
11916 return true;
11917 if (Pred == ICmpInst::ICMP_NE)
11918 if (!ICmpInst::isTrueWhenEqual(FoundPred))
11919 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
11920 return true;
11921
11922 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
11923 return true;
11924
11925 // Otherwise assume the worst.
11926 return false;
11927}
11928
11929bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
11930 const SCEV *&L, const SCEV *&R,
11931 SCEV::NoWrapFlags &Flags) {
11932 const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
11933 if (!AE || AE->getNumOperands() != 2)
11934 return false;
11935
11936 L = AE->getOperand(0);
11937 R = AE->getOperand(1);
11938 Flags = AE->getNoWrapFlags();
11939 return true;
11940}
11941
11942std::optional<APInt>
11944 // We avoid subtracting expressions here because this function is usually
11945 // fairly deep in the call stack (i.e. is called many times).
11946
11947 // X - X = 0.
11948 if (More == Less)
11949 return APInt(getTypeSizeInBits(More->getType()), 0);
11950
11951 if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
11952 const auto *LAR = cast<SCEVAddRecExpr>(Less);
11953 const auto *MAR = cast<SCEVAddRecExpr>(More);
11954
11955 if (LAR->getLoop() != MAR->getLoop())
11956 return std::nullopt;
11957
11958 // We look at affine expressions only; not for correctness but to keep
11959 // getStepRecurrence cheap.
11960 if (!LAR->isAffine() || !MAR->isAffine())
11961 return std::nullopt;
11962
11963 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
11964 return std::nullopt;
11965
11966 Less = LAR->getStart();
11967 More = MAR->getStart();
11968
11969 // fall through
11970 }
11971
11972 if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) {
11973 const auto &M = cast<SCEVConstant>(More)->getAPInt();
11974 const auto &L = cast<SCEVConstant>(Less)->getAPInt();
11975 return M - L;
11976 }
11977
11978 SCEV::NoWrapFlags Flags;
11979 const SCEV *LLess = nullptr, *RLess = nullptr;
11980 const SCEV *LMore = nullptr, *RMore = nullptr;
11981 const SCEVConstant *C1 = nullptr, *C2 = nullptr;
11982 // Compare (X + C1) vs X.
11983 if (splitBinaryAdd(Less, LLess, RLess, Flags))
11984 if ((C1 = dyn_cast<SCEVConstant>(LLess)))
11985 if (RLess == More)
11986 return -(C1->getAPInt());
11987
11988 // Compare X vs (X + C2).
11989 if (splitBinaryAdd(More, LMore, RMore, Flags))
11990 if ((C2 = dyn_cast<SCEVConstant>(LMore)))
11991 if (RMore == Less)
11992 return C2->getAPInt();
11993
11994 // Compare (X + C1) vs (X + C2).
11995 if (C1 && C2 && RLess == RMore)
11996 return C2->getAPInt() - C1->getAPInt();
11997
11998 return std::nullopt;
11999}
12000
12001bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12002 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
12003 const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) {
12004 // Try to recognize the following pattern:
12005 //
12006 // FoundRHS = ...
12007 // ...
12008 // loop:
12009 // FoundLHS = {Start,+,W}
12010 // context_bb: // Basic block from the same loop
12011 // known(Pred, FoundLHS, FoundRHS)
12012 //
12013 // If some predicate is known in the context of a loop, it is also known on
12014 // each iteration of this loop, including the first iteration. Therefore, in
12015 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12016 // prove the original pred using this fact.
12017 if (!CtxI)
12018 return false;
12019 const BasicBlock *ContextBB = CtxI->getParent();
12020 // Make sure AR varies in the context block.
12021 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12022 const Loop *L = AR->getLoop();
12023 // Make sure that context belongs to the loop and executes on 1st iteration
12024 // (if it ever executes at all).
12025 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12026 return false;
12027 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12028 return false;
12029 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12030 }
12031
12032 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12033 const Loop *L = AR->getLoop();
12034 // Make sure that context belongs to the loop and executes on 1st iteration
12035 // (if it ever executes at all).
12036 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12037 return false;
12038 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12039 return false;
12040 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12041 }
12042
12043 return false;
12044}
12045
12046bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(
12047 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
12048 const SCEV *FoundLHS, const SCEV *FoundRHS) {
12049 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12050 return false;
12051
12052 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12053 if (!AddRecLHS)
12054 return false;
12055
12056 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12057 if (!AddRecFoundLHS)
12058 return false;
12059
12060 // We'd like to let SCEV reason about control dependencies, so we constrain
12061 // both the inequalities to be about add recurrences on the same loop. This
12062 // way we can use isLoopEntryGuardedByCond later.
12063
12064 const Loop *L = AddRecFoundLHS->getLoop();
12065 if (L != AddRecLHS->getLoop())
12066 return false;
12067
12068 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12069 //
12070 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12071 // ... (2)
12072 //
12073 // Informal proof for (2), assuming (1) [*]:
12074 //
12075 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12076 //
12077 // Then
12078 //
12079 // FoundLHS s< FoundRHS s< INT_MIN - C
12080 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12081 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12082 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12083 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12084 // <=> FoundLHS + C s< FoundRHS + C
12085 //
12086 // [*]: (1) can be proved by ruling out overflow.
12087 //
12088 // [**]: This can be proved by analyzing all the four possibilities:
12089 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12090 // (A s>= 0, B s>= 0).
12091 //
12092 // Note:
12093 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12094 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12095 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12096 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12097 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12098 // C)".
12099
12100 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12101 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12102 if (!LDiff || !RDiff || *LDiff != *RDiff)
12103 return false;
12104
12105 if (LDiff->isMinValue())
12106 return true;
12107
12108 APInt FoundRHSLimit;
12109
12110 if (Pred == CmpInst::ICMP_ULT) {
12111 FoundRHSLimit = -(*RDiff);
12112 } else {
12113 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12114 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12115 }
12116
12117 // Try to prove (1) or (2), as needed.
12118 return isAvailableAtLoopEntry(FoundRHS, L) &&
12119 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12120 getConstant(FoundRHSLimit));
12121}
12122
12123bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred,
12124 const SCEV *LHS, const SCEV *RHS,
12125 const SCEV *FoundLHS,
12126 const SCEV *FoundRHS, unsigned Depth) {
12127 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12128
12129 auto ClearOnExit = make_scope_exit([&]() {
12130 if (LPhi) {
12131 bool Erased = PendingMerges.erase(LPhi);
12132 assert(Erased && "Failed to erase LPhi!");
12133 (void)Erased;
12134 }
12135 if (RPhi) {
12136 bool Erased = PendingMerges.erase(RPhi);
12137 assert(Erased && "Failed to erase RPhi!");
12138 (void)Erased;
12139 }
12140 });
12141
12142 // Find respective Phis and check that they are not being pending.
12143 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12144 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12145 if (!PendingMerges.insert(Phi).second)
12146 return false;
12147 LPhi = Phi;
12148 }
12149 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12150 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12151 // If we detect a loop of Phi nodes being processed by this method, for
12152 // example:
12153 //
12154 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12155 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12156 //
12157 // we don't want to deal with a case that complex, so return conservative
12158 // answer false.
12159 if (!PendingMerges.insert(Phi).second)
12160 return false;
12161 RPhi = Phi;
12162 }
12163
12164 // If none of LHS, RHS is a Phi, nothing to do here.
12165 if (!LPhi && !RPhi)
12166 return false;
12167
12168 // If there is a SCEVUnknown Phi we are interested in, make it left.
12169 if (!LPhi) {
12170 std::swap(LHS, RHS);
12171 std::swap(FoundLHS, FoundRHS);
12172 std::swap(LPhi, RPhi);
12173 Pred = ICmpInst::getSwappedPredicate(Pred);
12174 }
12175
12176 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12177 const BasicBlock *LBB = LPhi->getParent();
12178 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12179
12180 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12181 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12182 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12183 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12184 };
12185
12186 if (RPhi && RPhi->getParent() == LBB) {
12187 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12188 // If we compare two Phis from the same block, and for each entry block
12189 // the predicate is true for incoming values from this block, then the
12190 // predicate is also true for the Phis.
12191 for (const BasicBlock *IncBB : predecessors(LBB)) {
12192 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12193 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12194 if (!ProvedEasily(L, R))
12195 return false;
12196 }
12197 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12198 // Case two: RHS is also a Phi from the same basic block, and it is an
12199 // AddRec. It means that there is a loop which has both AddRec and Unknown
12200 // PHIs, for it we can compare incoming values of AddRec from above the loop
12201 // and latch with their respective incoming values of LPhi.
12202 // TODO: Generalize to handle loops with many inputs in a header.
12203 if (LPhi->getNumIncomingValues() != 2) return false;
12204
12205 auto *RLoop = RAR->getLoop();
12206 auto *Predecessor = RLoop->getLoopPredecessor();
12207 assert(Predecessor && "Loop with AddRec with no predecessor?");
12208 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12209 if (!ProvedEasily(L1, RAR->getStart()))
12210 return false;
12211 auto *Latch = RLoop->getLoopLatch();
12212 assert(Latch && "Loop with AddRec with no latch?");
12213 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12214 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12215 return false;
12216 } else {
12217 // In all other cases go over inputs of LHS and compare each of them to RHS,
12218 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12219 // At this point RHS is either a non-Phi, or it is a Phi from some block
12220 // different from LBB.
12221 for (const BasicBlock *IncBB : predecessors(LBB)) {
12222 // Check that RHS is available in this block.
12223 if (!dominates(RHS, IncBB))
12224 return false;
12225 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12226 // Make sure L does not refer to a value from a potentially previous
12227 // iteration of a loop.
12228 if (!properlyDominates(L, LBB))
12229 return false;
12230 if (!ProvedEasily(L, RHS))
12231 return false;
12232 }
12233 }
12234 return true;
12235}
12236
12237bool ScalarEvolution::isImpliedCondOperandsViaShift(ICmpInst::Predicate Pred,
12238 const SCEV *LHS,
12239 const SCEV *RHS,
12240 const SCEV *FoundLHS,
12241 const SCEV *FoundRHS) {
12242 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12243 // sure that we are dealing with same LHS.
12244 if (RHS == FoundRHS) {
12245 std::swap(LHS, RHS);
12246 std::swap(FoundLHS, FoundRHS);
12247 Pred = ICmpInst::getSwappedPredicate(Pred);
12248 }
12249 if (LHS != FoundLHS)
12250 return false;
12251
12252 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12253 if (!SUFoundRHS)
12254 return false;
12255
12256 Value *Shiftee, *ShiftValue;
12257
12258 using namespace PatternMatch;
12259 if (match(SUFoundRHS->getValue(),
12260 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12261 auto *ShifteeS = getSCEV(Shiftee);
12262 // Prove one of the following:
12263 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12264 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12265 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12266 // ---> LHS <s RHS
12267 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12268 // ---> LHS <=s RHS
12269 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12270 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12271 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12272 if (isKnownNonNegative(ShifteeS))
12273 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12274 }
12275
12276 return false;
12277}
12278
12279bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
12280 const SCEV *LHS, const SCEV *RHS,
12281 const SCEV *FoundLHS,
12282 const SCEV *FoundRHS,
12283 const Instruction *CtxI) {
12284 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS, FoundRHS))
12285 return true;
12286
12287 if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS))
12288 return true;
12289
12290 if (isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS))
12291 return true;
12292
12293 if (isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12294 CtxI))
12295 return true;
12296
12297 return isImpliedCondOperandsHelper(Pred, LHS, RHS,
12298 FoundLHS, FoundRHS);
12299}
12300
12301/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12302template <typename MinMaxExprType>
12303static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12304 const SCEV *Candidate) {
12305 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12306 if (!MinMaxExpr)
12307 return false;
12308
12309 return is_contained(MinMaxExpr->operands(), Candidate);
12310}
12311
12314 const SCEV *LHS, const SCEV *RHS) {
12315 // If both sides are affine addrecs for the same loop, with equal
12316 // steps, and we know the recurrences don't wrap, then we only
12317 // need to check the predicate on the starting values.
12318
12319 if (!ICmpInst::isRelational(Pred))
12320 return false;
12321
12322 const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS);
12323 if (!LAR)
12324 return false;
12325 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12326 if (!RAR)
12327 return false;
12328 if (LAR->getLoop() != RAR->getLoop())
12329 return false;
12330 if (!LAR->isAffine() || !RAR->isAffine())
12331 return false;
12332
12333 if (LAR->getStepRecurrence(SE) != RAR->getStepRecurrence(SE))
12334 return false;
12335
12338 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12339 return false;
12340
12341 return SE.isKnownPredicate(Pred, LAR->getStart(), RAR->getStart());
12342}
12343
12344/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12345/// expression?
12348 const SCEV *LHS, const SCEV *RHS) {
12349 switch (Pred) {
12350 default:
12351 return false;
12352
12353 case ICmpInst::ICMP_SGE:
12354 std::swap(LHS, RHS);
12355 [[fallthrough]];
12356 case ICmpInst::ICMP_SLE:
12357 return
12358 // min(A, ...) <= A
12359 IsMinMaxConsistingOf<SCEVSMinExpr>(LHS, RHS) ||
12360 // A <= max(A, ...)
12361 IsMinMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS);
12362
12363 case ICmpInst::ICMP_UGE:
12364 std::swap(LHS, RHS);
12365 [[fallthrough]];
12366 case ICmpInst::ICMP_ULE:
12367 return
12368 // min(A, ...) <= A
12369 // FIXME: what about umin_seq?
12370 IsMinMaxConsistingOf<SCEVUMinExpr>(LHS, RHS) ||
12371 // A <= max(A, ...)
12372 IsMinMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS);
12373 }
12374
12375 llvm_unreachable("covered switch fell through?!");
12376}
12377
12378bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred,
12379 const SCEV *LHS, const SCEV *RHS,
12380 const SCEV *FoundLHS,
12381 const SCEV *FoundRHS,
12382 unsigned Depth) {
12385 "LHS and RHS have different sizes?");
12386 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12387 getTypeSizeInBits(FoundRHS->getType()) &&
12388 "FoundLHS and FoundRHS have different sizes?");
12389 // We want to avoid hurting the compile time with analysis of too big trees.
12391 return false;
12392
12393 // We only want to work with GT comparison so far.
12394 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) {
12395 Pred = CmpInst::getSwappedPredicate(Pred);
12396 std::swap(LHS, RHS);
12397 std::swap(FoundLHS, FoundRHS);
12398 }
12399
12400 // For unsigned, try to reduce it to corresponding signed comparison.
12401 if (Pred == ICmpInst::ICMP_UGT)
12402 // We can replace unsigned predicate with its signed counterpart if all
12403 // involved values are non-negative.
12404 // TODO: We could have better support for unsigned.
12405 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12406 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12407 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12408 // use this fact to prove that LHS and RHS are non-negative.
12409 const SCEV *MinusOne = getMinusOne(LHS->getType());
12410 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12411 FoundRHS) &&
12412 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12413 FoundRHS))
12414 Pred = ICmpInst::ICMP_SGT;
12415 }
12416
12417 if (Pred != ICmpInst::ICMP_SGT)
12418 return false;
12419
12420 auto GetOpFromSExt = [&](const SCEV *S) {
12421 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12422 return Ext->getOperand();
12423 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12424 // the constant in some cases.
12425 return S;
12426 };
12427
12428 // Acquire values from extensions.
12429 auto *OrigLHS = LHS;
12430 auto *OrigFoundLHS = FoundLHS;
12431 LHS = GetOpFromSExt(LHS);
12432 FoundLHS = GetOpFromSExt(FoundLHS);
12433
12434 // Is the SGT predicate can be proved trivially or using the found context.
12435 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12436 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12437 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12438 FoundRHS, Depth + 1);
12439 };
12440
12441 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12442 // We want to avoid creation of any new non-constant SCEV. Since we are
12443 // going to compare the operands to RHS, we should be certain that we don't
12444 // need any size extensions for this. So let's decline all cases when the
12445 // sizes of types of LHS and RHS do not match.
12446 // TODO: Maybe try to get RHS from sext to catch more cases?
12448 return false;
12449
12450 // Should not overflow.
12451 if (!LHSAddExpr->hasNoSignedWrap())
12452 return false;
12453
12454 auto *LL = LHSAddExpr->getOperand(0);
12455 auto *LR = LHSAddExpr->getOperand(1);
12456 auto *MinusOne = getMinusOne(RHS->getType());
12457
12458 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12459 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12460 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12461 };
12462 // Try to prove the following rule:
12463 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12464 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12465 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12466 return true;
12467 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12468 Value *LL, *LR;
12469 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12470
12471 using namespace llvm::PatternMatch;
12472
12473 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12474 // Rules for division.
12475 // We are going to perform some comparisons with Denominator and its
12476 // derivative expressions. In general case, creating a SCEV for it may
12477 // lead to a complex analysis of the entire graph, and in particular it
12478 // can request trip count recalculation for the same loop. This would
12479 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12480 // this, we only want to create SCEVs that are constants in this section.
12481 // So we bail if Denominator is not a constant.
12482 if (!isa<ConstantInt>(LR))
12483 return false;
12484
12485 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12486
12487 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
12488 // then a SCEV for the numerator already exists and matches with FoundLHS.
12489 auto *Numerator = getExistingSCEV(LL);
12490 if (!Numerator || Numerator->getType() != FoundLHS->getType())
12491 return false;
12492
12493 // Make sure that the numerator matches with FoundLHS and the denominator
12494 // is positive.
12495 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
12496 return false;
12497
12498 auto *DTy = Denominator->getType();
12499 auto *FRHSTy = FoundRHS->getType();
12500 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
12501 // One of types is a pointer and another one is not. We cannot extend
12502 // them properly to a wider type, so let us just reject this case.
12503 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
12504 // to avoid this check.
12505 return false;
12506
12507 // Given that:
12508 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
12509 auto *WTy = getWiderType(DTy, FRHSTy);
12510 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
12511 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
12512
12513 // Try to prove the following rule:
12514 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
12515 // For example, given that FoundLHS > 2. It means that FoundLHS is at
12516 // least 3. If we divide it by Denominator < 4, we will have at least 1.
12517 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
12518 if (isKnownNonPositive(RHS) &&
12519 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
12520 return true;
12521
12522 // Try to prove the following rule:
12523 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
12524 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
12525 // If we divide it by Denominator > 2, then:
12526 // 1. If FoundLHS is negative, then the result is 0.
12527 // 2. If FoundLHS is non-negative, then the result is non-negative.
12528 // Anyways, the result is non-negative.
12529 auto *MinusOne = getMinusOne(WTy);
12530 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
12531 if (isKnownNegative(RHS) &&
12532 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
12533 return true;
12534 }
12535 }
12536
12537 // If our expression contained SCEVUnknown Phis, and we split it down and now
12538 // need to prove something for them, try to prove the predicate for every
12539 // possible incoming values of those Phis.
12540 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
12541 return true;
12542
12543 return false;
12544}
12545
12547 const SCEV *LHS, const SCEV *RHS) {
12548 // zext x u<= sext x, sext x s<= zext x
12549 switch (Pred) {
12550 case ICmpInst::ICMP_SGE:
12551 std::swap(LHS, RHS);
12552 [[fallthrough]];
12553 case ICmpInst::ICMP_SLE: {
12554 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12555 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(LHS);
12556 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(RHS);
12557 if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
12558 return true;
12559 break;
12560 }
12561 case ICmpInst::ICMP_UGE:
12562 std::swap(LHS, RHS);
12563 [[fallthrough]];
12564 case ICmpInst::ICMP_ULE: {
12565 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then ZExt <u SExt.
12566 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS);
12567 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(RHS);
12568 if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
12569 return true;
12570 break;
12571 }
12572 default:
12573 break;
12574 };
12575 return false;
12576}
12577
12578bool
12579ScalarEvolution::isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred,
12580 const SCEV *LHS, const SCEV *RHS) {
12581 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
12582 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
12583 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
12584 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
12585 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
12586}
12587
12588bool
12589ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
12590 const SCEV *LHS, const SCEV *RHS,
12591 const SCEV *FoundLHS,
12592 const SCEV *FoundRHS) {
12593 switch (Pred) {
12594 default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
12595 case ICmpInst::ICMP_EQ:
12596 case ICmpInst::ICMP_NE:
12597 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
12598 return true;
12599 break;
12600 case ICmpInst::ICMP_SLT:
12601 case ICmpInst::ICMP_SLE:
12602 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
12603 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
12604 return true;
12605 break;
12606 case ICmpInst::ICMP_SGT:
12607 case ICmpInst::ICMP_SGE:
12608 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
12609 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
12610 return true;
12611 break;
12612 case ICmpInst::ICMP_ULT:
12613 case ICmpInst::ICMP_ULE:
12614 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
12615 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
12616 return true;
12617 break;
12618 case ICmpInst::ICMP_UGT:
12619 case ICmpInst::ICMP_UGE:
12620 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
12621 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
12622 return true;
12623 break;
12624 }
12625
12626 // Maybe it can be proved via operations?
12627 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
12628 return true;
12629
12630 return false;
12631}
12632
12633bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred,
12634 const SCEV *LHS,
12635 const SCEV *RHS,
12636 ICmpInst::Predicate FoundPred,
12637 const SCEV *FoundLHS,
12638 const SCEV *FoundRHS) {
12639 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
12640 // The restriction on `FoundRHS` be lifted easily -- it exists only to
12641 // reduce the compile time impact of this optimization.
12642 return false;
12643
12644 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
12645 if (!Addend)
12646 return false;
12647
12648 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
12649
12650 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
12651 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
12652 ConstantRange FoundLHSRange =
12653 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
12654
12655 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
12656 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
12657
12658 // We can also compute the range of values for `LHS` that satisfy the
12659 // consequent, "`LHS` `Pred` `RHS`":
12660 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
12661 // The antecedent implies the consequent if every value of `LHS` that
12662 // satisfies the antecedent also satisfies the consequent.
12663 return LHSRange.icmp(Pred, ConstRHS);
12664}
12665
12666bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
12667 bool IsSigned) {
12668 assert(isKnownPositive(Stride) && "Positive stride expected!");
12669
12670 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12671 const SCEV *One = getOne(Stride->getType());
12672
12673 if (IsSigned) {
12674 APInt MaxRHS = getSignedRangeMax(RHS);
12676 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12677
12678 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
12679 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
12680 }
12681
12682 APInt MaxRHS = getUnsignedRangeMax(RHS);
12683 APInt MaxValue = APInt::getMaxValue(BitWidth);
12684 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12685
12686 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
12687 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
12688}
12689
12690bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
12691 bool IsSigned) {
12692
12693 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12694 const SCEV *One = getOne(Stride->getType());
12695
12696 if (IsSigned) {
12697 APInt MinRHS = getSignedRangeMin(RHS);
12699 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12700
12701 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
12702 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
12703 }
12704
12705 APInt MinRHS = getUnsignedRangeMin(RHS);
12706 APInt MinValue = APInt::getMinValue(BitWidth);
12707 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12708
12709 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
12710 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
12711}
12712
12714 // umin(N, 1) + floor((N - umin(N, 1)) / D)
12715 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
12716 // expression fixes the case of N=0.
12717 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
12718 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
12719 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
12720}
12721
12722const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
12723 const SCEV *Stride,
12724 const SCEV *End,
12725 unsigned BitWidth,
12726 bool IsSigned) {
12727 // The logic in this function assumes we can represent a positive stride.
12728 // If we can't, the backedge-taken count must be zero.
12729 if (IsSigned && BitWidth == 1)
12730 return getZero(Stride->getType());
12731
12732 // This code below only been closely audited for negative strides in the
12733 // unsigned comparison case, it may be correct for signed comparison, but
12734 // that needs to be established.
12735 if (IsSigned && isKnownNegative(Stride))
12736 return getCouldNotCompute();
12737
12738 // Calculate the maximum backedge count based on the range of values
12739 // permitted by Start, End, and Stride.
12740 APInt MinStart =
12741 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
12742
12743 APInt MinStride =
12744 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
12745
12746 // We assume either the stride is positive, or the backedge-taken count
12747 // is zero. So force StrideForMaxBECount to be at least one.
12748 APInt One(BitWidth, 1);
12749 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
12750 : APIntOps::umax(One, MinStride);
12751
12752 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
12753 : APInt::getMaxValue(BitWidth);
12754 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
12755
12756 // Although End can be a MAX expression we estimate MaxEnd considering only
12757 // the case End = RHS of the loop termination condition. This is safe because
12758 // in the other case (End - Start) is zero, leading to a zero maximum backedge
12759 // taken count.
12760 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
12761 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
12762
12763 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
12764 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
12765 : APIntOps::umax(MaxEnd, MinStart);
12766
12767 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
12768 getConstant(StrideForMaxBECount) /* Step */);
12769}
12770
12772ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
12773 const Loop *L, bool IsSigned,
12774 bool ControlsOnlyExit, bool AllowPredicates) {
12776
12777 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
12778 bool PredicatedIV = false;
12779
12780 auto canAssumeNoSelfWrap = [&](const SCEVAddRecExpr *AR) {
12781 // Can we prove this loop *must* be UB if overflow of IV occurs?
12782 // Reasoning goes as follows:
12783 // * Suppose the IV did self wrap.
12784 // * If Stride evenly divides the iteration space, then once wrap
12785 // occurs, the loop must revisit the same values.
12786 // * We know that RHS is invariant, and that none of those values
12787 // caused this exit to be taken previously. Thus, this exit is
12788 // dynamically dead.
12789 // * If this is the sole exit, then a dead exit implies the loop
12790 // must be infinite if there are no abnormal exits.
12791 // * If the loop were infinite, then it must either not be mustprogress
12792 // or have side effects. Otherwise, it must be UB.
12793 // * It can't (by assumption), be UB so we have contradicted our
12794 // premise and can conclude the IV did not in fact self-wrap.
12795 if (!isLoopInvariant(RHS, L))
12796 return false;
12797
12798 auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
12799 if (!StrideC || !StrideC->getAPInt().isPowerOf2())
12800 return false;
12801
12802 if (!ControlsOnlyExit || !loopHasNoAbnormalExits(L))
12803 return false;
12804
12805 return loopIsFiniteByAssumption(L);
12806 };
12807
12808 if (!IV) {
12809 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
12810 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
12811 if (AR && AR->getLoop() == L && AR->isAffine()) {
12812 auto canProveNUW = [&]() {
12813 // We can use the comparison to infer no-wrap flags only if it fully
12814 // controls the loop exit.
12815 if (!ControlsOnlyExit)
12816 return false;
12817
12818 if (!isLoopInvariant(RHS, L))
12819 return false;
12820
12821 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
12822 // We need the sequence defined by AR to strictly increase in the
12823 // unsigned integer domain for the logic below to hold.
12824 return false;
12825
12826 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
12827 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
12828 // If RHS <=u Limit, then there must exist a value V in the sequence
12829 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
12830 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
12831 // overflow occurs. This limit also implies that a signed comparison
12832 // (in the wide bitwidth) is equivalent to an unsigned comparison as
12833 // the high bits on both sides must be zero.
12834 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
12835 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
12836 Limit = Limit.zext(OuterBitWidth);
12837 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
12838 };
12839 auto Flags = AR->getNoWrapFlags();
12840 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
12841 Flags = setFlags(Flags, SCEV::FlagNUW);
12842
12843 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
12844 if (AR->hasNoUnsignedWrap()) {
12845 // Emulate what getZeroExtendExpr would have done during construction
12846 // if we'd been able to infer the fact just above at that time.
12847 const SCEV *Step = AR->getStepRecurrence(*this);
12848 Type *Ty = ZExt->getType();
12849 auto *S = getAddRecExpr(
12850 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, 0),
12851 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
12852 IV = dyn_cast<SCEVAddRecExpr>(S);
12853 }
12854 }
12855 }
12856 }
12857
12858
12859 if (!IV && AllowPredicates) {
12860 // Try to make this an AddRec using runtime tests, in the first X
12861 // iterations of this loop, where X is the SCEV expression found by the
12862 // algorithm below.
12863 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
12864 PredicatedIV = true;
12865 }
12866
12867 // Avoid weird loops
12868 if (!IV || IV->getLoop() != L || !IV->isAffine())
12869 return getCouldNotCompute();
12870
12871 // A precondition of this method is that the condition being analyzed
12872 // reaches an exiting branch which dominates the latch. Given that, we can
12873 // assume that an increment which violates the nowrap specification and
12874 // produces poison must cause undefined behavior when the resulting poison
12875 // value is branched upon and thus we can conclude that the backedge is
12876 // taken no more often than would be required to produce that poison value.
12877 // Note that a well defined loop can exit on the iteration which violates
12878 // the nowrap specification if there is another exit (either explicit or
12879 // implicit/exceptional) which causes the loop to execute before the
12880 // exiting instruction we're analyzing would trigger UB.
12881 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
12882 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
12884
12885 const SCEV *Stride = IV->getStepRecurrence(*this);
12886
12887 bool PositiveStride = isKnownPositive(Stride);
12888
12889 // Avoid negative or zero stride values.
12890 if (!PositiveStride) {
12891 // We can compute the correct backedge taken count for loops with unknown
12892 // strides if we can prove that the loop is not an infinite loop with side
12893 // effects. Here's the loop structure we are trying to handle -
12894 //
12895 // i = start
12896 // do {
12897 // A[i] = i;
12898 // i += s;
12899 // } while (i < end);
12900 //
12901 // The backedge taken count for such loops is evaluated as -
12902 // (max(end, start + stride) - start - 1) /u stride
12903 //
12904 // The additional preconditions that we need to check to prove correctness
12905 // of the above formula is as follows -
12906 //
12907 // a) IV is either nuw or nsw depending upon signedness (indicated by the
12908 // NoWrap flag).
12909 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
12910 // no side effects within the loop)
12911 // c) loop has a single static exit (with no abnormal exits)
12912 //
12913 // Precondition a) implies that if the stride is negative, this is a single
12914 // trip loop. The backedge taken count formula reduces to zero in this case.
12915 //
12916 // Precondition b) and c) combine to imply that if rhs is invariant in L,
12917 // then a zero stride means the backedge can't be taken without executing
12918 // undefined behavior.
12919 //
12920 // The positive stride case is the same as isKnownPositive(Stride) returning
12921 // true (original behavior of the function).
12922 //
12923 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
12925 return getCouldNotCompute();
12926
12927 if (!isKnownNonZero(Stride)) {
12928 // If we have a step of zero, and RHS isn't invariant in L, we don't know
12929 // if it might eventually be greater than start and if so, on which
12930 // iteration. We can't even produce a useful upper bound.
12931 if (!isLoopInvariant(RHS, L))
12932 return getCouldNotCompute();
12933
12934 // We allow a potentially zero stride, but we need to divide by stride
12935 // below. Since the loop can't be infinite and this check must control
12936 // the sole exit, we can infer the exit must be taken on the first
12937 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
12938 // we know the numerator in the divides below must be zero, so we can
12939 // pick an arbitrary non-zero value for the denominator (e.g. stride)
12940 // and produce the right result.
12941 // FIXME: Handle the case where Stride is poison?
12942 auto wouldZeroStrideBeUB = [&]() {
12943 // Proof by contradiction. Suppose the stride were zero. If we can
12944 // prove that the backedge *is* taken on the first iteration, then since
12945 // we know this condition controls the sole exit, we must have an
12946 // infinite loop. We can't have a (well defined) infinite loop per
12947 // check just above.
12948 // Note: The (Start - Stride) term is used to get the start' term from
12949 // (start' + stride,+,stride). Remember that we only care about the
12950 // result of this expression when stride == 0 at runtime.
12951 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
12952 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
12953 };
12954 if (!wouldZeroStrideBeUB()) {
12955 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
12956 }
12957 }
12958 } else if (!Stride->isOne() && !NoWrap) {
12959 auto isUBOnWrap = [&]() {
12960 // From no-self-wrap, we need to then prove no-(un)signed-wrap. This
12961 // follows trivially from the fact that every (un)signed-wrapped, but
12962 // not self-wrapped value must be LT than the last value before
12963 // (un)signed wrap. Since we know that last value didn't exit, nor
12964 // will any smaller one.
12965 return canAssumeNoSelfWrap(IV);
12966 };
12967
12968 // Avoid proven overflow cases: this will ensure that the backedge taken
12969 // count will not generate any unsigned overflow. Relaxed no-overflow
12970 // conditions exploit NoWrapFlags, allowing to optimize in presence of
12971 // undefined behaviors like the case of C language.
12972 if (canIVOverflowOnLT(RHS, Stride, IsSigned) && !isUBOnWrap())
12973 return getCouldNotCompute();
12974 }
12975
12976 // On all paths just preceeding, we established the following invariant:
12977 // IV can be assumed not to overflow up to and including the exiting
12978 // iteration. We proved this in one of two ways:
12979 // 1) We can show overflow doesn't occur before the exiting iteration
12980 // 1a) canIVOverflowOnLT, and b) step of one
12981 // 2) We can show that if overflow occurs, the loop must execute UB
12982 // before any possible exit.
12983 // Note that we have not yet proved RHS invariant (in general).
12984
12985 const SCEV *Start = IV->getStart();
12986
12987 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
12988 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
12989 // Use integer-typed versions for actual computation; we can't subtract
12990 // pointers in general.
12991 const SCEV *OrigStart = Start;
12992 const SCEV *OrigRHS = RHS;
12993 if (Start->getType()->isPointerTy()) {
12994 Start = getLosslessPtrToIntExpr(Start);
12995 if (isa<SCEVCouldNotCompute>(Start))
12996 return Start;
12997 }
12998 if (RHS->getType()->isPointerTy()) {
13000 if (isa<SCEVCouldNotCompute>(RHS))
13001 return RHS;
13002 }
13003
13004 const SCEV *End = nullptr, *BECount = nullptr,
13005 *BECountIfBackedgeTaken = nullptr;
13006 if (!isLoopInvariant(RHS, L)) {
13007 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13008 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13009 RHSAddRec->getNoWrapFlags()) {
13010 // The structure of loop we are trying to calculate backedge count of:
13011 //
13012 // left = left_start
13013 // right = right_start
13014 //
13015 // while(left < right){
13016 // ... do something here ...
13017 // left += s1; // stride of left is s1 (s1 > 0)
13018 // right += s2; // stride of right is s2 (s2 < 0)
13019 // }
13020 //
13021
13022 const SCEV *RHSStart = RHSAddRec->getStart();
13023 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13024
13025 // If Stride - RHSStride is positive and does not overflow, we can write
13026 // backedge count as ->
13027 // ceil((End - Start) /u (Stride - RHSStride))
13028 // Where, End = max(RHSStart, Start)
13029
13030 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13031 if (isKnownNegative(RHSStride) &&
13032 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13033 RHSStride)) {
13034
13035 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13036 if (isKnownPositive(Denominator)) {
13037 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13038 : getUMaxExpr(RHSStart, Start);
13039
13040 // We can do this because End >= Start, as End = max(RHSStart, Start)
13041 const SCEV *Delta = getMinusSCEV(End, Start);
13042
13043 BECount = getUDivCeilSCEV(Delta, Denominator);
13044 BECountIfBackedgeTaken =
13045 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13046 }
13047 }
13048 }
13049 if (BECount == nullptr) {
13050 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13051 // given the start, stride and max value for the end bound of the
13052 // loop (RHS), and the fact that IV does not overflow (which is
13053 // checked above).
13054 const SCEV *MaxBECount = computeMaxBECountForLT(
13055 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13056 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13057 MaxBECount, false /*MaxOrZero*/, Predicates);
13058 }
13059 } else {
13060 // We use the expression (max(End,Start)-Start)/Stride to describe the
13061 // backedge count, as if the backedge is taken at least once
13062 // max(End,Start) is End and so the result is as above, and if not
13063 // max(End,Start) is Start so we get a backedge count of zero.
13064 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13065 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13066 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13067 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13068 // Can we prove (max(RHS,Start) > Start - Stride?
13069 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13070 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13071 // In this case, we can use a refined formula for computing backedge
13072 // taken count. The general formula remains:
13073 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13074 // We want to use the alternate formula:
13075 // "((End - 1) - (Start - Stride)) /u Stride"
13076 // Let's do a quick case analysis to show these are equivalent under
13077 // our precondition that max(RHS,Start) > Start - Stride.
13078 // * For RHS <= Start, the backedge-taken count must be zero.
13079 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13080 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13081 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13082 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13083 // reducing this to the stride of 1 case.
13084 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13085 // Stride".
13086 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13087 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13088 // "((RHS - (Start - Stride) - 1) /u Stride".
13089 // Our preconditions trivially imply no overflow in that form.
13090 const SCEV *MinusOne = getMinusOne(Stride->getType());
13091 const SCEV *Numerator =
13092 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13093 BECount = getUDivExpr(Numerator, Stride);
13094 }
13095
13096 if (!BECount) {
13097 auto canProveRHSGreaterThanEqualStart = [&]() {
13098 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13099 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13100 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13101
13102 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13103 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13104 return true;
13105
13106 // (RHS > Start - 1) implies RHS >= Start.
13107 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13108 // "Start - 1" doesn't overflow.
13109 // * For signed comparison, if Start - 1 does overflow, it's equal
13110 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13111 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13112 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13113 //
13114 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13115 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13116 auto *StartMinusOne =
13117 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13118 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13119 };
13120
13121 // If we know that RHS >= Start in the context of loop, then we know
13122 // that max(RHS, Start) = RHS at this point.
13123 if (canProveRHSGreaterThanEqualStart()) {
13124 End = RHS;
13125 } else {
13126 // If RHS < Start, the backedge will be taken zero times. So in
13127 // general, we can write the backedge-taken count as:
13128 //
13129 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13130 //
13131 // We convert it to the following to make it more convenient for SCEV:
13132 //
13133 // ceil(max(RHS, Start) - Start) / Stride
13134 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13135
13136 // See what would happen if we assume the backedge is taken. This is
13137 // used to compute MaxBECount.
13138 BECountIfBackedgeTaken =
13139 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13140 }
13141
13142 // At this point, we know:
13143 //
13144 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13145 // 2. The index variable doesn't overflow.
13146 //
13147 // Therefore, we know N exists such that
13148 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13149 // doesn't overflow.
13150 //
13151 // Using this information, try to prove whether the addition in
13152 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13153 const SCEV *One = getOne(Stride->getType());
13154 bool MayAddOverflow = [&] {
13155 if (auto *StrideC = dyn_cast<SCEVConstant>(Stride)) {
13156 if (StrideC->getAPInt().isPowerOf2()) {
13157 // Suppose Stride is a power of two, and Start/End are unsigned
13158 // integers. Let UMAX be the largest representable unsigned
13159 // integer.
13160 //
13161 // By the preconditions of this function, we know
13162 // "(Start + Stride * N) >= End", and this doesn't overflow.
13163 // As a formula:
13164 //
13165 // End <= (Start + Stride * N) <= UMAX
13166 //
13167 // Subtracting Start from all the terms:
13168 //
13169 // End - Start <= Stride * N <= UMAX - Start
13170 //
13171 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13172 //
13173 // End - Start <= Stride * N <= UMAX
13174 //
13175 // Stride * N is a multiple of Stride. Therefore,
13176 //
13177 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13178 //
13179 // Since Stride is a power of two, UMAX + 1 is divisible by
13180 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13181 // write:
13182 //
13183 // End - Start <= Stride * N <= UMAX - Stride - 1
13184 //
13185 // Dropping the middle term:
13186 //
13187 // End - Start <= UMAX - Stride - 1
13188 //
13189 // Adding Stride - 1 to both sides:
13190 //
13191 // (End - Start) + (Stride - 1) <= UMAX
13192 //
13193 // In other words, the addition doesn't have unsigned overflow.
13194 //
13195 // A similar proof works if we treat Start/End as signed values.
13196 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13197 // to use signed max instead of unsigned max. Note that we're
13198 // trying to prove a lack of unsigned overflow in either case.
13199 return false;
13200 }
13201 }
13202 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13203 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13204 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13205 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13206 // 1 <s End.
13207 //
13208 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13209 // End.
13210 return false;
13211 }
13212 return true;
13213 }();
13214
13215 const SCEV *Delta = getMinusSCEV(End, Start);
13216 if (!MayAddOverflow) {
13217 // floor((D + (S - 1)) / S)
13218 // We prefer this formulation if it's legal because it's fewer
13219 // operations.
13220 BECount =
13221 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13222 } else {
13223 BECount = getUDivCeilSCEV(Delta, Stride);
13224 }
13225 }
13226 }
13227
13228 const SCEV *ConstantMaxBECount;
13229 bool MaxOrZero = false;
13230 if (isa<SCEVConstant>(BECount)) {
13231 ConstantMaxBECount = BECount;
13232 } else if (BECountIfBackedgeTaken &&
13233 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13234 // If we know exactly how many times the backedge will be taken if it's
13235 // taken at least once, then the backedge count will either be that or
13236 // zero.
13237 ConstantMaxBECount = BECountIfBackedgeTaken;
13238 MaxOrZero = true;
13239 } else {
13240 ConstantMaxBECount = computeMaxBECountForLT(
13241 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13242 }
13243
13244 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13245 !isa<SCEVCouldNotCompute>(BECount))
13246 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13247
13248 const SCEV *SymbolicMaxBECount =
13249 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13250 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13251 Predicates);
13252}
13253
13254ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13255 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13256 bool ControlsOnlyExit, bool AllowPredicates) {
13258 // We handle only IV > Invariant
13259 if (!isLoopInvariant(RHS, L))
13260 return getCouldNotCompute();
13261
13262 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13263 if (!IV && AllowPredicates)
13264 // Try to make this an AddRec using runtime tests, in the first X
13265 // iterations of this loop, where X is the SCEV expression found by the
13266 // algorithm below.
13267 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13268
13269 // Avoid weird loops
13270 if (!IV || IV->getLoop() != L || !IV->isAffine())
13271 return getCouldNotCompute();
13272
13273 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13274 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13276
13277 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13278
13279 // Avoid negative or zero stride values
13280 if (!isKnownPositive(Stride))
13281 return getCouldNotCompute();
13282
13283 // Avoid proven overflow cases: this will ensure that the backedge taken count
13284 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13285 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13286 // behaviors like the case of C language.
13287 if (!Stride->isOne() && !NoWrap)
13288 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13289 return getCouldNotCompute();
13290
13291 const SCEV *Start = IV->getStart();
13292 const SCEV *End = RHS;
13293 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13294 // If we know that Start >= RHS in the context of loop, then we know that
13295 // min(RHS, Start) = RHS at this point.
13297 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13298 End = RHS;
13299 else
13300 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13301 }
13302
13303 if (Start->getType()->isPointerTy()) {
13304 Start = getLosslessPtrToIntExpr(Start);
13305 if (isa<SCEVCouldNotCompute>(Start))
13306 return Start;
13307 }
13308 if (End->getType()->isPointerTy()) {
13310 if (isa<SCEVCouldNotCompute>(End))
13311 return End;
13312 }
13313
13314 // Compute ((Start - End) + (Stride - 1)) / Stride.
13315 // FIXME: This can overflow. Holding off on fixing this for now;
13316 // howManyGreaterThans will hopefully be gone soon.
13317 const SCEV *One = getOne(Stride->getType());
13318 const SCEV *BECount = getUDivExpr(
13319 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13320
13321 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13322 : getUnsignedRangeMax(Start);
13323
13324 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13325 : getUnsignedRangeMin(Stride);
13326
13327 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13328 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13329 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13330
13331 // Although End can be a MIN expression we estimate MinEnd considering only
13332 // the case End = RHS. This is safe because in the other case (Start - End)
13333 // is zero, leading to a zero maximum backedge taken count.
13334 APInt MinEnd =
13335 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13336 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13337
13338 const SCEV *ConstantMaxBECount =
13339 isa<SCEVConstant>(BECount)
13340 ? BECount
13341 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13342 getConstant(MinStride));
13343
13344 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13345 ConstantMaxBECount = BECount;
13346 const SCEV *SymbolicMaxBECount =
13347 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13348
13349 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13350 Predicates);
13351}
13352
13354 ScalarEvolution &SE) const {
13355 if (Range.isFullSet()) // Infinite loop.
13356 return SE.getCouldNotCompute();
13357
13358 // If the start is a non-zero constant, shift the range to simplify things.
13359 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13360 if (!SC->getValue()->isZero()) {
13362 Operands[0] = SE.getZero(SC->getType());
13363 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13364 getNoWrapFlags(FlagNW));
13365 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13366 return ShiftedAddRec->getNumIterationsInRange(
13367 Range.subtract(SC->getAPInt()), SE);
13368 // This is strange and shouldn't happen.
13369 return SE.getCouldNotCompute();
13370 }
13371
13372 // The only time we can solve this is when we have all constant indices.
13373 // Otherwise, we cannot determine the overflow conditions.
13374 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13375 return SE.getCouldNotCompute();
13376
13377 // Okay at this point we know that all elements of the chrec are constants and
13378 // that the start element is zero.
13379
13380 // First check to see if the range contains zero. If not, the first
13381 // iteration exits.
13382 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13383 if (!Range.contains(APInt(BitWidth, 0)))
13384 return SE.getZero(getType());
13385
13386 if (isAffine()) {
13387 // If this is an affine expression then we have this situation:
13388 // Solve {0,+,A} in Range === Ax in Range
13389
13390 // We know that zero is in the range. If A is positive then we know that
13391 // the upper value of the range must be the first possible exit value.
13392 // If A is negative then the lower of the range is the last possible loop
13393 // value. Also note that we already checked for a full range.
13394 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13395 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13396
13397 // The exit value should be (End+A)/A.
13398 APInt ExitVal = (End + A).udiv(A);
13399 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13400
13401 // Evaluate at the exit value. If we really did fall out of the valid
13402 // range, then we computed our trip count, otherwise wrap around or other
13403 // things must have happened.
13404 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13405 if (Range.contains(Val->getValue()))
13406 return SE.getCouldNotCompute(); // Something strange happened
13407
13408 // Ensure that the previous value is in the range.
13411 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13412 "Linear scev computation is off in a bad way!");
13413 return SE.getConstant(ExitValue);
13414 }
13415
13416 if (isQuadratic()) {
13417 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13418 return SE.getConstant(*S);
13419 }
13420
13421 return SE.getCouldNotCompute();
13422}
13423
13424const SCEVAddRecExpr *
13426 assert(getNumOperands() > 1 && "AddRec with zero step?");
13427 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13428 // but in this case we cannot guarantee that the value returned will be an
13429 // AddRec because SCEV does not have a fixed point where it stops
13430 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13431 // may happen if we reach arithmetic depth limit while simplifying. So we
13432 // construct the returned value explicitly.
13434 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13435 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13436 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13437 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13438 // We know that the last operand is not a constant zero (otherwise it would
13439 // have been popped out earlier). This guarantees us that if the result has
13440 // the same last operand, then it will also not be popped out, meaning that
13441 // the returned value will be an AddRec.
13442 const SCEV *Last = getOperand(getNumOperands() - 1);
13443 assert(!Last->isZero() && "Recurrency with zero step?");
13444 Ops.push_back(Last);
13445 return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, getLoop(),
13447}
13448
13449// Return true when S contains at least an undef value.
13451 return SCEVExprContains(S, [](const SCEV *S) {
13452 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13453 return isa<UndefValue>(SU->getValue());
13454 return false;
13455 });
13456}
13457
13458// Return true when S contains a value that is a nullptr.
13460 return SCEVExprContains(S, [](const SCEV *S) {
13461 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13462 return SU->getValue() == nullptr;
13463 return false;
13464 });
13465}
13466
13467/// Return the size of an element read or written by Inst.
13469 Type *Ty;
13470 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13471 Ty = Store->getValueOperand()->getType();
13472 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13473 Ty = Load->getType();
13474 else
13475 return nullptr;
13476
13478 return getSizeOfExpr(ETy, Ty);
13479}
13480
13481//===----------------------------------------------------------------------===//
13482// SCEVCallbackVH Class Implementation
13483//===----------------------------------------------------------------------===//
13484
13485void ScalarEvolution::SCEVCallbackVH::deleted() {
13486 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13487 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13488 SE->ConstantEvolutionLoopExitValue.erase(PN);
13489 SE->eraseValueFromMap(getValPtr());
13490 // this now dangles!
13491}
13492
13493void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13494 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13495
13496 // Forget all the expressions associated with users of the old value,
13497 // so that future queries will recompute the expressions using the new
13498 // value.
13499 SE->forgetValue(getValPtr());
13500 // this now dangles!
13501}
13502
13503ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13504 : CallbackVH(V), SE(se) {}
13505
13506//===----------------------------------------------------------------------===//
13507// ScalarEvolution Class Implementation
13508//===----------------------------------------------------------------------===//
13509
13512 LoopInfo &LI)
13513 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
13514 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13515 LoopDispositions(64), BlockDispositions(64) {
13516 // To use guards for proving predicates, we need to scan every instruction in
13517 // relevant basic blocks, and not just terminators. Doing this is a waste of
13518 // time if the IR does not actually contain any calls to
13519 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13520 //
13521 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13522 // to _add_ guards to the module when there weren't any before, and wants
13523 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13524 // efficient in lieu of being smart in that rather obscure case.
13525
13526 auto *GuardDecl = F.getParent()->getFunction(
13527 Intrinsic::getName(Intrinsic::experimental_guard));
13528 HasGuards = GuardDecl && !GuardDecl->use_empty();
13529}
13530
13532 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
13533 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13534 ValueExprMap(std::move(Arg.ValueExprMap)),
13535 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13536 PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
13537 PendingMerges(std::move(Arg.PendingMerges)),
13538 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
13539 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13540 PredicatedBackedgeTakenCounts(
13541 std::move(Arg.PredicatedBackedgeTakenCounts)),
13542 BECountUsers(std::move(Arg.BECountUsers)),
13543 ConstantEvolutionLoopExitValue(
13544 std::move(Arg.ConstantEvolutionLoopExitValue)),
13545 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
13546 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
13547 LoopDispositions(std::move(Arg.LoopDispositions)),
13548 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
13549 BlockDispositions(std::move(Arg.BlockDispositions)),
13550 SCEVUsers(std::move(Arg.SCEVUsers)),
13551 UnsignedRanges(std::move(Arg.UnsignedRanges)),
13552 SignedRanges(std::move(Arg.SignedRanges)),
13553 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
13554 UniquePreds(std::move(Arg.UniquePreds)),
13555 SCEVAllocator(std::move(Arg.SCEVAllocator)),
13556 LoopUsers(std::move(Arg.LoopUsers)),
13557 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
13558 FirstUnknown(Arg.FirstUnknown) {
13559 Arg.FirstUnknown = nullptr;
13560}
13561
13563 // Iterate through all the SCEVUnknown instances and call their
13564 // destructors, so that they release their references to their values.
13565 for (SCEVUnknown *U = FirstUnknown; U;) {
13566 SCEVUnknown *Tmp = U;
13567 U = U->Next;
13568 Tmp->~SCEVUnknown();
13569 }
13570 FirstUnknown = nullptr;
13571
13572 ExprValueMap.clear();
13573 ValueExprMap.clear();
13574 HasRecMap.clear();
13575 BackedgeTakenCounts.clear();
13576 PredicatedBackedgeTakenCounts.clear();
13577
13578 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
13579 assert(PendingPhiRanges.empty() && "getRangeRef garbage");
13580 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
13581 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
13582 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
13583}
13584
13586 return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
13587}
13588
13589/// When printing a top-level SCEV for trip counts, it's helpful to include
13590/// a type for constants which are otherwise hard to disambiguate.
13591static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
13592 if (isa<SCEVConstant>(S))
13593 OS << *S->getType() << " ";
13594 OS << *S;
13595}
13596
13598 const Loop *L) {
13599 // Print all inner loops first
13600 for (Loop *I : *L)
13601 PrintLoopInfo(OS, SE, I);
13602
13603 OS << "Loop ";
13604 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13605 OS << ": ";
13606
13607 SmallVector<BasicBlock *, 8> ExitingBlocks;
13608 L->getExitingBlocks(ExitingBlocks);
13609 if (ExitingBlocks.size() != 1)
13610 OS << "<multiple exits> ";
13611
13612 auto *BTC = SE->getBackedgeTakenCount(L);
13613 if (!isa<SCEVCouldNotCompute>(BTC)) {
13614 OS << "backedge-taken count is ";
13616 } else
13617 OS << "Unpredictable backedge-taken count.";
13618 OS << "\n";
13619
13620 if (ExitingBlocks.size() > 1)
13621 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13622 OS << " exit count for " << ExitingBlock->getName() << ": ";
13623 PrintSCEVWithTypeHint(OS, SE->getExitCount(L, ExitingBlock));
13624 OS << "\n";
13625 }
13626
13627 OS << "Loop ";
13628 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13629 OS << ": ";
13630
13631 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
13632 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
13633 OS << "constant max backedge-taken count is ";
13634 PrintSCEVWithTypeHint(OS, ConstantBTC);
13636 OS << ", actual taken count either this or zero.";
13637 } else {
13638 OS << "Unpredictable constant max backedge-taken count. ";
13639 }
13640
13641 OS << "\n"
13642 "Loop ";
13643 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13644 OS << ": ";
13645
13646 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
13647 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
13648 OS << "symbolic max backedge-taken count is ";
13649 PrintSCEVWithTypeHint(OS, SymbolicBTC);
13651 OS << ", actual taken count either this or zero.";
13652 } else {
13653 OS << "Unpredictable symbolic max backedge-taken count. ";
13654 }
13655 OS << "\n";
13656
13657 if (ExitingBlocks.size() > 1)
13658 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13659 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
13660 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
13662 PrintSCEVWithTypeHint(OS, ExitBTC);
13663 OS << "\n";
13664 }
13665
13667 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
13668 if (PBT != BTC || !Preds.empty()) {
13669 OS << "Loop ";
13670 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13671 OS << ": ";
13672 if (!isa<SCEVCouldNotCompute>(PBT)) {
13673 OS << "Predicated backedge-taken count is ";
13675 } else
13676 OS << "Unpredictable predicated backedge-taken count.";
13677 OS << "\n";
13678 OS << " Predicates:\n";
13679 for (const auto *P : Preds)
13680 P->print(OS, 4);
13681 }
13682
13683 Preds.clear();
13684 auto *PredSymbolicMax =
13686 if (SymbolicBTC != PredSymbolicMax) {
13687 OS << "Loop ";
13688 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13689 OS << ": ";
13690 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
13691 OS << "Predicated symbolic max backedge-taken count is ";
13692 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
13693 } else
13694 OS << "Unpredictable predicated symbolic max backedge-taken count.";
13695 OS << "\n";
13696 OS << " Predicates:\n";
13697 for (const auto *P : Preds)
13698 P->print(OS, 4);
13699 }
13700
13702 OS << "Loop ";
13703 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13704 OS << ": ";
13705 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
13706 }
13707}
13708
13709namespace llvm {
13711 switch (LD) {
13713 OS << "Variant";
13714 break;
13716 OS << "Invariant";
13717 break;
13719 OS << "Computable";
13720 break;
13721 }
13722 return OS;
13723}
13724
13726 switch (BD) {
13728 OS << "DoesNotDominate";
13729 break;
13731 OS << "Dominates";
13732 break;
13734 OS << "ProperlyDominates";
13735 break;
13736 }
13737 return OS;
13738}
13739} // namespace llvm
13740
13742 // ScalarEvolution's implementation of the print method is to print
13743 // out SCEV values of all instructions that are interesting. Doing
13744 // this potentially causes it to create new SCEV objects though,
13745 // which technically conflicts with the const qualifier. This isn't
13746 // observable from outside the class though, so casting away the
13747 // const isn't dangerous.
13748 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
13749
13750 if (ClassifyExpressions) {
13751 OS << "Classifying expressions for: ";
13752 F.printAsOperand(OS, /*PrintType=*/false);
13753 OS << "\n";
13754 for (Instruction &I : instructions(F))
13755 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
13756 OS << I << '\n';
13757 OS << " --> ";
13758 const SCEV *SV = SE.getSCEV(&I);
13759 SV->print(OS);
13760 if (!isa<SCEVCouldNotCompute>(SV)) {
13761 OS << " U: ";
13762 SE.getUnsignedRange(SV).print(OS);
13763 OS << " S: ";
13764 SE.getSignedRange(SV).print(OS);
13765 }
13766
13767 const Loop *L = LI.getLoopFor(I.getParent());
13768
13769 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
13770 if (AtUse != SV) {
13771 OS << " --> ";
13772 AtUse->print(OS);
13773 if (!isa<SCEVCouldNotCompute>(AtUse)) {
13774 OS << " U: ";
13775 SE.getUnsignedRange(AtUse).print(OS);
13776 OS << " S: ";
13777 SE.getSignedRange(AtUse).print(OS);
13778 }
13779 }
13780
13781 if (L) {
13782 OS << "\t\t" "Exits: ";
13783 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
13784 if (!SE.isLoopInvariant(ExitValue, L)) {
13785 OS << "<<Unknown>>";
13786 } else {
13787 OS << *ExitValue;
13788 }
13789
13790 bool First = true;
13791 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
13792 if (First) {
13793 OS << "\t\t" "LoopDispositions: { ";
13794 First = false;
13795 } else {
13796 OS << ", ";
13797 }
13798
13799 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13800 OS << ": " << SE.getLoopDisposition(SV, Iter);
13801 }
13802
13803 for (const auto *InnerL : depth_first(L)) {
13804 if (InnerL == L)
13805 continue;
13806 if (First) {
13807 OS << "\t\t" "LoopDispositions: { ";
13808 First = false;
13809 } else {
13810 OS << ", ";
13811 }
13812
13813 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13814 OS << ": " << SE.getLoopDisposition(SV, InnerL);
13815 }
13816
13817 OS << " }";
13818 }
13819
13820 OS << "\n";
13821 }
13822 }
13823
13824 OS << "Determining loop execution counts for: ";
13825 F.printAsOperand(OS, /*PrintType=*/false);
13826 OS << "\n";
13827 for (Loop *I : LI)
13828 PrintLoopInfo(OS, &SE, I);
13829}
13830
13833 auto &Values = LoopDispositions[S];
13834 for (auto &V : Values) {
13835 if (V.getPointer() == L)
13836 return V.getInt();
13837 }
13838 Values.emplace_back(L, LoopVariant);
13839 LoopDisposition D = computeLoopDisposition(S, L);
13840 auto &Values2 = LoopDispositions[S];
13841 for (auto &V : llvm::reverse(Values2)) {
13842 if (V.getPointer() == L) {
13843 V.setInt(D);
13844 break;
13845 }
13846 }
13847 return D;
13848}
13849
13851ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
13852 switch (S->getSCEVType()) {
13853 case scConstant:
13854 case scVScale:
13855 return LoopInvariant;
13856 case scAddRecExpr: {
13857 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
13858
13859 // If L is the addrec's loop, it's computable.
13860 if (AR->getLoop() == L)
13861 return LoopComputable;
13862
13863 // Add recurrences are never invariant in the function-body (null loop).
13864 if (!L)
13865 return LoopVariant;
13866
13867 // Everything that is not defined at loop entry is variant.
13868 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
13869 return LoopVariant;
13870 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
13871 " dominate the contained loop's header?");
13872
13873 // This recurrence is invariant w.r.t. L if AR's loop contains L.
13874 if (AR->getLoop()->contains(L))
13875 return LoopInvariant;
13876
13877 // This recurrence is variant w.r.t. L if any of its operands
13878 // are variant.
13879 for (const auto *Op : AR->operands())
13880 if (!isLoopInvariant(Op, L))
13881 return LoopVariant;
13882
13883 // Otherwise it's loop-invariant.
13884 return LoopInvariant;
13885 }
13886 case scTruncate:
13887 case scZeroExtend:
13888 case scSignExtend:
13889 case scPtrToInt:
13890 case scAddExpr:
13891 case scMulExpr:
13892 case scUDivExpr:
13893 case scUMaxExpr:
13894 case scSMaxExpr:
13895 case scUMinExpr:
13896 case scSMinExpr:
13897 case scSequentialUMinExpr: {
13898 bool HasVarying = false;
13899 for (const auto *Op : S->operands()) {
13901 if (D == LoopVariant)
13902 return LoopVariant;
13903 if (D == LoopComputable)
13904 HasVarying = true;
13905 }
13906 return HasVarying ? LoopComputable : LoopInvariant;
13907 }
13908 case scUnknown:
13909 // All non-instruction values are loop invariant. All instructions are loop
13910 // invariant if they are not contained in the specified loop.
13911 // Instructions are never considered invariant in the function body
13912 // (null loop) because they are defined within the "loop".
13913 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
13914 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
13915 return LoopInvariant;
13916 case scCouldNotCompute:
13917 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
13918 }
13919 llvm_unreachable("Unknown SCEV kind!");
13920}
13921
13923 return getLoopDisposition(S, L) == LoopInvariant;
13924}
13925
13927 return getLoopDisposition(S, L) == LoopComputable;
13928}
13929
13932 auto &Values = BlockDispositions[S];
13933 for (auto &V : Values) {
13934 if (V.getPointer() == BB)
13935 return V.getInt();
13936 }
13937 Values.emplace_back(BB, DoesNotDominateBlock);
13938 BlockDisposition D = computeBlockDisposition(S, BB);
13939 auto &Values2 = BlockDispositions[S];
13940 for (auto &V : llvm::reverse(Values2)) {
13941 if (V.getPointer() == BB) {
13942 V.setInt(D);
13943 break;
13944 }
13945 }
13946 return D;
13947}
13948
13950ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
13951 switch (S->getSCEVType()) {
13952 case scConstant:
13953 case scVScale:
13955 case scAddRecExpr: {
13956 // This uses a "dominates" query instead of "properly dominates" query
13957 // to test for proper dominance too, because the instruction which
13958 // produces the addrec's value is a PHI, and a PHI effectively properly
13959 // dominates its entire containing block.
13960 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
13961 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
13962 return DoesNotDominateBlock;
13963
13964 // Fall through into SCEVNAryExpr handling.
13965 [[fallthrough]];
13966 }
13967 case scTruncate:
13968 case scZeroExtend:
13969 case scSignExtend:
13970 case scPtrToInt:
13971 case scAddExpr:
13972 case scMulExpr:
13973 case scUDivExpr:
13974 case scUMaxExpr:
13975 case scSMaxExpr:
13976 case scUMinExpr:
13977 case scSMinExpr:
13978 case scSequentialUMinExpr: {
13979 bool Proper = true;
13980 for (const SCEV *NAryOp : S->operands()) {
13982 if (D == DoesNotDominateBlock)
13983 return DoesNotDominateBlock;
13984 if (D == DominatesBlock)
13985 Proper = false;
13986 }
13987 return Proper ? ProperlyDominatesBlock : DominatesBlock;
13988 }
13989 case scUnknown:
13990 if (Instruction *I =
13991 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
13992 if (I->getParent() == BB)
13993 return DominatesBlock;
13994 if (DT.properlyDominates(I->getParent(), BB))
13996 return DoesNotDominateBlock;
13997 }
13999 case scCouldNotCompute:
14000 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14001 }
14002 llvm_unreachable("Unknown SCEV kind!");
14003}
14004
14005bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14006 return getBlockDisposition(S, BB) >= DominatesBlock;
14007}
14008
14011}
14012
14013bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14014 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14015}
14016
14017void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14018 bool Predicated) {
14019 auto &BECounts =
14020 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14021 auto It = BECounts.find(L);
14022 if (It != BECounts.end()) {
14023 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14024 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14025 if (!isa<SCEVConstant>(S)) {
14026 auto UserIt = BECountUsers.find(S);
14027 assert(UserIt != BECountUsers.end());
14028 UserIt->second.erase({L, Predicated});
14029 }
14030 }
14031 }
14032 BECounts.erase(It);
14033 }
14034}
14035
14036void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
14037 SmallPtrSet<const SCEV *, 8> ToForget(SCEVs.begin(), SCEVs.end());
14038 SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
14039
14040 while (!Worklist.empty()) {
14041 const SCEV *Curr = Worklist.pop_back_val();
14042 auto Users = SCEVUsers.find(Curr);
14043 if (Users != SCEVUsers.end())
14044 for (const auto *User : Users->second)
14045 if (ToForget.insert(User).second)
14046 Worklist.push_back(User);
14047 }
14048
14049 for (const auto *S : ToForget)
14050 forgetMemoizedResultsImpl(S);
14051
14052 for (auto I = PredicatedSCEVRewrites.begin();
14053 I != PredicatedSCEVRewrites.end();) {
14054 std::pair<const SCEV *, const Loop *> Entry = I->first;
14055 if (ToForget.count(Entry.first))
14056 PredicatedSCEVRewrites.erase(I++);
14057 else
14058 ++I;
14059 }
14060}
14061
14062void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14063 LoopDispositions.erase(S);
14064 BlockDispositions.erase(S);
14065 UnsignedRanges.erase(S);
14066 SignedRanges.erase(S);
14067 HasRecMap.erase(S);
14068 ConstantMultipleCache.erase(S);
14069
14070 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14071 UnsignedWrapViaInductionTried.erase(AR);
14072 SignedWrapViaInductionTried.erase(AR);
14073 }
14074
14075 auto ExprIt = ExprValueMap.find(S);
14076 if (ExprIt != ExprValueMap.end()) {
14077 for (Value *V : ExprIt->second) {
14078 auto ValueIt = ValueExprMap.find_as(V);
14079 if (ValueIt != ValueExprMap.end())
14080 ValueExprMap.erase(ValueIt);
14081 }
14082 ExprValueMap.erase(ExprIt);
14083 }
14084
14085 auto ScopeIt = ValuesAtScopes.find(S);
14086 if (ScopeIt != ValuesAtScopes.end()) {
14087 for (const auto &Pair : ScopeIt->second)
14088 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14089 llvm::erase(ValuesAtScopesUsers[Pair.second],
14090 std::make_pair(Pair.first, S));
14091 ValuesAtScopes.erase(ScopeIt);
14092 }
14093
14094 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14095 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14096 for (const auto &Pair : ScopeUserIt->second)
14097 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14098 ValuesAtScopesUsers.erase(ScopeUserIt);
14099 }
14100
14101 auto BEUsersIt = BECountUsers.find(S);
14102 if (BEUsersIt != BECountUsers.end()) {
14103 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14104 auto Copy = BEUsersIt->second;
14105 for (const auto &Pair : Copy)
14106 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14107 BECountUsers.erase(BEUsersIt);
14108 }
14109
14110 auto FoldUser = FoldCacheUser.find(S);
14111 if (FoldUser != FoldCacheUser.end())
14112 for (auto &KV : FoldUser->second)
14113 FoldCache.erase(KV);
14114 FoldCacheUser.erase(S);
14115}
14116
14117void
14118ScalarEvolution::getUsedLoops(const SCEV *S,
14119 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14120 struct FindUsedLoops {
14121 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14122 : LoopsUsed(LoopsUsed) {}
14124 bool follow(const SCEV *S) {
14125 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14126 LoopsUsed.insert(AR->getLoop());
14127 return true;
14128 }
14129
14130 bool isDone() const { return false; }
14131 };
14132
14133 FindUsedLoops F(LoopsUsed);
14135}
14136
14137void ScalarEvolution::getReachableBlocks(
14140 Worklist.push_back(&F.getEntryBlock());
14141 while (!Worklist.empty()) {
14142 BasicBlock *BB = Worklist.pop_back_val();
14143 if (!Reachable.insert(BB).second)
14144 continue;
14145
14146 Value *Cond;
14147 BasicBlock *TrueBB, *FalseBB;
14148 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14149 m_BasicBlock(FalseBB)))) {
14150 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14151 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14152 continue;
14153 }
14154
14155 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14156 const SCEV *L = getSCEV(Cmp->getOperand(0));
14157 const SCEV *R = getSCEV(Cmp->getOperand(1));
14158 if (isKnownPredicateViaConstantRanges(Cmp->getPredicate(), L, R)) {
14159 Worklist.push_back(TrueBB);
14160 continue;
14161 }
14162 if (isKnownPredicateViaConstantRanges(Cmp->getInversePredicate(), L,
14163 R)) {
14164 Worklist.push_back(FalseBB);
14165 continue;
14166 }
14167 }
14168 }
14169
14170 append_range(Worklist, successors(BB));
14171 }
14172}
14173
14175 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14176 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14177
14178 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14179
14180 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14181 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14182 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14183
14184 const SCEV *visitConstant(const SCEVConstant *Constant) {
14185 return SE.getConstant(Constant->getAPInt());
14186 }
14187
14188 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14189 return SE.getUnknown(Expr->getValue());
14190 }
14191
14192 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14193 return SE.getCouldNotCompute();
14194 }
14195 };
14196
14197 SCEVMapper SCM(SE2);
14198 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14199 SE2.getReachableBlocks(ReachableBlocks, F);
14200
14201 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14202 if (containsUndefs(Old) || containsUndefs(New)) {
14203 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14204 // not propagate undef aggressively). This means we can (and do) fail
14205 // verification in cases where a transform makes a value go from "undef"
14206 // to "undef+1" (say). The transform is fine, since in both cases the
14207 // result is "undef", but SCEV thinks the value increased by 1.
14208 return nullptr;
14209 }
14210
14211 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14212 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14213 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14214 return nullptr;
14215
14216 return Delta;
14217 };
14218
14219 while (!LoopStack.empty()) {
14220 auto *L = LoopStack.pop_back_val();
14221 llvm::append_range(LoopStack, *L);
14222
14223 // Only verify BECounts in reachable loops. For an unreachable loop,
14224 // any BECount is legal.
14225 if (!ReachableBlocks.contains(L->getHeader()))
14226 continue;
14227
14228 // Only verify cached BECounts. Computing new BECounts may change the
14229 // results of subsequent SCEV uses.
14230 auto It = BackedgeTakenCounts.find(L);
14231 if (It == BackedgeTakenCounts.end())
14232 continue;
14233
14234 auto *CurBECount =
14235 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14236 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14237
14238 if (CurBECount == SE2.getCouldNotCompute() ||
14239 NewBECount == SE2.getCouldNotCompute()) {
14240 // NB! This situation is legal, but is very suspicious -- whatever pass
14241 // change the loop to make a trip count go from could not compute to
14242 // computable or vice-versa *should have* invalidated SCEV. However, we
14243 // choose not to assert here (for now) since we don't want false
14244 // positives.
14245 continue;
14246 }
14247
14248 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14249 SE.getTypeSizeInBits(NewBECount->getType()))
14250 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14251 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14252 SE.getTypeSizeInBits(NewBECount->getType()))
14253 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14254
14255 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14256 if (Delta && !Delta->isZero()) {
14257 dbgs() << "Trip Count for " << *L << " Changed!\n";
14258 dbgs() << "Old: " << *CurBECount << "\n";
14259 dbgs() << "New: " << *NewBECount << "\n";
14260 dbgs() << "Delta: " << *Delta << "\n";
14261 std::abort();
14262 }
14263 }
14264
14265 // Collect all valid loops currently in LoopInfo.
14266 SmallPtrSet<Loop *, 32> ValidLoops;
14267 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14268 while (!Worklist.empty()) {
14269 Loop *L = Worklist.pop_back_val();
14270 if (ValidLoops.insert(L).second)
14271 Worklist.append(L->begin(), L->end());
14272 }
14273 for (const auto &KV : ValueExprMap) {
14274#ifndef NDEBUG
14275 // Check for SCEV expressions referencing invalid/deleted loops.
14276 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14277 assert(ValidLoops.contains(AR->getLoop()) &&
14278 "AddRec references invalid loop");
14279 }
14280#endif
14281
14282 // Check that the value is also part of the reverse map.
14283 auto It = ExprValueMap.find(KV.second);
14284 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14285 dbgs() << "Value " << *KV.first
14286 << " is in ValueExprMap but not in ExprValueMap\n";
14287 std::abort();
14288 }
14289
14290 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14291 if (!ReachableBlocks.contains(I->getParent()))
14292 continue;
14293 const SCEV *OldSCEV = SCM.visit(KV.second);
14294 const SCEV *NewSCEV = SE2.getSCEV(I);
14295 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14296 if (Delta && !Delta->isZero()) {
14297 dbgs() << "SCEV for value " << *I << " changed!\n"
14298 << "Old: " << *OldSCEV << "\n"
14299 << "New: " << *NewSCEV << "\n"
14300 << "Delta: " << *Delta << "\n";
14301 std::abort();
14302 }
14303 }
14304 }
14305
14306 for (const auto &KV : ExprValueMap) {
14307 for (Value *V : KV.second) {
14308 auto It = ValueExprMap.find_as(V);
14309 if (It == ValueExprMap.end()) {
14310 dbgs() << "Value " << *V
14311 << " is in ExprValueMap but not in ValueExprMap\n";
14312 std::abort();
14313 }
14314 if (It->second != KV.first) {
14315 dbgs() << "Value " << *V << " mapped to " << *It->second
14316 << " rather than " << *KV.first << "\n";
14317 std::abort();
14318 }
14319 }
14320 }
14321
14322 // Verify integrity of SCEV users.
14323 for (const auto &S : UniqueSCEVs) {
14324 for (const auto *Op : S.operands()) {
14325 // We do not store dependencies of constants.
14326 if (isa<SCEVConstant>(Op))
14327 continue;
14328 auto It = SCEVUsers.find(Op);
14329 if (It != SCEVUsers.end() && It->second.count(&S))
14330 continue;
14331 dbgs() << "Use of operand " << *Op << " by user " << S
14332 << " is not being tracked!\n";
14333 std::abort();
14334 }
14335 }
14336
14337 // Verify integrity of ValuesAtScopes users.
14338 for (const auto &ValueAndVec : ValuesAtScopes) {
14339 const SCEV *Value = ValueAndVec.first;
14340 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14341 const Loop *L = LoopAndValueAtScope.first;
14342 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14343 if (!isa<SCEVConstant>(ValueAtScope)) {
14344 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14345 if (It != ValuesAtScopesUsers.end() &&
14346 is_contained(It->second, std::make_pair(L, Value)))
14347 continue;
14348 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14349 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14350 std::abort();
14351 }
14352 }
14353 }
14354
14355 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14356 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14357 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14358 const Loop *L = LoopAndValue.first;
14359 const SCEV *Value = LoopAndValue.second;
14360 assert(!isa<SCEVConstant>(Value));
14361 auto It = ValuesAtScopes.find(Value);
14362 if (It != ValuesAtScopes.end() &&
14363 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14364 continue;
14365 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14366 << *ValueAtScope << " missing in ValuesAtScopes\n";
14367 std::abort();
14368 }
14369 }
14370
14371 // Verify integrity of BECountUsers.
14372 auto VerifyBECountUsers = [&](bool Predicated) {
14373 auto &BECounts =
14374 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14375 for (const auto &LoopAndBEInfo : BECounts) {
14376 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14377 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14378 if (!isa<SCEVConstant>(S)) {
14379 auto UserIt = BECountUsers.find(S);
14380 if (UserIt != BECountUsers.end() &&
14381 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14382 continue;
14383 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14384 << " missing from BECountUsers\n";
14385 std::abort();
14386 }
14387 }
14388 }
14389 }
14390 };
14391 VerifyBECountUsers(/* Predicated */ false);
14392 VerifyBECountUsers(/* Predicated */ true);
14393
14394 // Verify intergity of loop disposition cache.
14395 for (auto &[S, Values] : LoopDispositions) {
14396 for (auto [Loop, CachedDisposition] : Values) {
14397 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14398 if (CachedDisposition != RecomputedDisposition) {
14399 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14400 << " is incorrect: cached " << CachedDisposition << ", actual "
14401 << RecomputedDisposition << "\n";
14402 std::abort();
14403 }
14404 }
14405 }
14406
14407 // Verify integrity of the block disposition cache.
14408 for (auto &[S, Values] : BlockDispositions) {
14409 for (auto [BB, CachedDisposition] : Values) {
14410 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14411 if (CachedDisposition != RecomputedDisposition) {
14412 dbgs() << "Cached disposition of " << *S << " for block %"
14413 << BB->getName() << " is incorrect: cached " << CachedDisposition
14414 << ", actual " << RecomputedDisposition << "\n";
14415 std::abort();
14416 }
14417 }
14418 }
14419
14420 // Verify FoldCache/FoldCacheUser caches.
14421 for (auto [FoldID, Expr] : FoldCache) {
14422 auto I = FoldCacheUser.find(Expr);
14423 if (I == FoldCacheUser.end()) {
14424 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14425 << "!\n";
14426 std::abort();
14427 }
14428 if (!is_contained(I->second, FoldID)) {
14429 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14430 std::abort();
14431 }
14432 }
14433 for (auto [Expr, IDs] : FoldCacheUser) {
14434 for (auto &FoldID : IDs) {
14435 auto I = FoldCache.find(FoldID);
14436 if (I == FoldCache.end()) {
14437 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14438 << "!\n";
14439 std::abort();
14440 }
14441 if (I->second != Expr) {
14442 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: "
14443 << *I->second << " != " << *Expr << "!\n";
14444 std::abort();
14445 }
14446 }
14447 }
14448
14449 // Verify that ConstantMultipleCache computations are correct. We check that
14450 // cached multiples and recomputed multiples are multiples of each other to
14451 // verify correctness. It is possible that a recomputed multiple is different
14452 // from the cached multiple due to strengthened no wrap flags or changes in
14453 // KnownBits computations.
14454 for (auto [S, Multiple] : ConstantMultipleCache) {
14455 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14456 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14457 Multiple.urem(RecomputedMultiple) != 0 &&
14458 RecomputedMultiple.urem(Multiple) != 0)) {
14459 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14460 << *S << " : Computed " << RecomputedMultiple
14461 << " but cache contains " << Multiple << "!\n";
14462 std::abort();
14463 }
14464 }
14465}
14466
14468 Function &F, const PreservedAnalyses &PA,
14470 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14471 // of its dependencies is invalidated.
14472 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14473 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14474 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14476 Inv.invalidate<LoopAnalysis>(F, PA);
14477}
14478
14479AnalysisKey ScalarEvolutionAnalysis::Key;
14480
14483 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14484 auto &AC = AM.getResult<AssumptionAnalysis>(F);
14485 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
14486 auto &LI = AM.getResult<LoopAnalysis>(F);
14487 return ScalarEvolution(F, TLI, AC, DT, LI);
14488}
14489
14493 return PreservedAnalyses::all();
14494}
14495
14498 // For compatibility with opt's -analyze feature under legacy pass manager
14499 // which was not ported to NPM. This keeps tests using
14500 // update_analyze_test_checks.py working.
14501 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
14502 << F.getName() << "':\n";
14504 return PreservedAnalyses::all();
14505}
14506
14508 "Scalar Evolution Analysis", false, true)
14514 "Scalar Evolution Analysis", false, true)
14515
14517
14520}
14521
14523 SE.reset(new ScalarEvolution(
14524 F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F),
14525 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
14526 getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
14527 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
14528 return false;
14529}
14530
14532
14534 SE->print(OS);
14535}
14536
14538 if (!VerifySCEV)
14539 return;
14540
14541 SE->verify();
14542}
14543
14545 AU.setPreservesAll();
14550}
14551
14553 const SCEV *RHS) {
14555}
14556
14557const SCEVPredicate *
14559 const SCEV *LHS, const SCEV *RHS) {
14561 assert(LHS->getType() == RHS->getType() &&
14562 "Type mismatch between LHS and RHS");
14563 // Unique this node based on the arguments
14564 ID.AddInteger(SCEVPredicate::P_Compare);
14565 ID.AddInteger(Pred);
14566 ID.AddPointer(LHS);
14567 ID.AddPointer(RHS);
14568 void *IP = nullptr;
14569 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14570 return S;
14571 SCEVComparePredicate *Eq = new (SCEVAllocator)
14572 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
14573 UniquePreds.InsertNode(Eq, IP);
14574 return Eq;
14575}
14576
14578 const SCEVAddRecExpr *AR,
14581 // Unique this node based on the arguments
14582 ID.AddInteger(SCEVPredicate::P_Wrap);
14583 ID.AddPointer(AR);
14584 ID.AddInteger(AddedFlags);
14585 void *IP = nullptr;
14586 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14587 return S;
14588 auto *OF = new (SCEVAllocator)
14589 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
14590 UniquePreds.InsertNode(OF, IP);
14591 return OF;
14592}
14593
14594namespace {
14595
14596class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14597public:
14598
14599 /// Rewrites \p S in the context of a loop L and the SCEV predication
14600 /// infrastructure.
14601 ///
14602 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
14603 /// equivalences present in \p Pred.
14604 ///
14605 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
14606 /// \p NewPreds such that the result will be an AddRecExpr.
14607 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
14609 const SCEVPredicate *Pred) {
14610 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
14611 return Rewriter.visit(S);
14612 }
14613
14614 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14615 if (Pred) {
14616 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
14617 for (const auto *Pred : U->getPredicates())
14618 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
14619 if (IPred->getLHS() == Expr &&
14620 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14621 return IPred->getRHS();
14622 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
14623 if (IPred->getLHS() == Expr &&
14624 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14625 return IPred->getRHS();
14626 }
14627 }
14628 return convertToAddRecWithPreds(Expr);
14629 }
14630
14631 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14632 const SCEV *Operand = visit(Expr->getOperand());
14633 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14634 if (AR && AR->getLoop() == L && AR->isAffine()) {
14635 // This couldn't be folded because the operand didn't have the nuw
14636 // flag. Add the nusw flag as an assumption that we could make.
14637 const SCEV *Step = AR->getStepRecurrence(SE);
14638 Type *Ty = Expr->getType();
14639 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
14640 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
14641 SE.getSignExtendExpr(Step, Ty), L,
14642 AR->getNoWrapFlags());
14643 }
14644 return SE.getZeroExtendExpr(Operand, Expr->getType());
14645 }
14646
14647 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
14648 const SCEV *Operand = visit(Expr->getOperand());
14649 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14650 if (AR && AR->getLoop() == L && AR->isAffine()) {
14651 // This couldn't be folded because the operand didn't have the nsw
14652 // flag. Add the nssw flag as an assumption that we could make.
14653 const SCEV *Step = AR->getStepRecurrence(SE);
14654 Type *Ty = Expr->getType();
14655 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
14656 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
14657 SE.getSignExtendExpr(Step, Ty), L,
14658 AR->getNoWrapFlags());
14659 }
14660 return SE.getSignExtendExpr(Operand, Expr->getType());
14661 }
14662
14663private:
14664 explicit SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE,
14666 const SCEVPredicate *Pred)
14667 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
14668
14669 bool addOverflowAssumption(const SCEVPredicate *P) {
14670 if (!NewPreds) {
14671 // Check if we've already made this assumption.
14672 return Pred && Pred->implies(P);
14673 }
14674 NewPreds->insert(P);
14675 return true;
14676 }
14677
14678 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
14680 auto *A = SE.getWrapPredicate(AR, AddedFlags);
14681 return addOverflowAssumption(A);
14682 }
14683
14684 // If \p Expr represents a PHINode, we try to see if it can be represented
14685 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
14686 // to add this predicate as a runtime overflow check, we return the AddRec.
14687 // If \p Expr does not meet these conditions (is not a PHI node, or we
14688 // couldn't create an AddRec for it, or couldn't add the predicate), we just
14689 // return \p Expr.
14690 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
14691 if (!isa<PHINode>(Expr->getValue()))
14692 return Expr;
14693 std::optional<
14694 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
14695 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
14696 if (!PredicatedRewrite)
14697 return Expr;
14698 for (const auto *P : PredicatedRewrite->second){
14699 // Wrap predicates from outer loops are not supported.
14700 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
14701 if (L != WP->getExpr()->getLoop())
14702 return Expr;
14703 }
14704 if (!addOverflowAssumption(P))
14705 return Expr;
14706 }
14707 return PredicatedRewrite->first;
14708 }
14709
14711 const SCEVPredicate *Pred;
14712 const Loop *L;
14713};
14714
14715} // end anonymous namespace
14716
14717const SCEV *
14719 const SCEVPredicate &Preds) {
14720 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
14721}
14722
14724 const SCEV *S, const Loop *L,
14727 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
14728 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
14729
14730 if (!AddRec)
14731 return nullptr;
14732
14733 // Since the transformation was successful, we can now transfer the SCEV
14734 // predicates.
14735 for (const auto *P : TransformPreds)
14736 Preds.insert(P);
14737
14738 return AddRec;
14739}
14740
14741/// SCEV predicates
14743 SCEVPredicateKind Kind)
14744 : FastID(ID), Kind(Kind) {}
14745
14747 const ICmpInst::Predicate Pred,
14748 const SCEV *LHS, const SCEV *RHS)
14749 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
14750 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
14751 assert(LHS != RHS && "LHS and RHS are the same SCEV");
14752}
14753
14755 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
14756
14757 if (!Op)
14758 return false;
14759
14760 if (Pred != ICmpInst::ICMP_EQ)
14761 return false;
14762
14763 return Op->LHS == LHS && Op->RHS == RHS;
14764}
14765
14766bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
14767
14769 if (Pred == ICmpInst::ICMP_EQ)
14770 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
14771 else
14772 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
14773 << *RHS << "\n";
14774
14775}
14776
14778 const SCEVAddRecExpr *AR,
14779 IncrementWrapFlags Flags)
14780 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
14781
14782const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
14783
14785 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
14786
14787 return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
14788}
14789
14791 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
14792 IncrementWrapFlags IFlags = Flags;
14793
14794 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
14795 IFlags = clearFlags(IFlags, IncrementNSSW);
14796
14797 return IFlags == IncrementAnyWrap;
14798}
14799
14801 OS.indent(Depth) << *getExpr() << " Added Flags: ";
14803 OS << "<nusw>";
14805 OS << "<nssw>";
14806 OS << "\n";
14807}
14808
14811 ScalarEvolution &SE) {
14812 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
14813 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
14814
14815 // We can safely transfer the NSW flag as NSSW.
14816 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
14817 ImpliedFlags = IncrementNSSW;
14818
14819 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
14820 // If the increment is positive, the SCEV NUW flag will also imply the
14821 // WrapPredicate NUSW flag.
14822 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
14823 if (Step->getValue()->getValue().isNonNegative())
14824 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
14825 }
14826
14827 return ImpliedFlags;
14828}
14829
14830/// Union predicates don't get cached so create a dummy set ID for it.
14832 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
14833 for (const auto *P : Preds)
14834 add(P);
14835}
14836
14838 return all_of(Preds,
14839 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
14840}
14841
14843 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
14844 return all_of(Set->Preds,
14845 [this](const SCEVPredicate *I) { return this->implies(I); });
14846
14847 return any_of(Preds,
14848 [N](const SCEVPredicate *I) { return I->implies(N); });
14849}
14850
14852 for (const auto *Pred : Preds)
14853 Pred->print(OS, Depth);
14854}
14855
14856void SCEVUnionPredicate::add(const SCEVPredicate *N) {
14857 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
14858 for (const auto *Pred : Set->Preds)
14859 add(Pred);
14860 return;
14861 }
14862
14863 // Only add predicate if it is not already implied by this union predicate.
14864 if (!implies(N))
14865 Preds.push_back(N);
14866}
14867
14869 Loop &L)
14870 : SE(SE), L(L) {
14872 Preds = std::make_unique<SCEVUnionPredicate>(Empty);
14873}
14874
14877 for (const auto *Op : Ops)
14878 // We do not expect that forgetting cached data for SCEVConstants will ever
14879 // open any prospects for sharpening or introduce any correctness issues,
14880 // so we don't bother storing their dependencies.
14881 if (!isa<SCEVConstant>(Op))
14882 SCEVUsers[Op].insert(User);
14883}
14884
14886 const SCEV *Expr = SE.getSCEV(V);
14887 RewriteEntry &Entry = RewriteMap[Expr];
14888
14889 // If we already have an entry and the version matches, return it.
14890 if (Entry.second && Generation == Entry.first)
14891 return Entry.second;
14892
14893 // We found an entry but it's stale. Rewrite the stale entry
14894 // according to the current predicate.
14895 if (Entry.second)
14896 Expr = Entry.second;
14897
14898 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
14899 Entry = {Generation, NewSCEV};
14900
14901 return NewSCEV;
14902}
14903
14905 if (!BackedgeCount) {
14907 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
14908 for (const auto *P : Preds)
14909 addPredicate(*P);
14910 }
14911 return BackedgeCount;
14912}
14913
14915 if (!SymbolicMaxBackedgeCount) {
14917 SymbolicMaxBackedgeCount =
14919 for (const auto *P : Preds)
14920 addPredicate(*P);
14921 }
14922 return SymbolicMaxBackedgeCount;
14923}
14924
14926 if (Preds->implies(&Pred))
14927 return;
14928
14929 auto &OldPreds = Preds->getPredicates();
14930 SmallVector<const SCEVPredicate*, 4> NewPreds(OldPreds.begin(), OldPreds.end());
14931 NewPreds.push_back(&Pred);
14932 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds);
14933 updateGeneration();
14934}
14935
14937 return *Preds;
14938}
14939
14940void PredicatedScalarEvolution::updateGeneration() {
14941 // If the generation number wrapped recompute everything.
14942 if (++Generation == 0) {
14943 for (auto &II : RewriteMap) {
14944 const SCEV *Rewritten = II.second.second;
14945 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
14946 }
14947 }
14948}
14949
14952 const SCEV *Expr = getSCEV(V);
14953 const auto *AR = cast<SCEVAddRecExpr>(Expr);
14954
14955 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
14956
14957 // Clear the statically implied flags.
14958 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
14959 addPredicate(*SE.getWrapPredicate(AR, Flags));
14960
14961 auto II = FlagsMap.insert({V, Flags});
14962 if (!II.second)
14963 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
14964}
14965
14968 const SCEV *Expr = getSCEV(V);
14969 const auto *AR = cast<SCEVAddRecExpr>(Expr);
14970
14972 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
14973
14974 auto II = FlagsMap.find(V);
14975
14976 if (II != FlagsMap.end())
14977 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
14978
14980}
14981
14983 const SCEV *Expr = this->getSCEV(V);
14985 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
14986
14987 if (!New)
14988 return nullptr;
14989
14990 for (const auto *P : NewPreds)
14991 addPredicate(*P);
14992
14993 RewriteMap[SE.getSCEV(V)] = {Generation, New};
14994 return New;
14995}
14996
14999 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15000 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())),
15001 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15002 for (auto I : Init.FlagsMap)
15003 FlagsMap.insert(I);
15004}
15005
15007 // For each block.
15008 for (auto *BB : L.getBlocks())
15009 for (auto &I : *BB) {
15010 if (!SE.isSCEVable(I.getType()))
15011 continue;
15012
15013 auto *Expr = SE.getSCEV(&I);
15014 auto II = RewriteMap.find(Expr);
15015
15016 if (II == RewriteMap.end())
15017 continue;
15018
15019 // Don't print things that are not interesting.
15020 if (II->second.second == Expr)
15021 continue;
15022
15023 OS.indent(Depth) << "[PSE]" << I << ":\n";
15024 OS.indent(Depth + 2) << *Expr << "\n";
15025 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15026 }
15027}
15028
15029// Match the mathematical pattern A - (A / B) * B, where A and B can be
15030// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
15031// for URem with constant power-of-2 second operands.
15032// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
15033// 4, A / B becomes X / 8).
15034bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
15035 const SCEV *&RHS) {
15036 if (Expr->getType()->isPointerTy())
15037 return false;
15038
15039 // Try to match 'zext (trunc A to iB) to iY', which is used
15040 // for URem with constant power-of-2 second operands. Make sure the size of
15041 // the operand A matches the size of the whole expressions.
15042 if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr))
15043 if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) {
15044 LHS = Trunc->getOperand();
15045 // Bail out if the type of the LHS is larger than the type of the
15046 // expression for now.
15047 if (getTypeSizeInBits(LHS->getType()) >
15048 getTypeSizeInBits(Expr->getType()))
15049 return false;
15050 if (LHS->getType() != Expr->getType())
15051 LHS = getZeroExtendExpr(LHS, Expr->getType());
15053 << getTypeSizeInBits(Trunc->getType()));
15054 return true;
15055 }
15056 const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
15057 if (Add == nullptr || Add->getNumOperands() != 2)
15058 return false;
15059
15060 const SCEV *A = Add->getOperand(1);
15061 const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
15062
15063 if (Mul == nullptr)
15064 return false;
15065
15066 const auto MatchURemWithDivisor = [&](const SCEV *B) {
15067 // (SomeExpr + (-(SomeExpr / B) * B)).
15068 if (Expr == getURemExpr(A, B)) {
15069 LHS = A;
15070 RHS = B;
15071 return true;
15072 }
15073 return false;
15074 };
15075
15076 // (SomeExpr + (-1 * (SomeExpr / B) * B)).
15077 if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
15078 return MatchURemWithDivisor(Mul->getOperand(1)) ||
15079 MatchURemWithDivisor(Mul->getOperand(2));
15080
15081 // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
15082 if (Mul->getNumOperands() == 2)
15083 return MatchURemWithDivisor(Mul->getOperand(1)) ||
15084 MatchURemWithDivisor(Mul->getOperand(0)) ||
15085 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
15086 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
15087 return false;
15088}
15089
15092 LoopGuards Guards(SE);
15093 SmallVector<const SCEV *> ExprsToRewrite;
15094 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15095 const SCEV *RHS,
15097 &RewriteMap) {
15098 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15099 // replacement SCEV which isn't directly implied by the structure of that
15100 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15101 // legal. See the scoping rules for flags in the header to understand why.
15102
15103 // If LHS is a constant, apply information to the other expression.
15104 if (isa<SCEVConstant>(LHS)) {
15105 std::swap(LHS, RHS);
15106 Predicate = CmpInst::getSwappedPredicate(Predicate);
15107 }
15108
15109 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15110 // create this form when combining two checks of the form (X u< C2 + C1) and
15111 // (X >=u C1).
15112 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15113 &ExprsToRewrite]() {
15114 auto *AddExpr = dyn_cast<SCEVAddExpr>(LHS);
15115 if (!AddExpr || AddExpr->getNumOperands() != 2)
15116 return false;
15117
15118 auto *C1 = dyn_cast<SCEVConstant>(AddExpr->getOperand(0));
15119 auto *LHSUnknown = dyn_cast<SCEVUnknown>(AddExpr->getOperand(1));
15120 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15121 if (!C1 || !C2 || !LHSUnknown)
15122 return false;
15123
15124 auto ExactRegion =
15125 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15126 .sub(C1->getAPInt());
15127
15128 // Bail out, unless we have a non-wrapping, monotonic range.
15129 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15130 return false;
15131 auto I = RewriteMap.find(LHSUnknown);
15132 const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHSUnknown;
15133 RewriteMap[LHSUnknown] = SE.getUMaxExpr(
15134 SE.getConstant(ExactRegion.getUnsignedMin()),
15135 SE.getUMinExpr(RewrittenLHS,
15136 SE.getConstant(ExactRegion.getUnsignedMax())));
15137 ExprsToRewrite.push_back(LHSUnknown);
15138 return true;
15139 };
15140 if (MatchRangeCheckIdiom())
15141 return;
15142
15143 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15144 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15145 // the non-constant operand and in \p LHS the constant operand.
15146 auto IsMinMaxSCEVWithNonNegativeConstant =
15147 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15148 const SCEV *&RHS) {
15149 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15150 if (MinMax->getNumOperands() != 2)
15151 return false;
15152 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15153 if (C->getAPInt().isNegative())
15154 return false;
15155 SCTy = MinMax->getSCEVType();
15156 LHS = MinMax->getOperand(0);
15157 RHS = MinMax->getOperand(1);
15158 return true;
15159 }
15160 }
15161 return false;
15162 };
15163
15164 // Checks whether Expr is a non-negative constant, and Divisor is a positive
15165 // constant, and returns their APInt in ExprVal and in DivisorVal.
15166 auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor,
15167 APInt &ExprVal, APInt &DivisorVal) {
15168 auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
15169 auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15170 if (!ConstExpr || !ConstDivisor)
15171 return false;
15172 ExprVal = ConstExpr->getAPInt();
15173 DivisorVal = ConstDivisor->getAPInt();
15174 return ExprVal.isNonNegative() && !DivisorVal.isNonPositive();
15175 };
15176
15177 // Return a new SCEV that modifies \p Expr to the closest number divides by
15178 // \p Divisor and greater or equal than Expr.
15179 // For now, only handle constant Expr and Divisor.
15180 auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
15181 const SCEV *Divisor) {
15182 APInt ExprVal;
15183 APInt DivisorVal;
15184 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15185 return Expr;
15186 APInt Rem = ExprVal.urem(DivisorVal);
15187 if (!Rem.isZero())
15188 // return the SCEV: Expr + Divisor - Expr % Divisor
15189 return SE.getConstant(ExprVal + DivisorVal - Rem);
15190 return Expr;
15191 };
15192
15193 // Return a new SCEV that modifies \p Expr to the closest number divides by
15194 // \p Divisor and less or equal than Expr.
15195 // For now, only handle constant Expr and Divisor.
15196 auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
15197 const SCEV *Divisor) {
15198 APInt ExprVal;
15199 APInt DivisorVal;
15200 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15201 return Expr;
15202 APInt Rem = ExprVal.urem(DivisorVal);
15203 // return the SCEV: Expr - Expr % Divisor
15204 return SE.getConstant(ExprVal - Rem);
15205 };
15206
15207 // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15208 // recursively. This is done by aligning up/down the constant value to the
15209 // Divisor.
15210 std::function<const SCEV *(const SCEV *, const SCEV *)>
15211 ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15212 const SCEV *Divisor) {
15213 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15214 SCEVTypes SCTy;
15215 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15216 MinMaxRHS))
15217 return MinMaxExpr;
15218 auto IsMin =
15219 isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15220 assert(SE.isKnownNonNegative(MinMaxLHS) &&
15221 "Expected non-negative operand!");
15222 auto *DivisibleExpr =
15223 IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor)
15224 : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor);
15226 ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15227 return SE.getMinMaxExpr(SCTy, Ops);
15228 };
15229
15230 // If we have LHS == 0, check if LHS is computing a property of some unknown
15231 // SCEV %v which we can rewrite %v to express explicitly.
15232 const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS);
15233 if (Predicate == CmpInst::ICMP_EQ && RHSC &&
15234 RHSC->getValue()->isNullValue()) {
15235 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15236 // explicitly express that.
15237 const SCEV *URemLHS = nullptr;
15238 const SCEV *URemRHS = nullptr;
15239 if (SE.matchURem(LHS, URemLHS, URemRHS)) {
15240 if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
15241 auto I = RewriteMap.find(LHSUnknown);
15242 const SCEV *RewrittenLHS =
15243 I != RewriteMap.end() ? I->second : LHSUnknown;
15244 RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15245 const auto *Multiple =
15246 SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15247 RewriteMap[LHSUnknown] = Multiple;
15248 ExprsToRewrite.push_back(LHSUnknown);
15249 return;
15250 }
15251 }
15252 }
15253
15254 // Do not apply information for constants or if RHS contains an AddRec.
15255 if (isa<SCEVConstant>(LHS) || SE.containsAddRecurrence(RHS))
15256 return;
15257
15258 // If RHS is SCEVUnknown, make sure the information is applied to it.
15259 if (!isa<SCEVUnknown>(LHS) && isa<SCEVUnknown>(RHS)) {
15260 std::swap(LHS, RHS);
15261 Predicate = CmpInst::getSwappedPredicate(Predicate);
15262 }
15263
15264 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15265 // and \p FromRewritten are the same (i.e. there has been no rewrite
15266 // registered for \p From), then puts this value in the list of rewritten
15267 // expressions.
15268 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15269 const SCEV *To) {
15270 if (From == FromRewritten)
15271 ExprsToRewrite.push_back(From);
15272 RewriteMap[From] = To;
15273 };
15274
15275 // Checks whether \p S has already been rewritten. In that case returns the
15276 // existing rewrite because we want to chain further rewrites onto the
15277 // already rewritten value. Otherwise returns \p S.
15278 auto GetMaybeRewritten = [&](const SCEV *S) {
15279 auto I = RewriteMap.find(S);
15280 return I != RewriteMap.end() ? I->second : S;
15281 };
15282
15283 // Check for the SCEV expression (A /u B) * B while B is a constant, inside
15284 // \p Expr. The check is done recuresively on \p Expr, which is assumed to
15285 // be a composition of Min/Max SCEVs. Return whether the SCEV expression (A
15286 // /u B) * B was found, and return the divisor B in \p DividesBy. For
15287 // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since
15288 // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p
15289 // DividesBy.
15290 std::function<bool(const SCEV *, const SCEV *&)> HasDivisibiltyInfo =
15291 [&](const SCEV *Expr, const SCEV *&DividesBy) {
15292 if (auto *Mul = dyn_cast<SCEVMulExpr>(Expr)) {
15293 if (Mul->getNumOperands() != 2)
15294 return false;
15295 auto *MulLHS = Mul->getOperand(0);
15296 auto *MulRHS = Mul->getOperand(1);
15297 if (isa<SCEVConstant>(MulLHS))
15298 std::swap(MulLHS, MulRHS);
15299 if (auto *Div = dyn_cast<SCEVUDivExpr>(MulLHS))
15300 if (Div->getOperand(1) == MulRHS) {
15301 DividesBy = MulRHS;
15302 return true;
15303 }
15304 }
15305 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15306 return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) ||
15307 HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy);
15308 return false;
15309 };
15310
15311 // Return true if Expr known to divide by \p DividesBy.
15312 std::function<bool(const SCEV *, const SCEV *&)> IsKnownToDivideBy =
15313 [&](const SCEV *Expr, const SCEV *DividesBy) {
15314 if (SE.getURemExpr(Expr, DividesBy)->isZero())
15315 return true;
15316 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15317 return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) &&
15318 IsKnownToDivideBy(MinMax->getOperand(1), DividesBy);
15319 return false;
15320 };
15321
15322 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15323 const SCEV *DividesBy = nullptr;
15324 if (HasDivisibiltyInfo(RewrittenLHS, DividesBy))
15325 // Check that the whole expression is divided by DividesBy
15326 DividesBy =
15327 IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr;
15328
15329 // Collect rewrites for LHS and its transitive operands based on the
15330 // condition.
15331 // For min/max expressions, also apply the guard to its operands:
15332 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15333 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15334 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15335 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15336
15337 // We cannot express strict predicates in SCEV, so instead we replace them
15338 // with non-strict ones against plus or minus one of RHS depending on the
15339 // predicate.
15340 const SCEV *One = SE.getOne(RHS->getType());
15341 switch (Predicate) {
15342 case CmpInst::ICMP_ULT:
15343 if (RHS->getType()->isPointerTy())
15344 return;
15345 RHS = SE.getUMaxExpr(RHS, One);
15346 [[fallthrough]];
15347 case CmpInst::ICMP_SLT: {
15348 RHS = SE.getMinusSCEV(RHS, One);
15349 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15350 break;
15351 }
15352 case CmpInst::ICMP_UGT:
15353 case CmpInst::ICMP_SGT:
15354 RHS = SE.getAddExpr(RHS, One);
15355 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15356 break;
15357 case CmpInst::ICMP_ULE:
15358 case CmpInst::ICMP_SLE:
15359 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15360 break;
15361 case CmpInst::ICMP_UGE:
15362 case CmpInst::ICMP_SGE:
15363 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15364 break;
15365 default:
15366 break;
15367 }
15368
15371
15372 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15373 append_range(Worklist, S->operands());
15374 };
15375
15376 while (!Worklist.empty()) {
15377 const SCEV *From = Worklist.pop_back_val();
15378 if (isa<SCEVConstant>(From))
15379 continue;
15380 if (!Visited.insert(From).second)
15381 continue;
15382 const SCEV *FromRewritten = GetMaybeRewritten(From);
15383 const SCEV *To = nullptr;
15384
15385 switch (Predicate) {
15386 case CmpInst::ICMP_ULT:
15387 case CmpInst::ICMP_ULE:
15388 To = SE.getUMinExpr(FromRewritten, RHS);
15389 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15390 EnqueueOperands(UMax);
15391 break;
15392 case CmpInst::ICMP_SLT:
15393 case CmpInst::ICMP_SLE:
15394 To = SE.getSMinExpr(FromRewritten, RHS);
15395 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15396 EnqueueOperands(SMax);
15397 break;
15398 case CmpInst::ICMP_UGT:
15399 case CmpInst::ICMP_UGE:
15400 To = SE.getUMaxExpr(FromRewritten, RHS);
15401 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15402 EnqueueOperands(UMin);
15403 break;
15404 case CmpInst::ICMP_SGT:
15405 case CmpInst::ICMP_SGE:
15406 To = SE.getSMaxExpr(FromRewritten, RHS);
15407 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15408 EnqueueOperands(SMin);
15409 break;
15410 case CmpInst::ICMP_EQ:
15411 if (isa<SCEVConstant>(RHS))
15412 To = RHS;
15413 break;
15414 case CmpInst::ICMP_NE:
15415 if (isa<SCEVConstant>(RHS) &&
15416 cast<SCEVConstant>(RHS)->getValue()->isNullValue()) {
15417 const SCEV *OneAlignedUp =
15418 DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
15419 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
15420 }
15421 break;
15422 default:
15423 break;
15424 }
15425
15426 if (To)
15427 AddRewrite(From, FromRewritten, To);
15428 }
15429 };
15430
15431 BasicBlock *Header = L->getHeader();
15433 // First, collect information from assumptions dominating the loop.
15434 for (auto &AssumeVH : SE.AC.assumptions()) {
15435 if (!AssumeVH)
15436 continue;
15437 auto *AssumeI = cast<CallInst>(AssumeVH);
15438 if (!SE.DT.dominates(AssumeI, Header))
15439 continue;
15440 Terms.emplace_back(AssumeI->getOperand(0), true);
15441 }
15442
15443 // Second, collect information from llvm.experimental.guards dominating the loop.
15444 auto *GuardDecl = SE.F.getParent()->getFunction(
15445 Intrinsic::getName(Intrinsic::experimental_guard));
15446 if (GuardDecl)
15447 for (const auto *GU : GuardDecl->users())
15448 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
15449 if (Guard->getFunction() == Header->getParent() &&
15450 SE.DT.dominates(Guard, Header))
15451 Terms.emplace_back(Guard->getArgOperand(0), true);
15452
15453 // Third, collect conditions from dominating branches. Starting at the loop
15454 // predecessor, climb up the predecessor chain, as long as there are
15455 // predecessors that can be found that have unique successors leading to the
15456 // original header.
15457 // TODO: share this logic with isLoopEntryGuardedByCond.
15458 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(
15459 L->getLoopPredecessor(), Header);
15460 Pair.first;
15461 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
15462
15463 const BranchInst *LoopEntryPredicate =
15464 dyn_cast<BranchInst>(Pair.first->getTerminator());
15465 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
15466 continue;
15467
15468 Terms.emplace_back(LoopEntryPredicate->getCondition(),
15469 LoopEntryPredicate->getSuccessor(0) == Pair.second);
15470 }
15471
15472 // Now apply the information from the collected conditions to
15473 // Guards.RewriteMap. Conditions are processed in reverse order, so the
15474 // earliest conditions is processed first. This ensures the SCEVs with the
15475 // shortest dependency chains are constructed first.
15476 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
15477 SmallVector<Value *, 8> Worklist;
15479 Worklist.push_back(Term);
15480 while (!Worklist.empty()) {
15481 Value *Cond = Worklist.pop_back_val();
15482 if (!Visited.insert(Cond).second)
15483 continue;
15484
15485 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
15486 auto Predicate =
15487 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
15488 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
15489 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15490 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap);
15491 continue;
15492 }
15493
15494 Value *L, *R;
15495 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
15496 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
15497 Worklist.push_back(L);
15498 Worklist.push_back(R);
15499 }
15500 }
15501 }
15502
15503 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
15504 // the replacement expressions are contained in the ranges of the replaced
15505 // expressions.
15506 Guards.PreserveNUW = true;
15507 Guards.PreserveNSW = true;
15508 for (const SCEV *Expr : ExprsToRewrite) {
15509 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15510 Guards.PreserveNUW &=
15511 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
15512 Guards.PreserveNSW &=
15513 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
15514 }
15515
15516 // Now that all rewrite information is collect, rewrite the collected
15517 // expressions with the information in the map. This applies information to
15518 // sub-expressions.
15519 if (ExprsToRewrite.size() > 1) {
15520 for (const SCEV *Expr : ExprsToRewrite) {
15521 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15522 Guards.RewriteMap.erase(Expr);
15523 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
15524 }
15525 }
15526 return Guards;
15527}
15528
15530 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
15531 /// in the map. It skips AddRecExpr because we cannot guarantee that the
15532 /// replacement is loop invariant in the loop of the AddRec.
15533 class SCEVLoopGuardRewriter
15534 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
15536
15538
15539 public:
15540 SCEVLoopGuardRewriter(ScalarEvolution &SE,
15541 const ScalarEvolution::LoopGuards &Guards)
15542 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) {
15543 if (Guards.PreserveNUW)
15544 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
15545 if (Guards.PreserveNSW)
15546 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
15547 }
15548
15549 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
15550
15551 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
15552 auto I = Map.find(Expr);
15553 if (I == Map.end())
15554 return Expr;
15555 return I->second;
15556 }
15557
15558 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
15559 auto I = Map.find(Expr);
15560 if (I == Map.end()) {
15561 // If we didn't find the extact ZExt expr in the map, check if there's
15562 // an entry for a smaller ZExt we can use instead.
15563 Type *Ty = Expr->getType();
15564 const SCEV *Op = Expr->getOperand(0);
15565 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
15566 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
15567 Bitwidth > Op->getType()->getScalarSizeInBits()) {
15568 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
15569 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
15570 auto I = Map.find(NarrowExt);
15571 if (I != Map.end())
15572 return SE.getZeroExtendExpr(I->second, Ty);
15573 Bitwidth = Bitwidth / 2;
15574 }
15575
15577 Expr);
15578 }
15579 return I->second;
15580 }
15581
15582 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
15583 auto I = Map.find(Expr);
15584 if (I == Map.end())
15586 Expr);
15587 return I->second;
15588 }
15589
15590 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
15591 auto I = Map.find(Expr);
15592 if (I == Map.end())
15594 return I->second;
15595 }
15596
15597 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
15598 auto I = Map.find(Expr);
15599 if (I == Map.end())
15601 return I->second;
15602 }
15603
15604 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
15606 bool Changed = false;
15607 for (const auto *Op : Expr->operands()) {
15608 Operands.push_back(
15610 Changed |= Op != Operands.back();
15611 }
15612 // We are only replacing operands with equivalent values, so transfer the
15613 // flags from the original expression.
15614 return !Changed ? Expr
15615 : SE.getAddExpr(Operands,
15617 Expr->getNoWrapFlags(), FlagMask));
15618 }
15619
15620 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
15622 bool Changed = false;
15623 for (const auto *Op : Expr->operands()) {
15624 Operands.push_back(
15626 Changed |= Op != Operands.back();
15627 }
15628 // We are only replacing operands with equivalent values, so transfer the
15629 // flags from the original expression.
15630 return !Changed ? Expr
15631 : SE.getMulExpr(Operands,
15633 Expr->getNoWrapFlags(), FlagMask));
15634 }
15635 };
15636
15637 if (RewriteMap.empty())
15638 return Expr;
15639
15640 SCEVLoopGuardRewriter Rewriter(SE, *this);
15641 return Rewriter.visit(Expr);
15642}
15643
15644const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
15645 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
15646}
15647
15649 const LoopGuards &Guards) {
15650 return Guards.rewrite(Expr);
15651}
static const LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
basic Basic Alias true
block Block Frequency Analysis
BlockVerifier::State From
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition: Compiler.h:537
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(X)
Definition: Debug.h:101
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
uint64_t Size
bool End
Definition: ELF_riscv.cpp:480
Generic implementation of equivalence classes through the use Tarjan's efficient union-find algorithm...
static GCMetadataPrinterRegistry::Add< ErlangGCPrinter > X("erlang", "erlang-compatible garbage collector")
static bool isSigned(unsigned int Opcode)
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition: IVUsers.cpp:48
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition: Lint.cpp:512
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
mir Rename Register Operands
#define T1
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
static GCMetadataPrinterRegistry::Add< OcamlGCMetadataPrinter > Y("ocaml", "ocaml 3.10-compatible collector")
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
if(VerifyEach)
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:55
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:59
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
static bool rewrite(Function &F)
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
SI optimize exec mask operations pre RA
This file contains some templates that are useful if you are working with the STL at all.
raw_pwrite_stream & OS
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static void GroupByComplexity(SmallVectorImpl< const SCEV * > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS)
static std::optional< int > CompareSCEVComplexity(EquivalenceClasses< const SCEV * > &EqCacheSCEV, EquivalenceClasses< const Value * > &EqCacheValue, const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool hasHugeExpression(ArrayRef< const SCEV * > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static int CompareValueComplexity(EquivalenceClasses< const Value * > &EqCacheValue, const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, ScalarEvolution &SE)
Finds the minimum unsigned root of the following equation:
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static bool CollectAddOperandsWithScales(DenseMap< const SCEV *, APInt > &M, SmallVectorImpl< const SCEV * > &NewOps, APInt &AccumulatedConstant, ArrayRef< const SCEV * > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static bool isKnownPredicateExtendIdiom(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, const ArrayRef< const SCEV * > Ops, SCEV::NoWrapFlags Flags)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
scalar evolution
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:167
This file contains some functions that are useful when dealing with strings.
static SymbolRef::Type getType(const Symbol *Sym)
Definition: TapiFile.cpp:40
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition: VPlanSLP.cpp:191
Virtual Register Rewriter
Definition: VirtRegMap.cpp:237
Value * RHS
Value * LHS
static const uint32_t IV[8]
Definition: blake3_impl.h:78
Class for arbitrary precision integers.
Definition: APInt.h:78
APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition: APInt.cpp:1941
APInt udiv(const APInt &RHS) const
Unsigned division operation.
Definition: APInt.cpp:1543
APInt zext(unsigned width) const
Zero extend to a new width.
Definition: APInt.cpp:981
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition: APInt.h:403
uint64_t getZExtValue() const
Get zero extended value.
Definition: APInt.h:1500
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition: APInt.h:1372
APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition: APInt.cpp:608
APInt zextOrTrunc(unsigned width) const
Zero extend or truncate to width.
Definition: APInt.cpp:1002
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition: APInt.h:1472
APInt trunc(unsigned width) const
Truncate to new width.
Definition: APInt.cpp:906
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition: APInt.h:186
APInt abs() const
Get the absolute value.
Definition: APInt.h:1753
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition: APInt.h:1162
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition: APInt.h:360
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition: APInt.h:446
APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition: APInt.cpp:1636
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition: APInt.h:1448
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition: APInt.h:1091
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition: APInt.h:189
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition: APInt.h:196
bool isNegative() const
Determine sign of this APInt.
Definition: APInt.h:309
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition: APInt.h:1146
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition: APInt.h:199
unsigned countTrailingZeros() const
Definition: APInt.h:1606
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition: APInt.h:336
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition: APInt.h:807
APInt multiplicativeInverse() const
Definition: APInt.cpp:1244
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition: APInt.h:1130
APInt sext(unsigned width) const
Sign extend to a new width.
Definition: APInt.cpp:954
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition: APInt.h:853
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition: APInt.h:286
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition: APInt.h:321
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition: APInt.h:1110
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition: APInt.h:180
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition: APInt.h:412
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition: APInt.h:219
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition: APInt.h:1201
This templated class represents "all analyses that operate over <a particular IR unit>" (e....
Definition: Analysis.h:49
API to communicate dependencies between analyses during invalidation.
Definition: PassManager.h:292
bool invalidate(IRUnitT &IR, const PreservedAnalyses &PA)
Trigger the invalidation of some other analysis pass if not already handled and return whether it was...
Definition: PassManager.h:310
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:405
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
ArrayRef< T > take_front(size_t N=1) const
Return a copy of *this with only the first N elements.
Definition: ArrayRef.h:228
iterator end() const
Definition: ArrayRef.h:154
size_t size() const
size - Get the array size.
Definition: ArrayRef.h:165
iterator begin() const
Definition: ArrayRef.h:153
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< ResultElem > assumptions()
Access the list of assumption handles currently tracked for this function.
bool isSingleEdge() const
Check if this is the only edge between Start and End.
Definition: Dominators.cpp:51
LLVM Basic Block Representation.
Definition: BasicBlock.h:61
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:438
const Instruction & front() const
Definition: BasicBlock.h:461
const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
Definition: BasicBlock.cpp:457
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:209
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:229
Value * getRHS() const
unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
Value * getLHS() const
BinaryOps getOpcode() const
Definition: InstrTypes.h:442
Conditional or Unconditional Branch instruction.
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Value * getCondition() const
LLVM_ATTRIBUTE_RETURNS_NONNULL void * Allocate(size_t Size, Align Alignment)
Allocate space at the specified alignment.
Definition: Allocator.h:148
This class represents a function call, abstracting a target machine's calling convention.
Value handle with callbacks on RAUW and destruction.
Definition: ValueHandle.h:383
void setValPtr(Value *P)
Definition: ValueHandle.h:390
bool isFalseWhenEqual() const
This is just a convenience.
Definition: InstrTypes.h:1062
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:757
@ ICMP_SLT
signed less than
Definition: InstrTypes.h:786
@ ICMP_SLE
signed less or equal
Definition: InstrTypes.h:787
@ ICMP_UGE
unsigned greater or equal
Definition: InstrTypes.h:781
@ ICMP_UGT
unsigned greater than
Definition: InstrTypes.h:780
@ ICMP_SGT
signed greater than
Definition: InstrTypes.h:784
@ ICMP_ULT
unsigned less than
Definition: InstrTypes.h:782
@ ICMP_EQ
equal
Definition: InstrTypes.h:778
@ ICMP_NE
not equal
Definition: InstrTypes.h:779
@ ICMP_SGE
signed greater or equal
Definition: InstrTypes.h:785
@ ICMP_ULE
unsigned less or equal
Definition: InstrTypes.h:783
bool isSigned() const
Definition: InstrTypes.h:1007
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition: InstrTypes.h:909
bool isTrueWhenEqual() const
This is just a convenience.
Definition: InstrTypes.h:1056
Predicate getNonStrictPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
Definition: InstrTypes.h:953
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition: InstrTypes.h:871
Predicate getPredicate() const
Return the predicate for this instruction.
Definition: InstrTypes.h:847
Predicate getFlippedSignednessPredicate()
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->Failed assert.
Definition: InstrTypes.h:1050
bool isUnsigned() const
Definition: InstrTypes.h:1013
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition: InstrTypes.h:1003
static Constant * getNot(Constant *C)
Definition: Constants.cpp:2593
static Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:2255
static Constant * getGetElementPtr(Type *Ty, Constant *C, ArrayRef< Constant * > IdxList, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReducedTy=nullptr)
Getelementptr form.
Definition: Constants.h:1240
static Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
Definition: Constants.cpp:2599
static Constant * getNeg(Constant *C, bool HasNSW=false)
Definition: Constants.cpp:2587
static Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:2241
This is the shared class of boolean and integer constants.
Definition: Constants.h:81
bool isMinusOne() const
This function will return true iff every bit in this constant is set to true.
Definition: Constants.h:218
bool isOne() const
This is just a convenience method to make client code smaller for a common case.
Definition: Constants.h:212
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition: Constants.h:206
static ConstantInt * getFalse(LLVMContext &Context)
Definition: Constants.cpp:857
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition: Constants.h:155
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition: Constants.h:146
static ConstantInt * getBool(LLVMContext &Context, bool V)
Definition: Constants.cpp:864
This class represents a range of values.
Definition: ConstantRange.h:47
ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
ConstantRange subtract(const APInt &CI) const
Subtract the specified constant from the endpoints of this constant range.
const APInt & getLower() const
Return the lower value for this range.
ConstantRange truncate(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other? NOTE: false does not mean that inverse pr...
bool isEmptySet() const
Return true if this set contains no members.
ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
void print(raw_ostream &OS) const
Print out the bounds to a stream.
ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
bool contains(const APInt &Val) const
Return true if the specified value is in the set.
APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
Definition: ConstantRange.h:84
static ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition: Constant.h:42
bool isNullValue() const
Return true if this is the value that would be returned by getNullValue.
Definition: Constants.cpp:90
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:110
const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
Definition: DataLayout.cpp:720
IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
Definition: DataLayout.cpp:878
unsigned getIndexTypeSizeInBits(Type *Ty) const
Layout size of the index used in GEP calculation.
Definition: DataLayout.cpp:774
IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
Definition: DataLayout.cpp:905
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition: DataLayout.h:672
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition: DenseMap.h:202
iterator find(const_arg_type_t< KeyT > Val)
Definition: DenseMap.h:155
bool erase(const KeyT &Val)
Definition: DenseMap.h:345
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition: DenseMap.h:71
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition: DenseMap.h:180
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition: DenseMap.h:151
iterator end()
Definition: DenseMap.h:84
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition: DenseMap.h:145
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:220
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:279
bool properlyDominates(const DomTreeNodeBase< NodeT > *A, const DomTreeNodeBase< NodeT > *B) const
properlyDominates - Returns true iff A dominates B and A != B.
Legacy analysis pass which computes a DominatorTree.
Definition: Dominators.h:317
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:162
bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
Definition: Dominators.cpp:321
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:122
EquivalenceClasses - This represents a collection of equivalence classes and supports three efficient...
member_iterator unionSets(const ElemTy &V1, const ElemTy &V2)
union - Merge the two equivalence sets for the specified values, inserting them if they do not alread...
bool isEquivalent(const ElemTy &V1, const ElemTy &V2) const
FoldingSetNodeIDRef - This class describes a reference to an interned FoldingSetNodeID,...
Definition: FoldingSet.h:290
FoldingSetNodeID - This class is used to gather all the unique data bits of a node.
Definition: FoldingSet.h:327
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:311
const BasicBlock & getEntryBlock() const
Definition: Function.h:800
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:656
static bool isPrivateLinkage(LinkageTypes Linkage)
Definition: GlobalValue.h:406
static bool isInternalLinkage(LinkageTypes Linkage)
Definition: GlobalValue.h:403
This instruction compares its operands according to the predicate given to the constructor.
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
static bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
bool isEquality() const
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
Class to represent integer types.
Definition: DerivedTypes.h:40
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition: Type.cpp:278
An instruction for reading from memory.
Definition: Instructions.h:174
Analysis pass that exposes the LoopInfo for a function.
Definition: LoopInfo.h:571
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
iterator end() const
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
iterator begin() const
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition: LoopInfo.h:598
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:44
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition: LoopInfo.cpp:61
Metadata node.
Definition: Metadata.h:1067
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
Function * getFunction(StringRef Name) const
Look up the specified function in the module symbol table.
Definition: Module.cpp:193
This is a utility class that provides an abstraction for the common functionality between Instruction...
Definition: Operator.h:32
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition: Operator.h:42
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition: Operator.h:77
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition: Operator.h:110
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition: Operator.h:104
iterator_range< const_block_iterator > blocks() const
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
Definition: DerivedTypes.h:662
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1852
An interface layer with SCEV used to manage how we see SCEV expressions for values in the context of ...
void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
const SCEVPredicate & getPredicate() const
bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition: Analysis.h:264
constexpr bool isValid() const
Definition: Register.h:116
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
This is the base class for unary cast operator classes.
const SCEV * getOperand() const
SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
bool implies(const SCEVPredicate *N) const override
Implementation of the SCEVPredicate interface.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
const SCEV * getOperand(unsigned i) const
const SCEV *const * Operands
ArrayRef< const SCEV * > operands() const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
virtual bool implies(const SCEVPredicate *N) const =0
Returns true if this predicate implies N.
SCEVPredicate(const SCEVPredicate &)=default
virtual void print(raw_ostream &OS, unsigned Depth=0) const =0
Prints a textual representation of this predicate with an indentation of Depth.
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed maximum selection.
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
This class represents a sequential/in-order unsigned minimum selection.
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
void visitAll(const SCEV *Root)
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
const SCEV * getLHS() const
const SCEV * getRHS() const
This class represents an unsigned maximum selection.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds)
Union predicates don't get cached so create a dummy set ID for it.
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
bool implies(const SCEVPredicate *N) const override
Returns true if this predicate implies N.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
ArrayRef< const SCEV * > operands() const
Return operands of this SCEV expression.
unsigned short getExpressionSize() const
bool isOne() const
Return true if the expression is a constant one.
bool isZero() const
Return true if the expression is a constant zero.
void dump() const
This method is used for debugging.
bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEVTypes getSCEVType() const
Type * getType() const
Return the LLVM type of this SCEV expression.
NoWrapFlags
NoWrapFlags are bitfield indices into SubclassData.
Analysis pass that exposes the ScalarEvolution for a function.
ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
bool isLoopBackedgeGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
const SCEV * getSMaxExpr(const SCEV *LHS, const SCEV *RHS)
const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
const SCEV * getGEPExpr(GEPOperator *GEP, const SmallVectorImpl< const SCEV * > &IndexExprs)
Returns an expression for a GEP.
Type * getWiderType(Type *Ty1, Type *Ty2) const
const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
const SCEV * getURemExpr(const SCEV *LHS, const SCEV *RHS)
Represents an unsigned remainder expression based on unsigned division.
bool SimplifyICmpOperands(ICmpInst::Predicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
APInt getConstantMultiple(const SCEV *S)
Returns the max constant multiple of S.
bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
const SCEV * getSMinExpr(const SCEV *LHS, const SCEV *RHS)
const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
const SCEV * getUMaxExpr(const SCEV *LHS, const SCEV *RHS)
void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?...
ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
const SCEV * getConstant(ConstantInt *V)
const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
unsigned getSmallConstantMaxTripCount(const Loop *L)
Returns the upper bound of the loop trip count as a normal unsigned value.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
const SCEV * getLosslessPtrToIntExpr(const SCEV *Op, unsigned Depth=0)
bool isKnownViaInduction(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
std::optional< bool > evaluatePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
bool isKnownPredicateAt(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
bool isKnownPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEV * getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
bool isKnownOnEveryIteration(ICmpInst::Predicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
const SCEV * getUDivExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
std::optional< bool > evaluatePredicateAt(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
uint32_t getMinTrailingZeros(const SCEV *S)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
void print(raw_ostream &OS) const
const SCEV * getUMinExpr(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVector< const SCEVPredicate *, 4 > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
void forgetTopmostLoop(const Loop *L)
void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallPtrSetImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
const SCEV * getCouldNotCompute()
bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
const SCEV * getVScale(Type *Ty)
unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
const SCEV * getMulExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
const SCEV * getUnknown(Value *V)
std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool isLoopEntryGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
const SCEV * getElementCount(Type *Ty, ElementCount EC)
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, int Mask)
Convenient NoWrapFlags manipulation that hides enum casts and is visible in the ScalarEvolution name ...
std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
const SCEV * getUDivExactExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
const SCEV * getAddExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVector< const SCEVPredicate *, 4 > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVMContext & getContext() const
This class represents the LLVM 'select' instruction.
size_type size() const
Definition: SmallPtrSet.h:94
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:323
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:344
bool contains(ConstPtrType Ptr) const
Definition: SmallPtrSet.h:418
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:479
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition: SmallSet.h:135
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition: SmallSet.h:179
size_type size() const
Definition: SmallSet.h:161
bool empty() const
Definition: SmallVector.h:94
size_t size() const
Definition: SmallVector.h:91
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:586
reference emplace_back(ArgTypes &&... Args)
Definition: SmallVector.h:950
void reserve(size_type N)
Definition: SmallVector.h:676
iterator erase(const_iterator CI)
Definition: SmallVector.h:750
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:696
iterator insert(iterator I, T &&Elt)
Definition: SmallVector.h:818
void push_back(const T &Elt)
Definition: SmallVector.h:426
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1209
An instruction for storing to memory.
Definition: Instructions.h:290
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition: DataLayout.h:622
TypeSize getElementOffset(unsigned Idx) const
Definition: DataLayout.h:651
TypeSize getSizeInBits() const
Definition: DataLayout.h:631
Class to represent struct types.
Definition: DerivedTypes.h:216
Multiway switch.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
bool isPointerTy() const
True if this is an instance of PointerType.
Definition: Type.h:255
static IntegerType * getInt1Ty(LLVMContext &C)
static IntegerType * getIntNTy(LLVMContext &C, unsigned N)
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
static IntegerType * getInt8Ty(LLVMContext &C)
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition: Type.h:243
static IntegerType * getInt32Ty(LLVMContext &C)
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition: Type.h:228
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
op_range operands()
Definition: User.h:242
Use & Op()
Definition: User.h:133
Value * getOperand(unsigned i) const
Definition: User.h:169
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition: Value.h:532
void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
Definition: AsmWriter.cpp:5105
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:1075
iterator_range< use_iterator > uses()
Definition: Value.h:376
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
Represents an op.with.overflow intrinsic.
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition: TypeSize.h:171
const ParentTy * getParent() const
Definition: ilist_node.h:32
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition: APInt.h:2197
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition: APInt.h:2202
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition: APInt.h:2207
std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition: APInt.cpp:2781
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition: APInt.h:2212
APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition: APInt.cpp:767
@ Entry
Definition: COFF.h:811
@ Exit
Definition: COFF.h:812
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
StringRef getName(ID id)
Return the LLVM name for an intrinsic, such as "llvm.ppc.altivec.lvx".
Definition: Function.cpp:1071
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
Definition: PatternMatch.h:168
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
bind_ty< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
Definition: PatternMatch.h:822
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
brc_match< Cond_t, bind_ty< BasicBlock >, bind_ty< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
apint_match m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
Definition: PatternMatch.h:299
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:92
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
class_match< BasicBlock > m_BasicBlock()
Match an arbitrary basic block value and ignore it.
Definition: PatternMatch.h:189
@ ReallyHidden
Definition: CommandLine.h:138
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:443
LocationClass< Ty > location(Ty &L)
Definition: CommandLine.h:463
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
constexpr double e
Definition: MathExtras.h:47
NodeAddr< PhiNode * > Phi
Definition: RDFGraph.h:390
@ FalseVal
Definition: TGLexer.h:59
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition: STLExtras.h:329
@ Offset
Definition: DWP.cpp:480
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
Definition: DynamicAPInt.h:388
void stable_sort(R &&Range)
Definition: STLExtras.h:1995
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1722
bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)
Definition: ScopeExit.h:59
bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
Definition: Verifier.cpp:7128
auto successors(const MachineBasicBlock *BB)
void * PointerTy
Definition: GenericValue.h:21
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition: STLExtras.h:2067
Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
unsigned short computeExpressionSize(ArrayRef< const SCEV * > Args)
bool VerifySCEV
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr)
ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition: bit.h:215
Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition: STLExtras.h:2059
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1729
bool getObjectSize(const Value *Ptr, uint64_t &Size, const DataLayout &DL, const TargetLibraryInfo *TLI, ObjectSizeOpts Opts={})
Compute the size of the object pointed by Ptr.
void initializeScalarEvolutionWrapperPassPass(PassRegistry &)
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:419
bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
Definition: LoopInfo.cpp:1150
bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
Definition: LoopInfo.cpp:1140
bool programUndefinedIfPoison(const Instruction *Inst)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
bool isPointerTy(const Type *T)
Definition: SPIRVUtils.h:120
ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
Constant * ConstantFoldInstOperands(Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
@ First
Helpers to iterate all locations in the MemoryEffectsBase class.
bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition: MathExtras.h:260
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition: STLExtras.h:1914
void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
DWARFExpression::Operation Op
auto max_element(R &&Range)
Definition: STLExtras.h:1986
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
Definition: APFixedPoint.h:293
constexpr unsigned BitWidth
Definition: BitmaskEnum.h:191
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1849
bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition: STLExtras.h:1921
auto predecessors(const MachineBasicBlock *BB)
bool isAllocationFn(const Value *V, const TargetLibraryInfo *TLI)
Tests if a value is a call or invoke to a library function that allocates or reallocates memory (eith...
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition: STLExtras.h:1879
unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Return the number of times the sign bit of the register is replicated into the other bits.
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition: Sequence.h:305
bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition: BitVector.h:858
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:860
#define N
#define NC
Definition: regutils.h:42
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition: Analysis.h:28
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition: KnownBits.h:290
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition: KnownBits.h:97
static KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
Definition: KnownBits.cpp:428
unsigned getBitWidth() const
Get the bit width of this value.
Definition: KnownBits.h:40
static KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
Definition: KnownBits.cpp:370
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition: KnownBits.h:185
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition: KnownBits.h:134
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition: KnownBits.h:118
bool isNegative() const
Returns true if this value is known to be negative.
Definition: KnownBits.h:94
static KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
Definition: KnownBits.cpp:285
Various options to control the behavior of getObjectSize.
bool NullIsUnknownSize
If this is true, null pointers in address space 0 will be treated as though they can't be evaluated.
bool RoundToAlign
Whether to round the result up to the alignment of allocas, byval arguments, and global variables.
An object of this class is returned by queries that could not be answered.
static bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
void addPredicate(const SCEVPredicate *P)