LLVM 19.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
66#include "llvm/ADT/FoldingSet.h"
67#include "llvm/ADT/STLExtras.h"
68#include "llvm/ADT/ScopeExit.h"
69#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/SmallSet.h"
73#include "llvm/ADT/Statistic.h"
75#include "llvm/ADT/StringRef.h"
84#include "llvm/Config/llvm-config.h"
85#include "llvm/IR/Argument.h"
86#include "llvm/IR/BasicBlock.h"
87#include "llvm/IR/CFG.h"
88#include "llvm/IR/Constant.h"
90#include "llvm/IR/Constants.h"
91#include "llvm/IR/DataLayout.h"
93#include "llvm/IR/Dominators.h"
94#include "llvm/IR/Function.h"
95#include "llvm/IR/GlobalAlias.h"
96#include "llvm/IR/GlobalValue.h"
98#include "llvm/IR/InstrTypes.h"
99#include "llvm/IR/Instruction.h"
100#include "llvm/IR/Instructions.h"
102#include "llvm/IR/Intrinsics.h"
103#include "llvm/IR/LLVMContext.h"
104#include "llvm/IR/Operator.h"
105#include "llvm/IR/PatternMatch.h"
106#include "llvm/IR/Type.h"
107#include "llvm/IR/Use.h"
108#include "llvm/IR/User.h"
109#include "llvm/IR/Value.h"
110#include "llvm/IR/Verifier.h"
112#include "llvm/Pass.h"
113#include "llvm/Support/Casting.h"
116#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136
137#define DEBUG_TYPE "scalar-evolution"
138
139STATISTIC(NumExitCountsComputed,
140 "Number of loop exits with predictable exit counts");
141STATISTIC(NumExitCountsNotComputed,
142 "Number of loop exits without predictable exit counts");
143STATISTIC(NumBruteForceTripCountsComputed,
144 "Number of loops with trip counts computed by force");
145
146#ifdef EXPENSIVE_CHECKS
147bool llvm::VerifySCEV = true;
148#else
149bool llvm::VerifySCEV = false;
150#endif
151
153 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
154 cl::desc("Maximum number of iterations SCEV will "
155 "symbolically execute a constant "
156 "derived loop"),
157 cl::init(100));
158
160 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
161 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
163 "verify-scev-strict", cl::Hidden,
164 cl::desc("Enable stricter verification with -verify-scev is passed"));
165
167 "scev-verify-ir", cl::Hidden,
168 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
169 cl::init(false));
170
172 "scev-mulops-inline-threshold", cl::Hidden,
173 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
174 cl::init(32));
175
177 "scev-addops-inline-threshold", cl::Hidden,
178 cl::desc("Threshold for inlining addition operands into a SCEV"),
179 cl::init(500));
180
182 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
183 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
184 cl::init(32));
185
187 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
188 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
189 cl::init(2));
190
192 "scalar-evolution-max-value-compare-depth", cl::Hidden,
193 cl::desc("Maximum depth of recursive value complexity comparisons"),
194 cl::init(2));
195
197 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
198 cl::desc("Maximum depth of recursive arithmetics"),
199 cl::init(32));
200
202 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
203 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
204
206 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
207 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
208 cl::init(8));
209
211 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
212 cl::desc("Max coefficients in AddRec during evolving"),
213 cl::init(8));
214
216 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
217 cl::desc("Size of the expression which is considered huge"),
218 cl::init(4096));
219
221 "scev-range-iter-threshold", cl::Hidden,
222 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
223 cl::init(32));
224
225static cl::opt<bool>
226ClassifyExpressions("scalar-evolution-classify-expressions",
227 cl::Hidden, cl::init(true),
228 cl::desc("When printing analysis, include information on every instruction"));
229
231 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
232 cl::init(false),
233 cl::desc("Use more powerful methods of sharpening expression ranges. May "
234 "be costly in terms of compile time"));
235
237 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
238 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
239 "Phi strongly connected components"),
240 cl::init(8));
241
242static cl::opt<bool>
243 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
244 cl::desc("Handle <= and >= in finite loops"),
245 cl::init(true));
246
248 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
249 cl::desc("Infer nuw/nsw flags using context where suitable"),
250 cl::init(true));
251
252//===----------------------------------------------------------------------===//
253// SCEV class definitions
254//===----------------------------------------------------------------------===//
255
256//===----------------------------------------------------------------------===//
257// Implementation of the SCEV class.
258//
259
260#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
262 print(dbgs());
263 dbgs() << '\n';
264}
265#endif
266
268 switch (getSCEVType()) {
269 case scConstant:
270 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
271 return;
272 case scVScale:
273 OS << "vscale";
274 return;
275 case scPtrToInt: {
276 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
277 const SCEV *Op = PtrToInt->getOperand();
278 OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
279 << *PtrToInt->getType() << ")";
280 return;
281 }
282 case scTruncate: {
283 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
284 const SCEV *Op = Trunc->getOperand();
285 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
286 << *Trunc->getType() << ")";
287 return;
288 }
289 case scZeroExtend: {
290 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
291 const SCEV *Op = ZExt->getOperand();
292 OS << "(zext " << *Op->getType() << " " << *Op << " to "
293 << *ZExt->getType() << ")";
294 return;
295 }
296 case scSignExtend: {
297 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
298 const SCEV *Op = SExt->getOperand();
299 OS << "(sext " << *Op->getType() << " " << *Op << " to "
300 << *SExt->getType() << ")";
301 return;
302 }
303 case scAddRecExpr: {
304 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
305 OS << "{" << *AR->getOperand(0);
306 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
307 OS << ",+," << *AR->getOperand(i);
308 OS << "}<";
309 if (AR->hasNoUnsignedWrap())
310 OS << "nuw><";
311 if (AR->hasNoSignedWrap())
312 OS << "nsw><";
313 if (AR->hasNoSelfWrap() &&
315 OS << "nw><";
316 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
317 OS << ">";
318 return;
319 }
320 case scAddExpr:
321 case scMulExpr:
322 case scUMaxExpr:
323 case scSMaxExpr:
324 case scUMinExpr:
325 case scSMinExpr:
327 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
328 const char *OpStr = nullptr;
329 switch (NAry->getSCEVType()) {
330 case scAddExpr: OpStr = " + "; break;
331 case scMulExpr: OpStr = " * "; break;
332 case scUMaxExpr: OpStr = " umax "; break;
333 case scSMaxExpr: OpStr = " smax "; break;
334 case scUMinExpr:
335 OpStr = " umin ";
336 break;
337 case scSMinExpr:
338 OpStr = " smin ";
339 break;
341 OpStr = " umin_seq ";
342 break;
343 default:
344 llvm_unreachable("There are no other nary expression types.");
345 }
346 OS << "(";
347 ListSeparator LS(OpStr);
348 for (const SCEV *Op : NAry->operands())
349 OS << LS << *Op;
350 OS << ")";
351 switch (NAry->getSCEVType()) {
352 case scAddExpr:
353 case scMulExpr:
354 if (NAry->hasNoUnsignedWrap())
355 OS << "<nuw>";
356 if (NAry->hasNoSignedWrap())
357 OS << "<nsw>";
358 break;
359 default:
360 // Nothing to print for other nary expressions.
361 break;
362 }
363 return;
364 }
365 case scUDivExpr: {
366 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
367 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
368 return;
369 }
370 case scUnknown:
371 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
372 return;
374 OS << "***COULDNOTCOMPUTE***";
375 return;
376 }
377 llvm_unreachable("Unknown SCEV kind!");
378}
379
381 switch (getSCEVType()) {
382 case scConstant:
383 return cast<SCEVConstant>(this)->getType();
384 case scVScale:
385 return cast<SCEVVScale>(this)->getType();
386 case scPtrToInt:
387 case scTruncate:
388 case scZeroExtend:
389 case scSignExtend:
390 return cast<SCEVCastExpr>(this)->getType();
391 case scAddRecExpr:
392 return cast<SCEVAddRecExpr>(this)->getType();
393 case scMulExpr:
394 return cast<SCEVMulExpr>(this)->getType();
395 case scUMaxExpr:
396 case scSMaxExpr:
397 case scUMinExpr:
398 case scSMinExpr:
399 return cast<SCEVMinMaxExpr>(this)->getType();
401 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
402 case scAddExpr:
403 return cast<SCEVAddExpr>(this)->getType();
404 case scUDivExpr:
405 return cast<SCEVUDivExpr>(this)->getType();
406 case scUnknown:
407 return cast<SCEVUnknown>(this)->getType();
409 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
410 }
411 llvm_unreachable("Unknown SCEV kind!");
412}
413
415 switch (getSCEVType()) {
416 case scConstant:
417 case scVScale:
418 case scUnknown:
419 return {};
420 case scPtrToInt:
421 case scTruncate:
422 case scZeroExtend:
423 case scSignExtend:
424 return cast<SCEVCastExpr>(this)->operands();
425 case scAddRecExpr:
426 case scAddExpr:
427 case scMulExpr:
428 case scUMaxExpr:
429 case scSMaxExpr:
430 case scUMinExpr:
431 case scSMinExpr:
433 return cast<SCEVNAryExpr>(this)->operands();
434 case scUDivExpr:
435 return cast<SCEVUDivExpr>(this)->operands();
437 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
438 }
439 llvm_unreachable("Unknown SCEV kind!");
440}
441
442bool SCEV::isZero() const {
443 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
444 return SC->getValue()->isZero();
445 return false;
446}
447
448bool SCEV::isOne() const {
449 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
450 return SC->getValue()->isOne();
451 return false;
452}
453
455 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
456 return SC->getValue()->isMinusOne();
457 return false;
458}
459
461 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
462 if (!Mul) return false;
463
464 // If there is a constant factor, it will be first.
465 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
466 if (!SC) return false;
467
468 // Return true if the value is negative, this matches things like (-42 * V).
469 return SC->getAPInt().isNegative();
470}
471
474
476 return S->getSCEVType() == scCouldNotCompute;
477}
478
481 ID.AddInteger(scConstant);
482 ID.AddPointer(V);
483 void *IP = nullptr;
484 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
485 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
486 UniqueSCEVs.InsertNode(S, IP);
487 return S;
488}
489
491 return getConstant(ConstantInt::get(getContext(), Val));
492}
493
494const SCEV *
496 IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
497 return getConstant(ConstantInt::get(ITy, V, isSigned));
498}
499
502 ID.AddInteger(scVScale);
503 ID.AddPointer(Ty);
504 void *IP = nullptr;
505 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
506 return S;
507 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
508 UniqueSCEVs.InsertNode(S, IP);
509 return S;
510}
511
513 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
514 if (EC.isScalable())
515 Res = getMulExpr(Res, getVScale(Ty));
516 return Res;
517}
518
520 const SCEV *op, Type *ty)
521 : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
522
523SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
524 Type *ITy)
525 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
527 "Must be a non-bit-width-changing pointer-to-integer cast!");
528}
529
531 SCEVTypes SCEVTy, const SCEV *op,
532 Type *ty)
533 : SCEVCastExpr(ID, SCEVTy, op, ty) {}
534
535SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
536 Type *ty)
538 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
539 "Cannot truncate non-integer value!");
540}
541
542SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
543 const SCEV *op, Type *ty)
545 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
546 "Cannot zero extend non-integer value!");
547}
548
549SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
550 const SCEV *op, Type *ty)
552 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
553 "Cannot sign extend non-integer value!");
554}
555
556void SCEVUnknown::deleted() {
557 // Clear this SCEVUnknown from various maps.
558 SE->forgetMemoizedResults(this);
559
560 // Remove this SCEVUnknown from the uniquing map.
561 SE->UniqueSCEVs.RemoveNode(this);
562
563 // Release the value.
564 setValPtr(nullptr);
565}
566
567void SCEVUnknown::allUsesReplacedWith(Value *New) {
568 // Clear this SCEVUnknown from various maps.
569 SE->forgetMemoizedResults(this);
570
571 // Remove this SCEVUnknown from the uniquing map.
572 SE->UniqueSCEVs.RemoveNode(this);
573
574 // Replace the value pointer in case someone is still using this SCEVUnknown.
575 setValPtr(New);
576}
577
578//===----------------------------------------------------------------------===//
579// SCEV Utilities
580//===----------------------------------------------------------------------===//
581
582/// Compare the two values \p LV and \p RV in terms of their "complexity" where
583/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
584/// operands in SCEV expressions. \p EqCache is a set of pairs of values that
585/// have been previously deemed to be "equally complex" by this routine. It is
586/// intended to avoid exponential time complexity in cases like:
587///
588/// %a = f(%x, %y)
589/// %b = f(%a, %a)
590/// %c = f(%b, %b)
591///
592/// %d = f(%x, %y)
593/// %e = f(%d, %d)
594/// %f = f(%e, %e)
595///
596/// CompareValueComplexity(%f, %c)
597///
598/// Since we do not continue running this routine on expression trees once we
599/// have seen unequal values, there is no need to track them in the cache.
600static int
602 const LoopInfo *const LI, Value *LV, Value *RV,
603 unsigned Depth) {
604 if (Depth > MaxValueCompareDepth || EqCacheValue.isEquivalent(LV, RV))
605 return 0;
606
607 // Order pointer values after integer values. This helps SCEVExpander form
608 // GEPs.
609 bool LIsPointer = LV->getType()->isPointerTy(),
610 RIsPointer = RV->getType()->isPointerTy();
611 if (LIsPointer != RIsPointer)
612 return (int)LIsPointer - (int)RIsPointer;
613
614 // Compare getValueID values.
615 unsigned LID = LV->getValueID(), RID = RV->getValueID();
616 if (LID != RID)
617 return (int)LID - (int)RID;
618
619 // Sort arguments by their position.
620 if (const auto *LA = dyn_cast<Argument>(LV)) {
621 const auto *RA = cast<Argument>(RV);
622 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
623 return (int)LArgNo - (int)RArgNo;
624 }
625
626 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
627 const auto *RGV = cast<GlobalValue>(RV);
628
629 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
630 auto LT = GV->getLinkage();
631 return !(GlobalValue::isPrivateLinkage(LT) ||
633 };
634
635 // Use the names to distinguish the two values, but only if the
636 // names are semantically important.
637 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
638 return LGV->getName().compare(RGV->getName());
639 }
640
641 // For instructions, compare their loop depth, and their operand count. This
642 // is pretty loose.
643 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
644 const auto *RInst = cast<Instruction>(RV);
645
646 // Compare loop depths.
647 const BasicBlock *LParent = LInst->getParent(),
648 *RParent = RInst->getParent();
649 if (LParent != RParent) {
650 unsigned LDepth = LI->getLoopDepth(LParent),
651 RDepth = LI->getLoopDepth(RParent);
652 if (LDepth != RDepth)
653 return (int)LDepth - (int)RDepth;
654 }
655
656 // Compare the number of operands.
657 unsigned LNumOps = LInst->getNumOperands(),
658 RNumOps = RInst->getNumOperands();
659 if (LNumOps != RNumOps)
660 return (int)LNumOps - (int)RNumOps;
661
662 for (unsigned Idx : seq(LNumOps)) {
663 int Result =
664 CompareValueComplexity(EqCacheValue, LI, LInst->getOperand(Idx),
665 RInst->getOperand(Idx), Depth + 1);
666 if (Result != 0)
667 return Result;
668 }
669 }
670
671 EqCacheValue.unionSets(LV, RV);
672 return 0;
673}
674
675// Return negative, zero, or positive, if LHS is less than, equal to, or greater
676// than RHS, respectively. A three-way result allows recursive comparisons to be
677// more efficient.
678// If the max analysis depth was reached, return std::nullopt, assuming we do
679// not know if they are equivalent for sure.
680static std::optional<int>
683 const LoopInfo *const LI, const SCEV *LHS,
684 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
685 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
686 if (LHS == RHS)
687 return 0;
688
689 // Primarily, sort the SCEVs by their getSCEVType().
690 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
691 if (LType != RType)
692 return (int)LType - (int)RType;
693
694 if (EqCacheSCEV.isEquivalent(LHS, RHS))
695 return 0;
696
698 return std::nullopt;
699
700 // Aside from the getSCEVType() ordering, the particular ordering
701 // isn't very important except that it's beneficial to be consistent,
702 // so that (a + b) and (b + a) don't end up as different expressions.
703 switch (LType) {
704 case scUnknown: {
705 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
706 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
707
708 int X = CompareValueComplexity(EqCacheValue, LI, LU->getValue(),
709 RU->getValue(), Depth + 1);
710 if (X == 0)
711 EqCacheSCEV.unionSets(LHS, RHS);
712 return X;
713 }
714
715 case scConstant: {
716 const SCEVConstant *LC = cast<SCEVConstant>(LHS);
717 const SCEVConstant *RC = cast<SCEVConstant>(RHS);
718
719 // Compare constant values.
720 const APInt &LA = LC->getAPInt();
721 const APInt &RA = RC->getAPInt();
722 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
723 if (LBitWidth != RBitWidth)
724 return (int)LBitWidth - (int)RBitWidth;
725 return LA.ult(RA) ? -1 : 1;
726 }
727
728 case scVScale: {
729 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
730 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
731 return LTy->getBitWidth() - RTy->getBitWidth();
732 }
733
734 case scAddRecExpr: {
735 const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
736 const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
737
738 // There is always a dominance between two recs that are used by one SCEV,
739 // so we can safely sort recs by loop header dominance. We require such
740 // order in getAddExpr.
741 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
742 if (LLoop != RLoop) {
743 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
744 assert(LHead != RHead && "Two loops share the same header?");
745 if (DT.dominates(LHead, RHead))
746 return 1;
747 assert(DT.dominates(RHead, LHead) &&
748 "No dominance between recurrences used by one SCEV?");
749 return -1;
750 }
751
752 [[fallthrough]];
753 }
754
755 case scTruncate:
756 case scZeroExtend:
757 case scSignExtend:
758 case scPtrToInt:
759 case scAddExpr:
760 case scMulExpr:
761 case scUDivExpr:
762 case scSMaxExpr:
763 case scUMaxExpr:
764 case scSMinExpr:
765 case scUMinExpr:
767 ArrayRef<const SCEV *> LOps = LHS->operands();
768 ArrayRef<const SCEV *> ROps = RHS->operands();
769
770 // Lexicographically compare n-ary-like expressions.
771 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
772 if (LNumOps != RNumOps)
773 return (int)LNumOps - (int)RNumOps;
774
775 for (unsigned i = 0; i != LNumOps; ++i) {
776 auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LOps[i],
777 ROps[i], DT, Depth + 1);
778 if (X != 0)
779 return X;
780 }
781 EqCacheSCEV.unionSets(LHS, RHS);
782 return 0;
783 }
784
786 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
787 }
788 llvm_unreachable("Unknown SCEV kind!");
789}
790
791/// Given a list of SCEV objects, order them by their complexity, and group
792/// objects of the same complexity together by value. When this routine is
793/// finished, we know that any duplicates in the vector are consecutive and that
794/// complexity is monotonically increasing.
795///
796/// Note that we go take special precautions to ensure that we get deterministic
797/// results from this routine. In other words, we don't want the results of
798/// this to depend on where the addresses of various SCEV objects happened to
799/// land in memory.
801 LoopInfo *LI, DominatorTree &DT) {
802 if (Ops.size() < 2) return; // Noop
803
806
807 // Whether LHS has provably less complexity than RHS.
808 auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
809 auto Complexity =
810 CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT);
811 return Complexity && *Complexity < 0;
812 };
813 if (Ops.size() == 2) {
814 // This is the common case, which also happens to be trivially simple.
815 // Special case it.
816 const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
817 if (IsLessComplex(RHS, LHS))
818 std::swap(LHS, RHS);
819 return;
820 }
821
822 // Do the rough sort by complexity.
823 llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
824 return IsLessComplex(LHS, RHS);
825 });
826
827 // Now that we are sorted by complexity, group elements of the same
828 // complexity. Note that this is, at worst, N^2, but the vector is likely to
829 // be extremely short in practice. Note that we take this approach because we
830 // do not want to depend on the addresses of the objects we are grouping.
831 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
832 const SCEV *S = Ops[i];
833 unsigned Complexity = S->getSCEVType();
834
835 // If there are any objects of the same complexity and same value as this
836 // one, group them.
837 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
838 if (Ops[j] == S) { // Found a duplicate.
839 // Move it to immediately after i'th element.
840 std::swap(Ops[i+1], Ops[j]);
841 ++i; // no need to rescan it.
842 if (i == e-2) return; // Done!
843 }
844 }
845 }
846}
847
848/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
849/// least HugeExprThreshold nodes).
851 return any_of(Ops, [](const SCEV *S) {
853 });
854}
855
856//===----------------------------------------------------------------------===//
857// Simple SCEV method implementations
858//===----------------------------------------------------------------------===//
859
860/// Compute BC(It, K). The result has width W. Assume, K > 0.
861static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
862 ScalarEvolution &SE,
863 Type *ResultTy) {
864 // Handle the simplest case efficiently.
865 if (K == 1)
866 return SE.getTruncateOrZeroExtend(It, ResultTy);
867
868 // We are using the following formula for BC(It, K):
869 //
870 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
871 //
872 // Suppose, W is the bitwidth of the return value. We must be prepared for
873 // overflow. Hence, we must assure that the result of our computation is
874 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
875 // safe in modular arithmetic.
876 //
877 // However, this code doesn't use exactly that formula; the formula it uses
878 // is something like the following, where T is the number of factors of 2 in
879 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
880 // exponentiation:
881 //
882 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
883 //
884 // This formula is trivially equivalent to the previous formula. However,
885 // this formula can be implemented much more efficiently. The trick is that
886 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
887 // arithmetic. To do exact division in modular arithmetic, all we have
888 // to do is multiply by the inverse. Therefore, this step can be done at
889 // width W.
890 //
891 // The next issue is how to safely do the division by 2^T. The way this
892 // is done is by doing the multiplication step at a width of at least W + T
893 // bits. This way, the bottom W+T bits of the product are accurate. Then,
894 // when we perform the division by 2^T (which is equivalent to a right shift
895 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
896 // truncated out after the division by 2^T.
897 //
898 // In comparison to just directly using the first formula, this technique
899 // is much more efficient; using the first formula requires W * K bits,
900 // but this formula less than W + K bits. Also, the first formula requires
901 // a division step, whereas this formula only requires multiplies and shifts.
902 //
903 // It doesn't matter whether the subtraction step is done in the calculation
904 // width or the input iteration count's width; if the subtraction overflows,
905 // the result must be zero anyway. We prefer here to do it in the width of
906 // the induction variable because it helps a lot for certain cases; CodeGen
907 // isn't smart enough to ignore the overflow, which leads to much less
908 // efficient code if the width of the subtraction is wider than the native
909 // register width.
910 //
911 // (It's possible to not widen at all by pulling out factors of 2 before
912 // the multiplication; for example, K=2 can be calculated as
913 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
914 // extra arithmetic, so it's not an obvious win, and it gets
915 // much more complicated for K > 3.)
916
917 // Protection from insane SCEVs; this bound is conservative,
918 // but it probably doesn't matter.
919 if (K > 1000)
920 return SE.getCouldNotCompute();
921
922 unsigned W = SE.getTypeSizeInBits(ResultTy);
923
924 // Calculate K! / 2^T and T; we divide out the factors of two before
925 // multiplying for calculating K! / 2^T to avoid overflow.
926 // Other overflow doesn't matter because we only care about the bottom
927 // W bits of the result.
928 APInt OddFactorial(W, 1);
929 unsigned T = 1;
930 for (unsigned i = 3; i <= K; ++i) {
931 unsigned TwoFactors = countr_zero(i);
932 T += TwoFactors;
933 OddFactorial *= (i >> TwoFactors);
934 }
935
936 // We need at least W + T bits for the multiplication step
937 unsigned CalculationBits = W + T;
938
939 // Calculate 2^T, at width T+W.
940 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
941
942 // Calculate the multiplicative inverse of K! / 2^T;
943 // this multiplication factor will perform the exact division by
944 // K! / 2^T.
945 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
946
947 // Calculate the product, at width T+W
948 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
949 CalculationBits);
950 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
951 for (unsigned i = 1; i != K; ++i) {
952 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
953 Dividend = SE.getMulExpr(Dividend,
954 SE.getTruncateOrZeroExtend(S, CalculationTy));
955 }
956
957 // Divide by 2^T
958 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
959
960 // Truncate the result, and divide by K! / 2^T.
961
962 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
963 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
964}
965
966/// Return the value of this chain of recurrences at the specified iteration
967/// number. We can evaluate this recurrence by multiplying each element in the
968/// chain by the binomial coefficient corresponding to it. In other words, we
969/// can evaluate {A,+,B,+,C,+,D} as:
970///
971/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
972///
973/// where BC(It, k) stands for binomial coefficient.
975 ScalarEvolution &SE) const {
976 return evaluateAtIteration(operands(), It, SE);
977}
978
979const SCEV *
981 const SCEV *It, ScalarEvolution &SE) {
982 assert(Operands.size() > 0);
983 const SCEV *Result = Operands[0];
984 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
985 // The computation is correct in the face of overflow provided that the
986 // multiplication is performed _after_ the evaluation of the binomial
987 // coefficient.
988 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
989 if (isa<SCEVCouldNotCompute>(Coeff))
990 return Coeff;
991
992 Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
993 }
994 return Result;
995}
996
997//===----------------------------------------------------------------------===//
998// SCEV Expression folder implementations
999//===----------------------------------------------------------------------===//
1000
1002 unsigned Depth) {
1003 assert(Depth <= 1 &&
1004 "getLosslessPtrToIntExpr() should self-recurse at most once.");
1005
1006 // We could be called with an integer-typed operands during SCEV rewrites.
1007 // Since the operand is an integer already, just perform zext/trunc/self cast.
1008 if (!Op->getType()->isPointerTy())
1009 return Op;
1010
1011 // What would be an ID for such a SCEV cast expression?
1013 ID.AddInteger(scPtrToInt);
1014 ID.AddPointer(Op);
1015
1016 void *IP = nullptr;
1017
1018 // Is there already an expression for such a cast?
1019 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1020 return S;
1021
1022 // It isn't legal for optimizations to construct new ptrtoint expressions
1023 // for non-integral pointers.
1024 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1025 return getCouldNotCompute();
1026
1027 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1028
1029 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1030 // is sufficiently wide to represent all possible pointer values.
1031 // We could theoretically teach SCEV to truncate wider pointers, but
1032 // that isn't implemented for now.
1034 getDataLayout().getTypeSizeInBits(IntPtrTy))
1035 return getCouldNotCompute();
1036
1037 // If not, is this expression something we can't reduce any further?
1038 if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1039 // Perform some basic constant folding. If the operand of the ptr2int cast
1040 // is a null pointer, don't create a ptr2int SCEV expression (that will be
1041 // left as-is), but produce a zero constant.
1042 // NOTE: We could handle a more general case, but lack motivational cases.
1043 if (isa<ConstantPointerNull>(U->getValue()))
1044 return getZero(IntPtrTy);
1045
1046 // Create an explicit cast node.
1047 // We can reuse the existing insert position since if we get here,
1048 // we won't have made any changes which would invalidate it.
1049 SCEV *S = new (SCEVAllocator)
1050 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1051 UniqueSCEVs.InsertNode(S, IP);
1052 registerUser(S, Op);
1053 return S;
1054 }
1055
1056 assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
1057 "non-SCEVUnknown's.");
1058
1059 // Otherwise, we've got some expression that is more complex than just a
1060 // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1061 // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1062 // only, and the expressions must otherwise be integer-typed.
1063 // So sink the cast down to the SCEVUnknown's.
1064
1065 /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1066 /// which computes a pointer-typed value, and rewrites the whole expression
1067 /// tree so that *all* the computations are done on integers, and the only
1068 /// pointer-typed operands in the expression are SCEVUnknown.
1069 class SCEVPtrToIntSinkingRewriter
1070 : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1072
1073 public:
1074 SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1075
1076 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1077 SCEVPtrToIntSinkingRewriter Rewriter(SE);
1078 return Rewriter.visit(Scev);
1079 }
1080
1081 const SCEV *visit(const SCEV *S) {
1082 Type *STy = S->getType();
1083 // If the expression is not pointer-typed, just keep it as-is.
1084 if (!STy->isPointerTy())
1085 return S;
1086 // Else, recursively sink the cast down into it.
1087 return Base::visit(S);
1088 }
1089
1090 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1092 bool Changed = false;
1093 for (const auto *Op : Expr->operands()) {
1094 Operands.push_back(visit(Op));
1095 Changed |= Op != Operands.back();
1096 }
1097 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1098 }
1099
1100 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1102 bool Changed = false;
1103 for (const auto *Op : Expr->operands()) {
1104 Operands.push_back(visit(Op));
1105 Changed |= Op != Operands.back();
1106 }
1107 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1108 }
1109
1110 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1111 assert(Expr->getType()->isPointerTy() &&
1112 "Should only reach pointer-typed SCEVUnknown's.");
1113 return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1114 }
1115 };
1116
1117 // And actually perform the cast sinking.
1118 const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1119 assert(IntOp->getType()->isIntegerTy() &&
1120 "We must have succeeded in sinking the cast, "
1121 "and ending up with an integer-typed expression!");
1122 return IntOp;
1123}
1124
1126 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1127
1128 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1129 if (isa<SCEVCouldNotCompute>(IntOp))
1130 return IntOp;
1131
1132 return getTruncateOrZeroExtend(IntOp, Ty);
1133}
1134
1136 unsigned Depth) {
1137 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1138 "This is not a truncating conversion!");
1139 assert(isSCEVable(Ty) &&
1140 "This is not a conversion to a SCEVable type!");
1141 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1142 Ty = getEffectiveSCEVType(Ty);
1143
1145 ID.AddInteger(scTruncate);
1146 ID.AddPointer(Op);
1147 ID.AddPointer(Ty);
1148 void *IP = nullptr;
1149 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1150
1151 // Fold if the operand is constant.
1152 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1153 return getConstant(
1154 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1155
1156 // trunc(trunc(x)) --> trunc(x)
1157 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
1158 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1159
1160 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1161 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1162 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1163
1164 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1165 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1166 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1167
1168 if (Depth > MaxCastDepth) {
1169 SCEV *S =
1170 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1171 UniqueSCEVs.InsertNode(S, IP);
1172 registerUser(S, Op);
1173 return S;
1174 }
1175
1176 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1177 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1178 // if after transforming we have at most one truncate, not counting truncates
1179 // that replace other casts.
1180 if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) {
1181 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1183 unsigned numTruncs = 0;
1184 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1185 ++i) {
1186 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1187 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1188 isa<SCEVTruncateExpr>(S))
1189 numTruncs++;
1190 Operands.push_back(S);
1191 }
1192 if (numTruncs < 2) {
1193 if (isa<SCEVAddExpr>(Op))
1194 return getAddExpr(Operands);
1195 if (isa<SCEVMulExpr>(Op))
1196 return getMulExpr(Operands);
1197 llvm_unreachable("Unexpected SCEV type for Op.");
1198 }
1199 // Although we checked in the beginning that ID is not in the cache, it is
1200 // possible that during recursion and different modification ID was inserted
1201 // into the cache. So if we find it, just return it.
1202 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1203 return S;
1204 }
1205
1206 // If the input value is a chrec scev, truncate the chrec's operands.
1207 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1209 for (const SCEV *Op : AddRec->operands())
1210 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1211 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1212 }
1213
1214 // Return zero if truncating to known zeros.
1215 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1216 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1217 return getZero(Ty);
1218
1219 // The cast wasn't folded; create an explicit cast node. We can reuse
1220 // the existing insert position since if we get here, we won't have
1221 // made any changes which would invalidate it.
1222 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1223 Op, Ty);
1224 UniqueSCEVs.InsertNode(S, IP);
1225 registerUser(S, Op);
1226 return S;
1227}
1228
1229// Get the limit of a recurrence such that incrementing by Step cannot cause
1230// signed overflow as long as the value of the recurrence within the
1231// loop does not exceed this limit before incrementing.
1232static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1233 ICmpInst::Predicate *Pred,
1234 ScalarEvolution *SE) {
1235 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1236 if (SE->isKnownPositive(Step)) {
1237 *Pred = ICmpInst::ICMP_SLT;
1239 SE->getSignedRangeMax(Step));
1240 }
1241 if (SE->isKnownNegative(Step)) {
1242 *Pred = ICmpInst::ICMP_SGT;
1244 SE->getSignedRangeMin(Step));
1245 }
1246 return nullptr;
1247}
1248
1249// Get the limit of a recurrence such that incrementing by Step cannot cause
1250// unsigned overflow as long as the value of the recurrence within the loop does
1251// not exceed this limit before incrementing.
1253 ICmpInst::Predicate *Pred,
1254 ScalarEvolution *SE) {
1255 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1256 *Pred = ICmpInst::ICMP_ULT;
1257
1259 SE->getUnsignedRangeMax(Step));
1260}
1261
1262namespace {
1263
1264struct ExtendOpTraitsBase {
1265 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1266 unsigned);
1267};
1268
1269// Used to make code generic over signed and unsigned overflow.
1270template <typename ExtendOp> struct ExtendOpTraits {
1271 // Members present:
1272 //
1273 // static const SCEV::NoWrapFlags WrapType;
1274 //
1275 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1276 //
1277 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1278 // ICmpInst::Predicate *Pred,
1279 // ScalarEvolution *SE);
1280};
1281
1282template <>
1283struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1284 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1285
1286 static const GetExtendExprTy GetExtendExpr;
1287
1288 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1289 ICmpInst::Predicate *Pred,
1290 ScalarEvolution *SE) {
1291 return getSignedOverflowLimitForStep(Step, Pred, SE);
1292 }
1293};
1294
1295const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1297
1298template <>
1299struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1300 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1301
1302 static const GetExtendExprTy GetExtendExpr;
1303
1304 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1305 ICmpInst::Predicate *Pred,
1306 ScalarEvolution *SE) {
1307 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1308 }
1309};
1310
1311const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1313
1314} // end anonymous namespace
1315
1316// The recurrence AR has been shown to have no signed/unsigned wrap or something
1317// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1318// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1319// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1320// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1321// expression "Step + sext/zext(PreIncAR)" is congruent with
1322// "sext/zext(PostIncAR)"
1323template <typename ExtendOpTy>
1324static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1325 ScalarEvolution *SE, unsigned Depth) {
1326 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1327 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1328
1329 const Loop *L = AR->getLoop();
1330 const SCEV *Start = AR->getStart();
1331 const SCEV *Step = AR->getStepRecurrence(*SE);
1332
1333 // Check for a simple looking step prior to loop entry.
1334 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1335 if (!SA)
1336 return nullptr;
1337
1338 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1339 // subtraction is expensive. For this purpose, perform a quick and dirty
1340 // difference, by checking for Step in the operand list. Note, that
1341 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1343 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1344 if (*It == Step) {
1345 DiffOps.erase(It);
1346 break;
1347 }
1348
1349 if (DiffOps.size() == SA->getNumOperands())
1350 return nullptr;
1351
1352 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1353 // `Step`:
1354
1355 // 1. NSW/NUW flags on the step increment.
1356 auto PreStartFlags =
1358 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1359 const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1360 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1361
1362 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1363 // "S+X does not sign/unsign-overflow".
1364 //
1365
1366 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1367 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1368 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1369 return PreStart;
1370
1371 // 2. Direct overflow check on the step operation's expression.
1372 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1373 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1374 const SCEV *OperandExtendedStart =
1375 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1376 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1377 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1378 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1379 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1380 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1381 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1382 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1383 }
1384 return PreStart;
1385 }
1386
1387 // 3. Loop precondition.
1389 const SCEV *OverflowLimit =
1390 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1391
1392 if (OverflowLimit &&
1393 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1394 return PreStart;
1395
1396 return nullptr;
1397}
1398
1399// Get the normalized zero or sign extended expression for this AddRec's Start.
1400template <typename ExtendOpTy>
1401static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1402 ScalarEvolution *SE,
1403 unsigned Depth) {
1404 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1405
1406 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1407 if (!PreStart)
1408 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1409
1410 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1411 Depth),
1412 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1413}
1414
1415// Try to prove away overflow by looking at "nearby" add recurrences. A
1416// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1417// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1418//
1419// Formally:
1420//
1421// {S,+,X} == {S-T,+,X} + T
1422// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1423//
1424// If ({S-T,+,X} + T) does not overflow ... (1)
1425//
1426// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1427//
1428// If {S-T,+,X} does not overflow ... (2)
1429//
1430// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1431// == {Ext(S-T)+Ext(T),+,Ext(X)}
1432//
1433// If (S-T)+T does not overflow ... (3)
1434//
1435// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1436// == {Ext(S),+,Ext(X)} == LHS
1437//
1438// Thus, if (1), (2) and (3) are true for some T, then
1439// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1440//
1441// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1442// does not overflow" restricted to the 0th iteration. Therefore we only need
1443// to check for (1) and (2).
1444//
1445// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1446// is `Delta` (defined below).
1447template <typename ExtendOpTy>
1448bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1449 const SCEV *Step,
1450 const Loop *L) {
1451 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1452
1453 // We restrict `Start` to a constant to prevent SCEV from spending too much
1454 // time here. It is correct (but more expensive) to continue with a
1455 // non-constant `Start` and do a general SCEV subtraction to compute
1456 // `PreStart` below.
1457 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1458 if (!StartC)
1459 return false;
1460
1461 APInt StartAI = StartC->getAPInt();
1462
1463 for (unsigned Delta : {-2, -1, 1, 2}) {
1464 const SCEV *PreStart = getConstant(StartAI - Delta);
1465
1467 ID.AddInteger(scAddRecExpr);
1468 ID.AddPointer(PreStart);
1469 ID.AddPointer(Step);
1470 ID.AddPointer(L);
1471 void *IP = nullptr;
1472 const auto *PreAR =
1473 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1474
1475 // Give up if we don't already have the add recurrence we need because
1476 // actually constructing an add recurrence is relatively expensive.
1477 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1478 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1480 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1481 DeltaS, &Pred, this);
1482 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1483 return true;
1484 }
1485 }
1486
1487 return false;
1488}
1489
1490// Finds an integer D for an expression (C + x + y + ...) such that the top
1491// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1492// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1493// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1494// the (C + x + y + ...) expression is \p WholeAddExpr.
1496 const SCEVConstant *ConstantTerm,
1497 const SCEVAddExpr *WholeAddExpr) {
1498 const APInt &C = ConstantTerm->getAPInt();
1499 const unsigned BitWidth = C.getBitWidth();
1500 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1501 uint32_t TZ = BitWidth;
1502 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1503 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1504 if (TZ) {
1505 // Set D to be as many least significant bits of C as possible while still
1506 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1507 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1508 }
1509 return APInt(BitWidth, 0);
1510}
1511
1512// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1513// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1514// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1515// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1517 const APInt &ConstantStart,
1518 const SCEV *Step) {
1519 const unsigned BitWidth = ConstantStart.getBitWidth();
1520 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1521 if (TZ)
1522 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1523 : ConstantStart;
1524 return APInt(BitWidth, 0);
1525}
1526
1528 const ScalarEvolution::FoldID &ID, const SCEV *S,
1531 &FoldCacheUser) {
1532 auto I = FoldCache.insert({ID, S});
1533 if (!I.second) {
1534 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1535 // entry.
1536 auto &UserIDs = FoldCacheUser[I.first->second];
1537 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1538 for (unsigned I = 0; I != UserIDs.size(); ++I)
1539 if (UserIDs[I] == ID) {
1540 std::swap(UserIDs[I], UserIDs.back());
1541 break;
1542 }
1543 UserIDs.pop_back();
1544 I.first->second = S;
1545 }
1546 auto R = FoldCacheUser.insert({S, {}});
1547 R.first->second.push_back(ID);
1548}
1549
1550const SCEV *
1552 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1553 "This is not an extending conversion!");
1554 assert(isSCEVable(Ty) &&
1555 "This is not a conversion to a SCEVable type!");
1556 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1557 Ty = getEffectiveSCEVType(Ty);
1558
1559 FoldID ID(scZeroExtend, Op, Ty);
1560 auto Iter = FoldCache.find(ID);
1561 if (Iter != FoldCache.end())
1562 return Iter->second;
1563
1564 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1565 if (!isa<SCEVZeroExtendExpr>(S))
1566 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1567 return S;
1568}
1569
1571 unsigned Depth) {
1572 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1573 "This is not an extending conversion!");
1574 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1575 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1576
1577 // Fold if the operand is constant.
1578 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1579 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1580
1581 // zext(zext(x)) --> zext(x)
1582 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1583 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1584
1585 // Before doing any expensive analysis, check to see if we've already
1586 // computed a SCEV for this Op and Ty.
1588 ID.AddInteger(scZeroExtend);
1589 ID.AddPointer(Op);
1590 ID.AddPointer(Ty);
1591 void *IP = nullptr;
1592 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1593 if (Depth > MaxCastDepth) {
1594 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1595 Op, Ty);
1596 UniqueSCEVs.InsertNode(S, IP);
1597 registerUser(S, Op);
1598 return S;
1599 }
1600
1601 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1602 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1603 // It's possible the bits taken off by the truncate were all zero bits. If
1604 // so, we should be able to simplify this further.
1605 const SCEV *X = ST->getOperand();
1607 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1608 unsigned NewBits = getTypeSizeInBits(Ty);
1609 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1610 CR.zextOrTrunc(NewBits)))
1611 return getTruncateOrZeroExtend(X, Ty, Depth);
1612 }
1613
1614 // If the input value is a chrec scev, and we can prove that the value
1615 // did not overflow the old, smaller, value, we can zero extend all of the
1616 // operands (often constants). This allows analysis of something like
1617 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1618 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1619 if (AR->isAffine()) {
1620 const SCEV *Start = AR->getStart();
1621 const SCEV *Step = AR->getStepRecurrence(*this);
1622 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1623 const Loop *L = AR->getLoop();
1624
1625 // If we have special knowledge that this addrec won't overflow,
1626 // we don't need to do any further analysis.
1627 if (AR->hasNoUnsignedWrap()) {
1628 Start =
1629 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1630 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1631 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1632 }
1633
1634 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1635 // Note that this serves two purposes: It filters out loops that are
1636 // simply not analyzable, and it covers the case where this code is
1637 // being called from within backedge-taken count analysis, such that
1638 // attempting to ask for the backedge-taken count would likely result
1639 // in infinite recursion. In the later case, the analysis code will
1640 // cope with a conservative value, and it will take care to purge
1641 // that value once it has finished.
1642 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1643 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1644 // Manually compute the final value for AR, checking for overflow.
1645
1646 // Check whether the backedge-taken count can be losslessly casted to
1647 // the addrec's type. The count is always unsigned.
1648 const SCEV *CastedMaxBECount =
1649 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1650 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1651 CastedMaxBECount, MaxBECount->getType(), Depth);
1652 if (MaxBECount == RecastedMaxBECount) {
1653 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1654 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1655 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1657 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1659 Depth + 1),
1660 WideTy, Depth + 1);
1661 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1662 const SCEV *WideMaxBECount =
1663 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1664 const SCEV *OperandExtendedAdd =
1665 getAddExpr(WideStart,
1666 getMulExpr(WideMaxBECount,
1667 getZeroExtendExpr(Step, WideTy, Depth + 1),
1670 if (ZAdd == OperandExtendedAdd) {
1671 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1672 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1673 // Return the expression with the addrec on the outside.
1674 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1675 Depth + 1);
1676 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1677 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1678 }
1679 // Similar to above, only this time treat the step value as signed.
1680 // This covers loops that count down.
1681 OperandExtendedAdd =
1682 getAddExpr(WideStart,
1683 getMulExpr(WideMaxBECount,
1684 getSignExtendExpr(Step, WideTy, Depth + 1),
1687 if (ZAdd == OperandExtendedAdd) {
1688 // Cache knowledge of AR NW, which is propagated to this AddRec.
1689 // Negative step causes unsigned wrap, but it still can't self-wrap.
1690 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1691 // Return the expression with the addrec on the outside.
1692 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1693 Depth + 1);
1694 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1695 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1696 }
1697 }
1698 }
1699
1700 // Normally, in the cases we can prove no-overflow via a
1701 // backedge guarding condition, we can also compute a backedge
1702 // taken count for the loop. The exceptions are assumptions and
1703 // guards present in the loop -- SCEV is not great at exploiting
1704 // these to compute max backedge taken counts, but can still use
1705 // these to prove lack of overflow. Use this fact to avoid
1706 // doing extra work that may not pay off.
1707 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1708 !AC.assumptions().empty()) {
1709
1710 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1711 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1712 if (AR->hasNoUnsignedWrap()) {
1713 // Same as nuw case above - duplicated here to avoid a compile time
1714 // issue. It's not clear that the order of checks does matter, but
1715 // it's one of two issue possible causes for a change which was
1716 // reverted. Be conservative for the moment.
1717 Start =
1718 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1719 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1720 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1721 }
1722
1723 // For a negative step, we can extend the operands iff doing so only
1724 // traverses values in the range zext([0,UINT_MAX]).
1725 if (isKnownNegative(Step)) {
1727 getSignedRangeMin(Step));
1730 // Cache knowledge of AR NW, which is propagated to this
1731 // AddRec. Negative step causes unsigned wrap, but it
1732 // still can't self-wrap.
1733 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1734 // Return the expression with the addrec on the outside.
1735 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1736 Depth + 1);
1737 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1738 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1739 }
1740 }
1741 }
1742
1743 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1744 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1745 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1746 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1747 const APInt &C = SC->getAPInt();
1748 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1749 if (D != 0) {
1750 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1751 const SCEV *SResidual =
1752 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1753 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1754 return getAddExpr(SZExtD, SZExtR,
1756 Depth + 1);
1757 }
1758 }
1759
1760 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1761 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1762 Start =
1763 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1764 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1765 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1766 }
1767 }
1768
1769 // zext(A % B) --> zext(A) % zext(B)
1770 {
1771 const SCEV *LHS;
1772 const SCEV *RHS;
1773 if (matchURem(Op, LHS, RHS))
1774 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1775 getZeroExtendExpr(RHS, Ty, Depth + 1));
1776 }
1777
1778 // zext(A / B) --> zext(A) / zext(B).
1779 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1780 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1781 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1782
1783 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1784 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1785 if (SA->hasNoUnsignedWrap()) {
1786 // If the addition does not unsign overflow then we can, by definition,
1787 // commute the zero extension with the addition operation.
1789 for (const auto *Op : SA->operands())
1790 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1791 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1792 }
1793
1794 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1795 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1796 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1797 //
1798 // Often address arithmetics contain expressions like
1799 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1800 // This transformation is useful while proving that such expressions are
1801 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1802 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1803 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1804 if (D != 0) {
1805 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1806 const SCEV *SResidual =
1808 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1809 return getAddExpr(SZExtD, SZExtR,
1811 Depth + 1);
1812 }
1813 }
1814 }
1815
1816 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1817 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1818 if (SM->hasNoUnsignedWrap()) {
1819 // If the multiply does not unsign overflow then we can, by definition,
1820 // commute the zero extension with the multiply operation.
1822 for (const auto *Op : SM->operands())
1823 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1824 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1825 }
1826
1827 // zext(2^K * (trunc X to iN)) to iM ->
1828 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1829 //
1830 // Proof:
1831 //
1832 // zext(2^K * (trunc X to iN)) to iM
1833 // = zext((trunc X to iN) << K) to iM
1834 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1835 // (because shl removes the top K bits)
1836 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1837 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1838 //
1839 if (SM->getNumOperands() == 2)
1840 if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1841 if (MulLHS->getAPInt().isPowerOf2())
1842 if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1843 int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
1844 MulLHS->getAPInt().logBase2();
1845 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1846 return getMulExpr(
1847 getZeroExtendExpr(MulLHS, Ty),
1849 getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
1850 SCEV::FlagNUW, Depth + 1);
1851 }
1852 }
1853
1854 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1855 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1856 if (isa<SCEVUMinExpr>(Op) || isa<SCEVUMaxExpr>(Op)) {
1857 auto *MinMax = cast<SCEVMinMaxExpr>(Op);
1859 for (auto *Operand : MinMax->operands())
1860 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1861 if (isa<SCEVUMinExpr>(MinMax))
1862 return getUMinExpr(Operands);
1863 return getUMaxExpr(Operands);
1864 }
1865
1866 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1867 if (auto *MinMax = dyn_cast<SCEVSequentialMinMaxExpr>(Op)) {
1868 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1870 for (auto *Operand : MinMax->operands())
1871 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1872 return getUMinExpr(Operands, /*Sequential*/ true);
1873 }
1874
1875 // The cast wasn't folded; create an explicit cast node.
1876 // Recompute the insert position, as it may have been invalidated.
1877 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1878 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1879 Op, Ty);
1880 UniqueSCEVs.InsertNode(S, IP);
1881 registerUser(S, Op);
1882 return S;
1883}
1884
1885const SCEV *
1887 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1888 "This is not an extending conversion!");
1889 assert(isSCEVable(Ty) &&
1890 "This is not a conversion to a SCEVable type!");
1891 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1892 Ty = getEffectiveSCEVType(Ty);
1893
1894 FoldID ID(scSignExtend, Op, Ty);
1895 auto Iter = FoldCache.find(ID);
1896 if (Iter != FoldCache.end())
1897 return Iter->second;
1898
1899 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
1900 if (!isa<SCEVSignExtendExpr>(S))
1901 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1902 return S;
1903}
1904
1906 unsigned Depth) {
1907 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1908 "This is not an extending conversion!");
1909 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1910 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1911 Ty = getEffectiveSCEVType(Ty);
1912
1913 // Fold if the operand is constant.
1914 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1915 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
1916
1917 // sext(sext(x)) --> sext(x)
1918 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1919 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1920
1921 // sext(zext(x)) --> zext(x)
1922 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1923 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1924
1925 // Before doing any expensive analysis, check to see if we've already
1926 // computed a SCEV for this Op and Ty.
1928 ID.AddInteger(scSignExtend);
1929 ID.AddPointer(Op);
1930 ID.AddPointer(Ty);
1931 void *IP = nullptr;
1932 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1933 // Limit recursion depth.
1934 if (Depth > MaxCastDepth) {
1935 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1936 Op, Ty);
1937 UniqueSCEVs.InsertNode(S, IP);
1938 registerUser(S, Op);
1939 return S;
1940 }
1941
1942 // sext(trunc(x)) --> sext(x) or x or trunc(x)
1943 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1944 // It's possible the bits taken off by the truncate were all sign bits. If
1945 // so, we should be able to simplify this further.
1946 const SCEV *X = ST->getOperand();
1948 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1949 unsigned NewBits = getTypeSizeInBits(Ty);
1950 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1951 CR.sextOrTrunc(NewBits)))
1952 return getTruncateOrSignExtend(X, Ty, Depth);
1953 }
1954
1955 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1956 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1957 if (SA->hasNoSignedWrap()) {
1958 // If the addition does not sign overflow then we can, by definition,
1959 // commute the sign extension with the addition operation.
1961 for (const auto *Op : SA->operands())
1962 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1963 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1964 }
1965
1966 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1967 // if D + (C - D + x + y + ...) could be proven to not signed wrap
1968 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1969 //
1970 // For instance, this will bring two seemingly different expressions:
1971 // 1 + sext(5 + 20 * %x + 24 * %y) and
1972 // sext(6 + 20 * %x + 24 * %y)
1973 // to the same form:
1974 // 2 + sext(4 + 20 * %x + 24 * %y)
1975 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1976 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1977 if (D != 0) {
1978 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1979 const SCEV *SResidual =
1981 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1982 return getAddExpr(SSExtD, SSExtR,
1984 Depth + 1);
1985 }
1986 }
1987 }
1988 // If the input value is a chrec scev, and we can prove that the value
1989 // did not overflow the old, smaller, value, we can sign extend all of the
1990 // operands (often constants). This allows analysis of something like
1991 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
1992 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1993 if (AR->isAffine()) {
1994 const SCEV *Start = AR->getStart();
1995 const SCEV *Step = AR->getStepRecurrence(*this);
1996 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1997 const Loop *L = AR->getLoop();
1998
1999 // If we have special knowledge that this addrec won't overflow,
2000 // we don't need to do any further analysis.
2001 if (AR->hasNoSignedWrap()) {
2002 Start =
2003 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2004 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2005 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2006 }
2007
2008 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2009 // Note that this serves two purposes: It filters out loops that are
2010 // simply not analyzable, and it covers the case where this code is
2011 // being called from within backedge-taken count analysis, such that
2012 // attempting to ask for the backedge-taken count would likely result
2013 // in infinite recursion. In the later case, the analysis code will
2014 // cope with a conservative value, and it will take care to purge
2015 // that value once it has finished.
2016 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2017 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2018 // Manually compute the final value for AR, checking for
2019 // overflow.
2020
2021 // Check whether the backedge-taken count can be losslessly casted to
2022 // the addrec's type. The count is always unsigned.
2023 const SCEV *CastedMaxBECount =
2024 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2025 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2026 CastedMaxBECount, MaxBECount->getType(), Depth);
2027 if (MaxBECount == RecastedMaxBECount) {
2028 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2029 // Check whether Start+Step*MaxBECount has no signed overflow.
2030 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2032 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2034 Depth + 1),
2035 WideTy, Depth + 1);
2036 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2037 const SCEV *WideMaxBECount =
2038 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2039 const SCEV *OperandExtendedAdd =
2040 getAddExpr(WideStart,
2041 getMulExpr(WideMaxBECount,
2042 getSignExtendExpr(Step, WideTy, Depth + 1),
2045 if (SAdd == OperandExtendedAdd) {
2046 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2047 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2048 // Return the expression with the addrec on the outside.
2049 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2050 Depth + 1);
2051 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2052 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2053 }
2054 // Similar to above, only this time treat the step value as unsigned.
2055 // This covers loops that count up with an unsigned step.
2056 OperandExtendedAdd =
2057 getAddExpr(WideStart,
2058 getMulExpr(WideMaxBECount,
2059 getZeroExtendExpr(Step, WideTy, Depth + 1),
2062 if (SAdd == OperandExtendedAdd) {
2063 // If AR wraps around then
2064 //
2065 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2066 // => SAdd != OperandExtendedAdd
2067 //
2068 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2069 // (SAdd == OperandExtendedAdd => AR is NW)
2070
2071 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2072
2073 // Return the expression with the addrec on the outside.
2074 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2075 Depth + 1);
2076 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2077 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2078 }
2079 }
2080 }
2081
2082 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2083 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2084 if (AR->hasNoSignedWrap()) {
2085 // Same as nsw case above - duplicated here to avoid a compile time
2086 // issue. It's not clear that the order of checks does matter, but
2087 // it's one of two issue possible causes for a change which was
2088 // reverted. Be conservative for the moment.
2089 Start =
2090 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2091 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2092 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2093 }
2094
2095 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2096 // if D + (C - D + Step * n) could be proven to not signed wrap
2097 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2098 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2099 const APInt &C = SC->getAPInt();
2100 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2101 if (D != 0) {
2102 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2103 const SCEV *SResidual =
2104 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2105 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2106 return getAddExpr(SSExtD, SSExtR,
2108 Depth + 1);
2109 }
2110 }
2111
2112 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2113 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2114 Start =
2115 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2116 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2117 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2118 }
2119 }
2120
2121 // If the input value is provably positive and we could not simplify
2122 // away the sext build a zext instead.
2124 return getZeroExtendExpr(Op, Ty, Depth + 1);
2125
2126 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2127 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2128 if (isa<SCEVSMinExpr>(Op) || isa<SCEVSMaxExpr>(Op)) {
2129 auto *MinMax = cast<SCEVMinMaxExpr>(Op);
2131 for (auto *Operand : MinMax->operands())
2132 Operands.push_back(getSignExtendExpr(Operand, Ty));
2133 if (isa<SCEVSMinExpr>(MinMax))
2134 return getSMinExpr(Operands);
2135 return getSMaxExpr(Operands);
2136 }
2137
2138 // The cast wasn't folded; create an explicit cast node.
2139 // Recompute the insert position, as it may have been invalidated.
2140 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2141 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2142 Op, Ty);
2143 UniqueSCEVs.InsertNode(S, IP);
2144 registerUser(S, { Op });
2145 return S;
2146}
2147
2149 Type *Ty) {
2150 switch (Kind) {
2151 case scTruncate:
2152 return getTruncateExpr(Op, Ty);
2153 case scZeroExtend:
2154 return getZeroExtendExpr(Op, Ty);
2155 case scSignExtend:
2156 return getSignExtendExpr(Op, Ty);
2157 case scPtrToInt:
2158 return getPtrToIntExpr(Op, Ty);
2159 default:
2160 llvm_unreachable("Not a SCEV cast expression!");
2161 }
2162}
2163
2164/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2165/// unspecified bits out to the given type.
2167 Type *Ty) {
2168 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2169 "This is not an extending conversion!");
2170 assert(isSCEVable(Ty) &&
2171 "This is not a conversion to a SCEVable type!");
2172 Ty = getEffectiveSCEVType(Ty);
2173
2174 // Sign-extend negative constants.
2175 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2176 if (SC->getAPInt().isNegative())
2177 return getSignExtendExpr(Op, Ty);
2178
2179 // Peel off a truncate cast.
2180 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2181 const SCEV *NewOp = T->getOperand();
2182 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2183 return getAnyExtendExpr(NewOp, Ty);
2184 return getTruncateOrNoop(NewOp, Ty);
2185 }
2186
2187 // Next try a zext cast. If the cast is folded, use it.
2188 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2189 if (!isa<SCEVZeroExtendExpr>(ZExt))
2190 return ZExt;
2191
2192 // Next try a sext cast. If the cast is folded, use it.
2193 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2194 if (!isa<SCEVSignExtendExpr>(SExt))
2195 return SExt;
2196
2197 // Force the cast to be folded into the operands of an addrec.
2198 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2200 for (const SCEV *Op : AR->operands())
2201 Ops.push_back(getAnyExtendExpr(Op, Ty));
2202 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2203 }
2204
2205 // If the expression is obviously signed, use the sext cast value.
2206 if (isa<SCEVSMaxExpr>(Op))
2207 return SExt;
2208
2209 // Absent any other information, use the zext cast value.
2210 return ZExt;
2211}
2212
2213/// Process the given Ops list, which is a list of operands to be added under
2214/// the given scale, update the given map. This is a helper function for
2215/// getAddRecExpr. As an example of what it does, given a sequence of operands
2216/// that would form an add expression like this:
2217///
2218/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2219///
2220/// where A and B are constants, update the map with these values:
2221///
2222/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2223///
2224/// and add 13 + A*B*29 to AccumulatedConstant.
2225/// This will allow getAddRecExpr to produce this:
2226///
2227/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2228///
2229/// This form often exposes folding opportunities that are hidden in
2230/// the original operand list.
2231///
2232/// Return true iff it appears that any interesting folding opportunities
2233/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2234/// the common case where no interesting opportunities are present, and
2235/// is also used as a check to avoid infinite recursion.
2236static bool
2239 APInt &AccumulatedConstant,
2240 ArrayRef<const SCEV *> Ops, const APInt &Scale,
2241 ScalarEvolution &SE) {
2242 bool Interesting = false;
2243
2244 // Iterate over the add operands. They are sorted, with constants first.
2245 unsigned i = 0;
2246 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2247 ++i;
2248 // Pull a buried constant out to the outside.
2249 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2250 Interesting = true;
2251 AccumulatedConstant += Scale * C->getAPInt();
2252 }
2253
2254 // Next comes everything else. We're especially interested in multiplies
2255 // here, but they're in the middle, so just visit the rest with one loop.
2256 for (; i != Ops.size(); ++i) {
2257 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
2258 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2259 APInt NewScale =
2260 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2261 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2262 // A multiplication of a constant with another add; recurse.
2263 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2264 Interesting |=
2265 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2266 Add->operands(), NewScale, SE);
2267 } else {
2268 // A multiplication of a constant with some other value. Update
2269 // the map.
2270 SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2271 const SCEV *Key = SE.getMulExpr(MulOps);
2272 auto Pair = M.insert({Key, NewScale});
2273 if (Pair.second) {
2274 NewOps.push_back(Pair.first->first);
2275 } else {
2276 Pair.first->second += NewScale;
2277 // The map already had an entry for this value, which may indicate
2278 // a folding opportunity.
2279 Interesting = true;
2280 }
2281 }
2282 } else {
2283 // An ordinary operand. Update the map.
2284 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2285 M.insert({Ops[i], Scale});
2286 if (Pair.second) {
2287 NewOps.push_back(Pair.first->first);
2288 } else {
2289 Pair.first->second += Scale;
2290 // The map already had an entry for this value, which may indicate
2291 // a folding opportunity.
2292 Interesting = true;
2293 }
2294 }
2295 }
2296
2297 return Interesting;
2298}
2299
2301 const SCEV *LHS, const SCEV *RHS,
2302 const Instruction *CtxI) {
2303 const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2304 SCEV::NoWrapFlags, unsigned);
2305 switch (BinOp) {
2306 default:
2307 llvm_unreachable("Unsupported binary op");
2308 case Instruction::Add:
2310 break;
2311 case Instruction::Sub:
2313 break;
2314 case Instruction::Mul:
2316 break;
2317 }
2318
2319 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2322
2323 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2324 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2325 auto *WideTy =
2326 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2327
2328 const SCEV *A = (this->*Extension)(
2329 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2330 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2331 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2332 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2333 if (A == B)
2334 return true;
2335 // Can we use context to prove the fact we need?
2336 if (!CtxI)
2337 return false;
2338 // TODO: Support mul.
2339 if (BinOp == Instruction::Mul)
2340 return false;
2341 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2342 // TODO: Lift this limitation.
2343 if (!RHSC)
2344 return false;
2345 APInt C = RHSC->getAPInt();
2346 unsigned NumBits = C.getBitWidth();
2347 bool IsSub = (BinOp == Instruction::Sub);
2348 bool IsNegativeConst = (Signed && C.isNegative());
2349 // Compute the direction and magnitude by which we need to check overflow.
2350 bool OverflowDown = IsSub ^ IsNegativeConst;
2351 APInt Magnitude = C;
2352 if (IsNegativeConst) {
2353 if (C == APInt::getSignedMinValue(NumBits))
2354 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2355 // want to deal with that.
2356 return false;
2357 Magnitude = -C;
2358 }
2359
2361 if (OverflowDown) {
2362 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2363 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2364 : APInt::getMinValue(NumBits);
2365 APInt Limit = Min + Magnitude;
2366 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2367 } else {
2368 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2369 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2370 : APInt::getMaxValue(NumBits);
2371 APInt Limit = Max - Magnitude;
2372 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2373 }
2374}
2375
2376std::optional<SCEV::NoWrapFlags>
2378 const OverflowingBinaryOperator *OBO) {
2379 // It cannot be done any better.
2380 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2381 return std::nullopt;
2382
2384
2385 if (OBO->hasNoUnsignedWrap())
2387 if (OBO->hasNoSignedWrap())
2389
2390 bool Deduced = false;
2391
2392 if (OBO->getOpcode() != Instruction::Add &&
2393 OBO->getOpcode() != Instruction::Sub &&
2394 OBO->getOpcode() != Instruction::Mul)
2395 return std::nullopt;
2396
2397 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2398 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2399
2400 const Instruction *CtxI =
2401 UseContextForNoWrapFlagInference ? dyn_cast<Instruction>(OBO) : nullptr;
2402 if (!OBO->hasNoUnsignedWrap() &&
2404 /* Signed */ false, LHS, RHS, CtxI)) {
2406 Deduced = true;
2407 }
2408
2409 if (!OBO->hasNoSignedWrap() &&
2411 /* Signed */ true, LHS, RHS, CtxI)) {
2413 Deduced = true;
2414 }
2415
2416 if (Deduced)
2417 return Flags;
2418 return std::nullopt;
2419}
2420
2421// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2422// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2423// can't-overflow flags for the operation if possible.
2424static SCEV::NoWrapFlags
2426 const ArrayRef<const SCEV *> Ops,
2427 SCEV::NoWrapFlags Flags) {
2428 using namespace std::placeholders;
2429
2430 using OBO = OverflowingBinaryOperator;
2431
2432 bool CanAnalyze =
2434 (void)CanAnalyze;
2435 assert(CanAnalyze && "don't call from other places!");
2436
2437 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2438 SCEV::NoWrapFlags SignOrUnsignWrap =
2439 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2440
2441 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2442 auto IsKnownNonNegative = [&](const SCEV *S) {
2443 return SE->isKnownNonNegative(S);
2444 };
2445
2446 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2447 Flags =
2448 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2449
2450 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2451
2452 if (SignOrUnsignWrap != SignOrUnsignMask &&
2453 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2454 isa<SCEVConstant>(Ops[0])) {
2455
2456 auto Opcode = [&] {
2457 switch (Type) {
2458 case scAddExpr:
2459 return Instruction::Add;
2460 case scMulExpr:
2461 return Instruction::Mul;
2462 default:
2463 llvm_unreachable("Unexpected SCEV op.");
2464 }
2465 }();
2466
2467 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2468
2469 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2470 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2472 Opcode, C, OBO::NoSignedWrap);
2473 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2475 }
2476
2477 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2478 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2480 Opcode, C, OBO::NoUnsignedWrap);
2481 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2483 }
2484 }
2485
2486 // <0,+,nonnegative><nw> is also nuw
2487 // TODO: Add corresponding nsw case
2489 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2490 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2492
2493 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2495 Ops.size() == 2) {
2496 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2497 if (UDiv->getOperand(1) == Ops[1])
2499 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2500 if (UDiv->getOperand(1) == Ops[0])
2502 }
2503
2504 return Flags;
2505}
2506
2508 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2509}
2510
2511/// Get a canonical add expression, or something simpler if possible.
2513 SCEV::NoWrapFlags OrigFlags,
2514 unsigned Depth) {
2515 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2516 "only nuw or nsw allowed");
2517 assert(!Ops.empty() && "Cannot get empty add!");
2518 if (Ops.size() == 1) return Ops[0];
2519#ifndef NDEBUG
2520 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2521 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2522 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2523 "SCEVAddExpr operand types don't match!");
2524 unsigned NumPtrs = count_if(
2525 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2526 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2527#endif
2528
2529 // Sort by complexity, this groups all similar expression types together.
2530 GroupByComplexity(Ops, &LI, DT);
2531
2532 // If there are any constants, fold them together.
2533 unsigned Idx = 0;
2534 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2535 ++Idx;
2536 assert(Idx < Ops.size());
2537 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2538 // We found two constants, fold them together!
2539 Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt());
2540 if (Ops.size() == 2) return Ops[0];
2541 Ops.erase(Ops.begin()+1); // Erase the folded element
2542 LHSC = cast<SCEVConstant>(Ops[0]);
2543 }
2544
2545 // If we are left with a constant zero being added, strip it off.
2546 if (LHSC->getValue()->isZero()) {
2547 Ops.erase(Ops.begin());
2548 --Idx;
2549 }
2550
2551 if (Ops.size() == 1) return Ops[0];
2552 }
2553
2554 // Delay expensive flag strengthening until necessary.
2555 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
2556 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2557 };
2558
2559 // Limit recursion calls depth.
2561 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2562
2563 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2564 // Don't strengthen flags if we have no new information.
2565 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2566 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2567 Add->setNoWrapFlags(ComputeFlags(Ops));
2568 return S;
2569 }
2570
2571 // Okay, check to see if the same value occurs in the operand list more than
2572 // once. If so, merge them together into an multiply expression. Since we
2573 // sorted the list, these values are required to be adjacent.
2574 Type *Ty = Ops[0]->getType();
2575 bool FoundMatch = false;
2576 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2577 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2578 // Scan ahead to count how many equal operands there are.
2579 unsigned Count = 2;
2580 while (i+Count != e && Ops[i+Count] == Ops[i])
2581 ++Count;
2582 // Merge the values into a multiply.
2583 const SCEV *Scale = getConstant(Ty, Count);
2584 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2585 if (Ops.size() == Count)
2586 return Mul;
2587 Ops[i] = Mul;
2588 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2589 --i; e -= Count - 1;
2590 FoundMatch = true;
2591 }
2592 if (FoundMatch)
2593 return getAddExpr(Ops, OrigFlags, Depth + 1);
2594
2595 // Check for truncates. If all the operands are truncated from the same
2596 // type, see if factoring out the truncate would permit the result to be
2597 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2598 // if the contents of the resulting outer trunc fold to something simple.
2599 auto FindTruncSrcType = [&]() -> Type * {
2600 // We're ultimately looking to fold an addrec of truncs and muls of only
2601 // constants and truncs, so if we find any other types of SCEV
2602 // as operands of the addrec then we bail and return nullptr here.
2603 // Otherwise, we return the type of the operand of a trunc that we find.
2604 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2605 return T->getOperand()->getType();
2606 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2607 const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2608 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2609 return T->getOperand()->getType();
2610 }
2611 return nullptr;
2612 };
2613 if (auto *SrcType = FindTruncSrcType()) {
2615 bool Ok = true;
2616 // Check all the operands to see if they can be represented in the
2617 // source type of the truncate.
2618 for (const SCEV *Op : Ops) {
2619 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2620 if (T->getOperand()->getType() != SrcType) {
2621 Ok = false;
2622 break;
2623 }
2624 LargeOps.push_back(T->getOperand());
2625 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2626 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2627 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2628 SmallVector<const SCEV *, 8> LargeMulOps;
2629 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2630 if (const SCEVTruncateExpr *T =
2631 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2632 if (T->getOperand()->getType() != SrcType) {
2633 Ok = false;
2634 break;
2635 }
2636 LargeMulOps.push_back(T->getOperand());
2637 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2638 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2639 } else {
2640 Ok = false;
2641 break;
2642 }
2643 }
2644 if (Ok)
2645 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2646 } else {
2647 Ok = false;
2648 break;
2649 }
2650 }
2651 if (Ok) {
2652 // Evaluate the expression in the larger type.
2653 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2654 // If it folds to something simple, use it. Otherwise, don't.
2655 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2656 return getTruncateExpr(Fold, Ty);
2657 }
2658 }
2659
2660 if (Ops.size() == 2) {
2661 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2662 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2663 // C1).
2664 const SCEV *A = Ops[0];
2665 const SCEV *B = Ops[1];
2666 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2667 auto *C = dyn_cast<SCEVConstant>(A);
2668 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2669 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2670 auto C2 = C->getAPInt();
2671 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2672
2673 APInt ConstAdd = C1 + C2;
2674 auto AddFlags = AddExpr->getNoWrapFlags();
2675 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2677 ConstAdd.ule(C1)) {
2678 PreservedFlags =
2680 }
2681
2682 // Adding a constant with the same sign and small magnitude is NSW, if the
2683 // original AddExpr was NSW.
2685 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2686 ConstAdd.abs().ule(C1.abs())) {
2687 PreservedFlags =
2689 }
2690
2691 if (PreservedFlags != SCEV::FlagAnyWrap) {
2692 SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2693 NewOps[0] = getConstant(ConstAdd);
2694 return getAddExpr(NewOps, PreservedFlags);
2695 }
2696 }
2697 }
2698
2699 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2700 if (Ops.size() == 2) {
2701 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
2702 if (Mul && Mul->getNumOperands() == 2 &&
2703 Mul->getOperand(0)->isAllOnesValue()) {
2704 const SCEV *X;
2705 const SCEV *Y;
2706 if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
2707 return getMulExpr(Y, getUDivExpr(X, Y));
2708 }
2709 }
2710 }
2711
2712 // Skip past any other cast SCEVs.
2713 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2714 ++Idx;
2715
2716 // If there are add operands they would be next.
2717 if (Idx < Ops.size()) {
2718 bool DeletedAdd = false;
2719 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2720 // common NUW flag for expression after inlining. Other flags cannot be
2721 // preserved, because they may depend on the original order of operations.
2722 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2723 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2724 if (Ops.size() > AddOpsInlineThreshold ||
2725 Add->getNumOperands() > AddOpsInlineThreshold)
2726 break;
2727 // If we have an add, expand the add operands onto the end of the operands
2728 // list.
2729 Ops.erase(Ops.begin()+Idx);
2730 append_range(Ops, Add->operands());
2731 DeletedAdd = true;
2732 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2733 }
2734
2735 // If we deleted at least one add, we added operands to the end of the list,
2736 // and they are not necessarily sorted. Recurse to resort and resimplify
2737 // any operands we just acquired.
2738 if (DeletedAdd)
2739 return getAddExpr(Ops, CommonFlags, Depth + 1);
2740 }
2741
2742 // Skip over the add expression until we get to a multiply.
2743 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2744 ++Idx;
2745
2746 // Check to see if there are any folding opportunities present with
2747 // operands multiplied by constant values.
2748 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2752 APInt AccumulatedConstant(BitWidth, 0);
2753 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2754 Ops, APInt(BitWidth, 1), *this)) {
2755 struct APIntCompare {
2756 bool operator()(const APInt &LHS, const APInt &RHS) const {
2757 return LHS.ult(RHS);
2758 }
2759 };
2760
2761 // Some interesting folding opportunity is present, so its worthwhile to
2762 // re-generate the operands list. Group the operands by constant scale,
2763 // to avoid multiplying by the same constant scale multiple times.
2764 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2765 for (const SCEV *NewOp : NewOps)
2766 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2767 // Re-generate the operands list.
2768 Ops.clear();
2769 if (AccumulatedConstant != 0)
2770 Ops.push_back(getConstant(AccumulatedConstant));
2771 for (auto &MulOp : MulOpLists) {
2772 if (MulOp.first == 1) {
2773 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2774 } else if (MulOp.first != 0) {
2776 getConstant(MulOp.first),
2777 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2778 SCEV::FlagAnyWrap, Depth + 1));
2779 }
2780 }
2781 if (Ops.empty())
2782 return getZero(Ty);
2783 if (Ops.size() == 1)
2784 return Ops[0];
2785 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2786 }
2787 }
2788
2789 // If we are adding something to a multiply expression, make sure the
2790 // something is not already an operand of the multiply. If so, merge it into
2791 // the multiply.
2792 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2793 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2794 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2795 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2796 if (isa<SCEVConstant>(MulOpSCEV))
2797 continue;
2798 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2799 if (MulOpSCEV == Ops[AddOp]) {
2800 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2801 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2802 if (Mul->getNumOperands() != 2) {
2803 // If the multiply has more than two operands, we must get the
2804 // Y*Z term.
2806 Mul->operands().take_front(MulOp));
2807 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2808 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2809 }
2810 SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2811 const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2812 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2814 if (Ops.size() == 2) return OuterMul;
2815 if (AddOp < Idx) {
2816 Ops.erase(Ops.begin()+AddOp);
2817 Ops.erase(Ops.begin()+Idx-1);
2818 } else {
2819 Ops.erase(Ops.begin()+Idx);
2820 Ops.erase(Ops.begin()+AddOp-1);
2821 }
2822 Ops.push_back(OuterMul);
2823 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2824 }
2825
2826 // Check this multiply against other multiplies being added together.
2827 for (unsigned OtherMulIdx = Idx+1;
2828 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2829 ++OtherMulIdx) {
2830 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2831 // If MulOp occurs in OtherMul, we can fold the two multiplies
2832 // together.
2833 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2834 OMulOp != e; ++OMulOp)
2835 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2836 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2837 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2838 if (Mul->getNumOperands() != 2) {
2840 Mul->operands().take_front(MulOp));
2841 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2842 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2843 }
2844 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2845 if (OtherMul->getNumOperands() != 2) {
2847 OtherMul->operands().take_front(OMulOp));
2848 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2849 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2850 }
2851 SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2852 const SCEV *InnerMulSum =
2853 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2854 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2856 if (Ops.size() == 2) return OuterMul;
2857 Ops.erase(Ops.begin()+Idx);
2858 Ops.erase(Ops.begin()+OtherMulIdx-1);
2859 Ops.push_back(OuterMul);
2860 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2861 }
2862 }
2863 }
2864 }
2865
2866 // If there are any add recurrences in the operands list, see if any other
2867 // added values are loop invariant. If so, we can fold them into the
2868 // recurrence.
2869 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2870 ++Idx;
2871
2872 // Scan over all recurrences, trying to fold loop invariants into them.
2873 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2874 // Scan all of the other operands to this add and add them to the vector if
2875 // they are loop invariant w.r.t. the recurrence.
2877 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2878 const Loop *AddRecLoop = AddRec->getLoop();
2879 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2880 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2881 LIOps.push_back(Ops[i]);
2882 Ops.erase(Ops.begin()+i);
2883 --i; --e;
2884 }
2885
2886 // If we found some loop invariants, fold them into the recurrence.
2887 if (!LIOps.empty()) {
2888 // Compute nowrap flags for the addition of the loop-invariant ops and
2889 // the addrec. Temporarily push it as an operand for that purpose. These
2890 // flags are valid in the scope of the addrec only.
2891 LIOps.push_back(AddRec);
2892 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2893 LIOps.pop_back();
2894
2895 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2896 LIOps.push_back(AddRec->getStart());
2897
2898 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2899
2900 // It is not in general safe to propagate flags valid on an add within
2901 // the addrec scope to one outside it. We must prove that the inner
2902 // scope is guaranteed to execute if the outer one does to be able to
2903 // safely propagate. We know the program is undefined if poison is
2904 // produced on the inner scoped addrec. We also know that *for this use*
2905 // the outer scoped add can't overflow (because of the flags we just
2906 // computed for the inner scoped add) without the program being undefined.
2907 // Proving that entry to the outer scope neccesitates entry to the inner
2908 // scope, thus proves the program undefined if the flags would be violated
2909 // in the outer scope.
2910 SCEV::NoWrapFlags AddFlags = Flags;
2911 if (AddFlags != SCEV::FlagAnyWrap) {
2912 auto *DefI = getDefiningScopeBound(LIOps);
2913 auto *ReachI = &*AddRecLoop->getHeader()->begin();
2914 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2915 AddFlags = SCEV::FlagAnyWrap;
2916 }
2917 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2918
2919 // Build the new addrec. Propagate the NUW and NSW flags if both the
2920 // outer add and the inner addrec are guaranteed to have no overflow.
2921 // Always propagate NW.
2922 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2923 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2924
2925 // If all of the other operands were loop invariant, we are done.
2926 if (Ops.size() == 1) return NewRec;
2927
2928 // Otherwise, add the folded AddRec by the non-invariant parts.
2929 for (unsigned i = 0;; ++i)
2930 if (Ops[i] == AddRec) {
2931 Ops[i] = NewRec;
2932 break;
2933 }
2934 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2935 }
2936
2937 // Okay, if there weren't any loop invariants to be folded, check to see if
2938 // there are multiple AddRec's with the same loop induction variable being
2939 // added together. If so, we can fold them.
2940 for (unsigned OtherIdx = Idx+1;
2941 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2942 ++OtherIdx) {
2943 // We expect the AddRecExpr's to be sorted in reverse dominance order,
2944 // so that the 1st found AddRecExpr is dominated by all others.
2945 assert(DT.dominates(
2946 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2947 AddRec->getLoop()->getHeader()) &&
2948 "AddRecExprs are not sorted in reverse dominance order?");
2949 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2950 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2951 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2952 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2953 ++OtherIdx) {
2954 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2955 if (OtherAddRec->getLoop() == AddRecLoop) {
2956 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2957 i != e; ++i) {
2958 if (i >= AddRecOps.size()) {
2959 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
2960 break;
2961 }
2963 AddRecOps[i], OtherAddRec->getOperand(i)};
2964 AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2965 }
2966 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2967 }
2968 }
2969 // Step size has changed, so we cannot guarantee no self-wraparound.
2970 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2971 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2972 }
2973 }
2974
2975 // Otherwise couldn't fold anything into this recurrence. Move onto the
2976 // next one.
2977 }
2978
2979 // Okay, it looks like we really DO need an add expr. Check to see if we
2980 // already have one, otherwise create a new one.
2981 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2982}
2983
2984const SCEV *
2985ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2986 SCEV::NoWrapFlags Flags) {
2988 ID.AddInteger(scAddExpr);
2989 for (const SCEV *Op : Ops)
2990 ID.AddPointer(Op);
2991 void *IP = nullptr;
2992 SCEVAddExpr *S =
2993 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2994 if (!S) {
2995 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2996 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2997 S = new (SCEVAllocator)
2998 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
2999 UniqueSCEVs.InsertNode(S, IP);
3000 registerUser(S, Ops);
3001 }
3002 S->setNoWrapFlags(Flags);
3003 return S;
3004}
3005
3006const SCEV *
3007ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
3008 const Loop *L, SCEV::NoWrapFlags Flags) {
3010 ID.AddInteger(scAddRecExpr);
3011 for (const SCEV *Op : Ops)
3012 ID.AddPointer(Op);
3013 ID.AddPointer(L);
3014 void *IP = nullptr;
3015 SCEVAddRecExpr *S =
3016 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3017 if (!S) {
3018 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3019 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3020 S = new (SCEVAllocator)
3021 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3022 UniqueSCEVs.InsertNode(S, IP);
3023 LoopUsers[L].push_back(S);
3024 registerUser(S, Ops);
3025 }
3026 setNoWrapFlags(S, Flags);
3027 return S;
3028}
3029
3030const SCEV *
3031ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
3032 SCEV::NoWrapFlags Flags) {
3034 ID.AddInteger(scMulExpr);
3035 for (const SCEV *Op : Ops)
3036 ID.AddPointer(Op);
3037 void *IP = nullptr;
3038 SCEVMulExpr *S =
3039 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3040 if (!S) {
3041 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3042 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3043 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3044 O, Ops.size());
3045 UniqueSCEVs.InsertNode(S, IP);
3046 registerUser(S, Ops);
3047 }
3048 S->setNoWrapFlags(Flags);
3049 return S;
3050}
3051
3052static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3053 uint64_t k = i*j;
3054 if (j > 1 && k / j != i) Overflow = true;
3055 return k;
3056}
3057
3058/// Compute the result of "n choose k", the binomial coefficient. If an
3059/// intermediate computation overflows, Overflow will be set and the return will
3060/// be garbage. Overflow is not cleared on absence of overflow.
3061static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3062 // We use the multiplicative formula:
3063 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3064 // At each iteration, we take the n-th term of the numeral and divide by the
3065 // (k-n)th term of the denominator. This division will always produce an
3066 // integral result, and helps reduce the chance of overflow in the
3067 // intermediate computations. However, we can still overflow even when the
3068 // final result would fit.
3069
3070 if (n == 0 || n == k) return 1;
3071 if (k > n) return 0;
3072
3073 if (k > n/2)
3074 k = n-k;
3075
3076 uint64_t r = 1;
3077 for (uint64_t i = 1; i <= k; ++i) {
3078 r = umul_ov(r, n-(i-1), Overflow);
3079 r /= i;
3080 }
3081 return r;
3082}
3083
3084/// Determine if any of the operands in this SCEV are a constant or if
3085/// any of the add or multiply expressions in this SCEV contain a constant.
3086static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3087 struct FindConstantInAddMulChain {
3088 bool FoundConstant = false;
3089
3090 bool follow(const SCEV *S) {
3091 FoundConstant |= isa<SCEVConstant>(S);
3092 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3093 }
3094
3095 bool isDone() const {
3096 return FoundConstant;
3097 }
3098 };
3099
3100 FindConstantInAddMulChain F;
3102 ST.visitAll(StartExpr);
3103 return F.FoundConstant;
3104}
3105
3106/// Get a canonical multiply expression, or something simpler if possible.
3108 SCEV::NoWrapFlags OrigFlags,
3109 unsigned Depth) {
3110 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3111 "only nuw or nsw allowed");
3112 assert(!Ops.empty() && "Cannot get empty mul!");
3113 if (Ops.size() == 1) return Ops[0];
3114#ifndef NDEBUG
3115 Type *ETy = Ops[0]->getType();
3116 assert(!ETy->isPointerTy());
3117 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3118 assert(Ops[i]->getType() == ETy &&
3119 "SCEVMulExpr operand types don't match!");
3120#endif
3121
3122 // Sort by complexity, this groups all similar expression types together.
3123 GroupByComplexity(Ops, &LI, DT);
3124
3125 // If there are any constants, fold them together.
3126 unsigned Idx = 0;
3127 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3128 ++Idx;
3129 assert(Idx < Ops.size());
3130 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3131 // We found two constants, fold them together!
3132 Ops[0] = getConstant(LHSC->getAPInt() * RHSC->getAPInt());
3133 if (Ops.size() == 2) return Ops[0];
3134 Ops.erase(Ops.begin()+1); // Erase the folded element
3135 LHSC = cast<SCEVConstant>(Ops[0]);
3136 }
3137
3138 // If we have a multiply of zero, it will always be zero.
3139 if (LHSC->getValue()->isZero())
3140 return LHSC;
3141
3142 // If we are left with a constant one being multiplied, strip it off.
3143 if (LHSC->getValue()->isOne()) {
3144 Ops.erase(Ops.begin());
3145 --Idx;
3146 }
3147
3148 if (Ops.size() == 1)
3149 return Ops[0];
3150 }
3151
3152 // Delay expensive flag strengthening until necessary.
3153 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
3154 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3155 };
3156
3157 // Limit recursion calls depth.
3159 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3160
3161 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3162 // Don't strengthen flags if we have no new information.
3163 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3164 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3165 Mul->setNoWrapFlags(ComputeFlags(Ops));
3166 return S;
3167 }
3168
3169 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3170 if (Ops.size() == 2) {
3171 // C1*(C2+V) -> C1*C2 + C1*V
3172 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
3173 // If any of Add's ops are Adds or Muls with a constant, apply this
3174 // transformation as well.
3175 //
3176 // TODO: There are some cases where this transformation is not
3177 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3178 // this transformation should be narrowed down.
3179 if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) {
3180 const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0),
3182 const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1),
3184 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3185 }
3186
3187 if (Ops[0]->isAllOnesValue()) {
3188 // If we have a mul by -1 of an add, try distributing the -1 among the
3189 // add operands.
3190 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3192 bool AnyFolded = false;
3193 for (const SCEV *AddOp : Add->operands()) {
3194 const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3195 Depth + 1);
3196 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3197 NewOps.push_back(Mul);
3198 }
3199 if (AnyFolded)
3200 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3201 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3202 // Negation preserves a recurrence's no self-wrap property.
3204 for (const SCEV *AddRecOp : AddRec->operands())
3205 Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3206 Depth + 1));
3207 // Let M be the minimum representable signed value. AddRec with nsw
3208 // multiplied by -1 can have signed overflow if and only if it takes a
3209 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3210 // maximum signed value. In all other cases signed overflow is
3211 // impossible.
3212 auto FlagsMask = SCEV::FlagNW;
3213 if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) {
3214 auto MinInt =
3215 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3216 if (getSignedRangeMin(AddRec) != MinInt)
3217 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3218 }
3219 return getAddRecExpr(Operands, AddRec->getLoop(),
3220 AddRec->getNoWrapFlags(FlagsMask));
3221 }
3222 }
3223 }
3224 }
3225
3226 // Skip over the add expression until we get to a multiply.
3227 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3228 ++Idx;
3229
3230 // If there are mul operands inline them all into this expression.
3231 if (Idx < Ops.size()) {
3232 bool DeletedMul = false;
3233 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3234 if (Ops.size() > MulOpsInlineThreshold)
3235 break;
3236 // If we have an mul, expand the mul operands onto the end of the
3237 // operands list.
3238 Ops.erase(Ops.begin()+Idx);
3239 append_range(Ops, Mul->operands());
3240 DeletedMul = true;
3241 }
3242
3243 // If we deleted at least one mul, we added operands to the end of the
3244 // list, and they are not necessarily sorted. Recurse to resort and
3245 // resimplify any operands we just acquired.
3246 if (DeletedMul)
3247 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3248 }
3249
3250 // If there are any add recurrences in the operands list, see if any other
3251 // added values are loop invariant. If so, we can fold them into the
3252 // recurrence.
3253 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3254 ++Idx;
3255
3256 // Scan over all recurrences, trying to fold loop invariants into them.
3257 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3258 // Scan all of the other operands to this mul and add them to the vector
3259 // if they are loop invariant w.r.t. the recurrence.
3261 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3262 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3263 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3264 LIOps.push_back(Ops[i]);
3265 Ops.erase(Ops.begin()+i);
3266 --i; --e;
3267 }
3268
3269 // If we found some loop invariants, fold them into the recurrence.
3270 if (!LIOps.empty()) {
3271 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3273 NewOps.reserve(AddRec->getNumOperands());
3274 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3275
3276 // If both the mul and addrec are nuw, we can preserve nuw.
3277 // If both the mul and addrec are nsw, we can only preserve nsw if either
3278 // a) they are also nuw, or
3279 // b) all multiplications of addrec operands with scale are nsw.
3280 SCEV::NoWrapFlags Flags =
3281 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3282
3283 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3284 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3285 SCEV::FlagAnyWrap, Depth + 1));
3286
3287 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3289 Instruction::Mul, getSignedRange(Scale),
3291 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3292 Flags = clearFlags(Flags, SCEV::FlagNSW);
3293 }
3294 }
3295
3296 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3297
3298 // If all of the other operands were loop invariant, we are done.
3299 if (Ops.size() == 1) return NewRec;
3300
3301 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3302 for (unsigned i = 0;; ++i)
3303 if (Ops[i] == AddRec) {
3304 Ops[i] = NewRec;
3305 break;
3306 }
3307 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3308 }
3309
3310 // Okay, if there weren't any loop invariants to be folded, check to see
3311 // if there are multiple AddRec's with the same loop induction variable
3312 // being multiplied together. If so, we can fold them.
3313
3314 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3315 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3316 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3317 // ]]],+,...up to x=2n}.
3318 // Note that the arguments to choose() are always integers with values
3319 // known at compile time, never SCEV objects.
3320 //
3321 // The implementation avoids pointless extra computations when the two
3322 // addrec's are of different length (mathematically, it's equivalent to
3323 // an infinite stream of zeros on the right).
3324 bool OpsModified = false;
3325 for (unsigned OtherIdx = Idx+1;
3326 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3327 ++OtherIdx) {
3328 const SCEVAddRecExpr *OtherAddRec =
3329 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3330 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3331 continue;
3332
3333 // Limit max number of arguments to avoid creation of unreasonably big
3334 // SCEVAddRecs with very complex operands.
3335 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3336 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3337 continue;
3338
3339 bool Overflow = false;
3340 Type *Ty = AddRec->getType();
3341 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3343 for (int x = 0, xe = AddRec->getNumOperands() +
3344 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3345 SmallVector <const SCEV *, 7> SumOps;
3346 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3347 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3348 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3349 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3350 z < ze && !Overflow; ++z) {
3351 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3352 uint64_t Coeff;
3353 if (LargerThan64Bits)
3354 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3355 else
3356 Coeff = Coeff1*Coeff2;
3357 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3358 const SCEV *Term1 = AddRec->getOperand(y-z);
3359 const SCEV *Term2 = OtherAddRec->getOperand(z);
3360 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3361 SCEV::FlagAnyWrap, Depth + 1));
3362 }
3363 }
3364 if (SumOps.empty())
3365 SumOps.push_back(getZero(Ty));
3366 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3367 }
3368 if (!Overflow) {
3369 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3371 if (Ops.size() == 2) return NewAddRec;
3372 Ops[Idx] = NewAddRec;
3373 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3374 OpsModified = true;
3375 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3376 if (!AddRec)
3377 break;
3378 }
3379 }
3380 if (OpsModified)
3381 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3382
3383 // Otherwise couldn't fold anything into this recurrence. Move onto the
3384 // next one.
3385 }
3386
3387 // Okay, it looks like we really DO need an mul expr. Check to see if we
3388 // already have one, otherwise create a new one.
3389 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3390}
3391
3392/// Represents an unsigned remainder expression based on unsigned division.
3394 const SCEV *RHS) {
3397 "SCEVURemExpr operand types don't match!");
3398
3399 // Short-circuit easy cases
3400 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3401 // If constant is one, the result is trivial
3402 if (RHSC->getValue()->isOne())
3403 return getZero(LHS->getType()); // X urem 1 --> 0
3404
3405 // If constant is a power of two, fold into a zext(trunc(LHS)).
3406 if (RHSC->getAPInt().isPowerOf2()) {
3407 Type *FullTy = LHS->getType();
3408 Type *TruncTy =
3409 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3410 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3411 }
3412 }
3413
3414 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3415 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3416 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3417 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3418}
3419
3420/// Get a canonical unsigned division expression, or something simpler if
3421/// possible.
3423 const SCEV *RHS) {
3424 assert(!LHS->getType()->isPointerTy() &&
3425 "SCEVUDivExpr operand can't be pointer!");
3426 assert(LHS->getType() == RHS->getType() &&
3427 "SCEVUDivExpr operand types don't match!");
3428
3430 ID.AddInteger(scUDivExpr);
3431 ID.AddPointer(LHS);
3432 ID.AddPointer(RHS);
3433 void *IP = nullptr;
3434 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3435 return S;
3436
3437 // 0 udiv Y == 0
3438 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3439 if (LHSC->getValue()->isZero())
3440 return LHS;
3441
3442 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3443 if (RHSC->getValue()->isOne())
3444 return LHS; // X udiv 1 --> x
3445 // If the denominator is zero, the result of the udiv is undefined. Don't
3446 // try to analyze it, because the resolution chosen here may differ from
3447 // the resolution chosen in other parts of the compiler.
3448 if (!RHSC->getValue()->isZero()) {
3449 // Determine if the division can be folded into the operands of
3450 // its operands.
3451 // TODO: Generalize this to non-constants by using known-bits information.
3452 Type *Ty = LHS->getType();
3453 unsigned LZ = RHSC->getAPInt().countl_zero();
3454 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3455 // For non-power-of-two values, effectively round the value up to the
3456 // nearest power of two.
3457 if (!RHSC->getAPInt().isPowerOf2())
3458 ++MaxShiftAmt;
3459 IntegerType *ExtTy =
3460 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3461 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3462 if (const SCEVConstant *Step =
3463 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3464 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3465 const APInt &StepInt = Step->getAPInt();
3466 const APInt &DivInt = RHSC->getAPInt();
3467 if (!StepInt.urem(DivInt) &&
3468 getZeroExtendExpr(AR, ExtTy) ==
3469 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3470 getZeroExtendExpr(Step, ExtTy),
3471 AR->getLoop(), SCEV::FlagAnyWrap)) {
3473 for (const SCEV *Op : AR->operands())
3474 Operands.push_back(getUDivExpr(Op, RHS));
3475 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3476 }
3477 /// Get a canonical UDivExpr for a recurrence.
3478 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3479 // We can currently only fold X%N if X is constant.
3480 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
3481 if (StartC && !DivInt.urem(StepInt) &&
3482 getZeroExtendExpr(AR, ExtTy) ==
3483 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3484 getZeroExtendExpr(Step, ExtTy),
3485 AR->getLoop(), SCEV::FlagAnyWrap)) {
3486 const APInt &StartInt = StartC->getAPInt();
3487 const APInt &StartRem = StartInt.urem(StepInt);
3488 if (StartRem != 0) {
3489 const SCEV *NewLHS =
3490 getAddRecExpr(getConstant(StartInt - StartRem), Step,
3491 AR->getLoop(), SCEV::FlagNW);
3492 if (LHS != NewLHS) {
3493 LHS = NewLHS;
3494
3495 // Reset the ID to include the new LHS, and check if it is
3496 // already cached.
3497 ID.clear();
3498 ID.AddInteger(scUDivExpr);
3499 ID.AddPointer(LHS);
3500 ID.AddPointer(RHS);
3501 IP = nullptr;
3502 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3503 return S;
3504 }
3505 }
3506 }
3507 }
3508 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3509 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3511 for (const SCEV *Op : M->operands())
3512 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3513 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3514 // Find an operand that's safely divisible.
3515 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3516 const SCEV *Op = M->getOperand(i);
3517 const SCEV *Div = getUDivExpr(Op, RHSC);
3518 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3519 Operands = SmallVector<const SCEV *, 4>(M->operands());
3520 Operands[i] = Div;
3521 return getMulExpr(Operands);
3522 }
3523 }
3524 }
3525
3526 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3527 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3528 if (auto *DivisorConstant =
3529 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3530 bool Overflow = false;
3531 APInt NewRHS =
3532 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3533 if (Overflow) {
3534 return getConstant(RHSC->getType(), 0, false);
3535 }
3536 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3537 }
3538 }
3539
3540 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3541 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3543 for (const SCEV *Op : A->operands())
3544 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3545 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3546 Operands.clear();
3547 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3548 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3549 if (isa<SCEVUDivExpr>(Op) ||
3550 getMulExpr(Op, RHS) != A->getOperand(i))
3551 break;
3552 Operands.push_back(Op);
3553 }
3554 if (Operands.size() == A->getNumOperands())
3555 return getAddExpr(Operands);
3556 }
3557 }
3558
3559 // Fold if both operands are constant.
3560 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3561 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3562 }
3563 }
3564
3565 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3566 // changes). Make sure we get a new one.
3567 IP = nullptr;
3568 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3569 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3570 LHS, RHS);
3571 UniqueSCEVs.InsertNode(S, IP);
3572 registerUser(S, {LHS, RHS});
3573 return S;
3574}
3575
3576APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3577 APInt A = C1->getAPInt().abs();
3578 APInt B = C2->getAPInt().abs();
3579 uint32_t ABW = A.getBitWidth();
3580 uint32_t BBW = B.getBitWidth();
3581
3582 if (ABW > BBW)
3583 B = B.zext(ABW);
3584 else if (ABW < BBW)
3585 A = A.zext(BBW);
3586
3587 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3588}
3589
3590/// Get a canonical unsigned division expression, or something simpler if
3591/// possible. There is no representation for an exact udiv in SCEV IR, but we
3592/// can attempt to remove factors from the LHS and RHS. We can't do this when
3593/// it's not exact because the udiv may be clearing bits.
3595 const SCEV *RHS) {
3596 // TODO: we could try to find factors in all sorts of things, but for now we
3597 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3598 // end of this file for inspiration.
3599
3600 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
3601 if (!Mul || !Mul->hasNoUnsignedWrap())
3602 return getUDivExpr(LHS, RHS);
3603
3604 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3605 // If the mulexpr multiplies by a constant, then that constant must be the
3606 // first element of the mulexpr.
3607 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3608 if (LHSCst == RHSCst) {
3610 return getMulExpr(Operands);
3611 }
3612
3613 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3614 // that there's a factor provided by one of the other terms. We need to
3615 // check.
3616 APInt Factor = gcd(LHSCst, RHSCst);
3617 if (!Factor.isIntN(1)) {
3618 LHSCst =
3619 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3620 RHSCst =
3621 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3623 Operands.push_back(LHSCst);
3624 append_range(Operands, Mul->operands().drop_front());
3626 RHS = RHSCst;
3627 Mul = dyn_cast<SCEVMulExpr>(LHS);
3628 if (!Mul)
3629 return getUDivExactExpr(LHS, RHS);
3630 }
3631 }
3632 }
3633
3634 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3635 if (Mul->getOperand(i) == RHS) {
3637 append_range(Operands, Mul->operands().take_front(i));
3638 append_range(Operands, Mul->operands().drop_front(i + 1));
3639 return getMulExpr(Operands);
3640 }
3641 }
3642
3643 return getUDivExpr(LHS, RHS);
3644}
3645
3646/// Get an add recurrence expression for the specified loop. Simplify the
3647/// expression as much as possible.
3648const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3649 const Loop *L,
3650 SCEV::NoWrapFlags Flags) {
3652 Operands.push_back(Start);
3653 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3654 if (StepChrec->getLoop() == L) {
3655 append_range(Operands, StepChrec->operands());
3656 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3657 }
3658
3659 Operands.push_back(Step);
3660 return getAddRecExpr(Operands, L, Flags);
3661}
3662
3663/// Get an add recurrence expression for the specified loop. Simplify the
3664/// expression as much as possible.
3665const SCEV *
3667 const Loop *L, SCEV::NoWrapFlags Flags) {
3668 if (Operands.size() == 1) return Operands[0];
3669#ifndef NDEBUG
3671 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3672 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3673 "SCEVAddRecExpr operand types don't match!");
3674 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3675 }
3676 for (const SCEV *Op : Operands)
3678 "SCEVAddRecExpr operand is not available at loop entry!");
3679#endif
3680
3681 if (Operands.back()->isZero()) {
3682 Operands.pop_back();
3683 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3684 }
3685
3686 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3687 // use that information to infer NUW and NSW flags. However, computing a
3688 // BE count requires calling getAddRecExpr, so we may not yet have a
3689 // meaningful BE count at this point (and if we don't, we'd be stuck
3690 // with a SCEVCouldNotCompute as the cached BE count).
3691
3692 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3693
3694 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3695 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3696 const Loop *NestedLoop = NestedAR->getLoop();
3697 if (L->contains(NestedLoop)
3698 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3699 : (!NestedLoop->contains(L) &&
3700 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3701 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3702 Operands[0] = NestedAR->getStart();
3703 // AddRecs require their operands be loop-invariant with respect to their
3704 // loops. Don't perform this transformation if it would break this
3705 // requirement.
3706 bool AllInvariant = all_of(
3707 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3708
3709 if (AllInvariant) {
3710 // Create a recurrence for the outer loop with the same step size.
3711 //
3712 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3713 // inner recurrence has the same property.
3714 SCEV::NoWrapFlags OuterFlags =
3715 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3716
3717 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3718 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3719 return isLoopInvariant(Op, NestedLoop);
3720 });
3721
3722 if (AllInvariant) {
3723 // Ok, both add recurrences are valid after the transformation.
3724 //
3725 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3726 // the outer recurrence has the same property.
3727 SCEV::NoWrapFlags InnerFlags =
3728 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3729 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3730 }
3731 }
3732 // Reset Operands to its original state.
3733 Operands[0] = NestedAR;
3734 }
3735 }
3736
3737 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3738 // already have one, otherwise create a new one.
3739 return getOrCreateAddRecExpr(Operands, L, Flags);
3740}
3741
3742const SCEV *
3744 const SmallVectorImpl<const SCEV *> &IndexExprs) {
3745 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3746 // getSCEV(Base)->getType() has the same address space as Base->getType()
3747 // because SCEV::getType() preserves the address space.
3748 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3749 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3750 if (NW != GEPNoWrapFlags::none()) {
3751 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3752 // but to do that, we have to ensure that said flag is valid in the entire
3753 // defined scope of the SCEV.
3754 // TODO: non-instructions have global scope. We might be able to prove
3755 // some global scope cases
3756 auto *GEPI = dyn_cast<Instruction>(GEP);
3757 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3758 NW = GEPNoWrapFlags::none();
3759 }
3760
3762 if (NW.hasNoUnsignedSignedWrap())
3763 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3764 if (NW.hasNoUnsignedWrap())
3765 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3766
3767 Type *CurTy = GEP->getType();
3768 bool FirstIter = true;
3770 for (const SCEV *IndexExpr : IndexExprs) {
3771 // Compute the (potentially symbolic) offset in bytes for this index.
3772 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3773 // For a struct, add the member offset.
3774 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3775 unsigned FieldNo = Index->getZExtValue();
3776 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3777 Offsets.push_back(FieldOffset);
3778
3779 // Update CurTy to the type of the field at Index.
3780 CurTy = STy->getTypeAtIndex(Index);
3781 } else {
3782 // Update CurTy to its element type.
3783 if (FirstIter) {
3784 assert(isa<PointerType>(CurTy) &&
3785 "The first index of a GEP indexes a pointer");
3786 CurTy = GEP->getSourceElementType();
3787 FirstIter = false;
3788 } else {
3790 }
3791 // For an array, add the element offset, explicitly scaled.
3792 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3793 // Getelementptr indices are signed.
3794 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3795
3796 // Multiply the index by the element size to compute the element offset.
3797 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3798 Offsets.push_back(LocalOffset);
3799 }
3800 }
3801
3802 // Handle degenerate case of GEP without offsets.
3803 if (Offsets.empty())
3804 return BaseExpr;
3805
3806 // Add the offsets together, assuming nsw if inbounds.
3807 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3808 // Add the base address and the offset. We cannot use the nsw flag, as the
3809 // base address is unsigned. However, if we know that the offset is
3810 // non-negative, we can use nuw.
3811 bool NUW = NW.hasNoUnsignedWrap() ||
3814 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3815 assert(BaseExpr->getType() == GEPExpr->getType() &&
3816 "GEP should not change type mid-flight.");
3817 return GEPExpr;
3818}
3819
3820SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3823 ID.AddInteger(SCEVType);
3824 for (const SCEV *Op : Ops)
3825 ID.AddPointer(Op);
3826 void *IP = nullptr;
3827 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3828}
3829
3830const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3832 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3833}
3834
3837 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3838 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3839 if (Ops.size() == 1) return Ops[0];
3840#ifndef NDEBUG
3841 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3842 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3843 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3844 "Operand types don't match!");
3845 assert(Ops[0]->getType()->isPointerTy() ==
3846 Ops[i]->getType()->isPointerTy() &&
3847 "min/max should be consistently pointerish");
3848 }
3849#endif
3850
3851 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3852 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3853
3854 // Sort by complexity, this groups all similar expression types together.
3855 GroupByComplexity(Ops, &LI, DT);
3856
3857 // Check if we have created the same expression before.
3858 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3859 return S;
3860 }
3861
3862 // If there are any constants, fold them together.
3863 unsigned Idx = 0;
3864 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3865 ++Idx;
3866 assert(Idx < Ops.size());
3867 auto FoldOp = [&](const APInt &LHS, const APInt &RHS) {
3868 switch (Kind) {
3869 case scSMaxExpr:
3870 return APIntOps::smax(LHS, RHS);
3871 case scSMinExpr:
3872 return APIntOps::smin(LHS, RHS);
3873 case scUMaxExpr:
3874 return APIntOps::umax(LHS, RHS);
3875 case scUMinExpr:
3876 return APIntOps::umin(LHS, RHS);
3877 default:
3878 llvm_unreachable("Unknown SCEV min/max opcode");
3879 }
3880 };
3881
3882 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3883 // We found two constants, fold them together!
3884 ConstantInt *Fold = ConstantInt::get(
3885 getContext(), FoldOp(LHSC->getAPInt(), RHSC->getAPInt()));
3886 Ops[0] = getConstant(Fold);
3887 Ops.erase(Ops.begin()+1); // Erase the folded element
3888 if (Ops.size() == 1) return Ops[0];
3889 LHSC = cast<SCEVConstant>(Ops[0]);
3890 }
3891
3892 bool IsMinV = LHSC->getValue()->isMinValue(IsSigned);
3893 bool IsMaxV = LHSC->getValue()->isMaxValue(IsSigned);
3894
3895 if (IsMax ? IsMinV : IsMaxV) {
3896 // If we are left with a constant minimum(/maximum)-int, strip it off.
3897 Ops.erase(Ops.begin());
3898 --Idx;
3899 } else if (IsMax ? IsMaxV : IsMinV) {
3900 // If we have a max(/min) with a constant maximum(/minimum)-int,
3901 // it will always be the extremum.
3902 return LHSC;
3903 }
3904
3905 if (Ops.size() == 1) return Ops[0];
3906 }
3907
3908 // Find the first operation of the same kind
3909 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3910 ++Idx;
3911
3912 // Check to see if one of the operands is of the same kind. If so, expand its
3913 // operands onto our operand list, and recurse to simplify.
3914 if (Idx < Ops.size()) {
3915 bool DeletedAny = false;
3916 while (Ops[Idx]->getSCEVType() == Kind) {
3917 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3918 Ops.erase(Ops.begin()+Idx);
3919 append_range(Ops, SMME->operands());
3920 DeletedAny = true;
3921 }
3922
3923 if (DeletedAny)
3924 return getMinMaxExpr(Kind, Ops);
3925 }
3926
3927 // Okay, check to see if the same value occurs in the operand list twice. If
3928 // so, delete one. Since we sorted the list, these values are required to
3929 // be adjacent.
3934 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3935 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3936 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3937 if (Ops[i] == Ops[i + 1] ||
3938 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3939 // X op Y op Y --> X op Y
3940 // X op Y --> X, if we know X, Y are ordered appropriately
3941 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3942 --i;
3943 --e;
3944 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3945 Ops[i + 1])) {
3946 // X op Y --> Y, if we know X, Y are ordered appropriately
3947 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3948 --i;
3949 --e;
3950 }
3951 }
3952
3953 if (Ops.size() == 1) return Ops[0];
3954
3955 assert(!Ops.empty() && "Reduced smax down to nothing!");
3956
3957 // Okay, it looks like we really DO need an expr. Check to see if we
3958 // already have one, otherwise create a new one.
3960 ID.AddInteger(Kind);
3961 for (const SCEV *Op : Ops)
3962 ID.AddPointer(Op);
3963 void *IP = nullptr;
3964 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3965 if (ExistingSCEV)
3966 return ExistingSCEV;
3967 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3968 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3969 SCEV *S = new (SCEVAllocator)
3970 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3971
3972 UniqueSCEVs.InsertNode(S, IP);
3973 registerUser(S, Ops);
3974 return S;
3975}
3976
3977namespace {
3978
3979class SCEVSequentialMinMaxDeduplicatingVisitor final
3980 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
3981 std::optional<const SCEV *>> {
3982 using RetVal = std::optional<const SCEV *>;
3984
3985 ScalarEvolution &SE;
3986 const SCEVTypes RootKind; // Must be a sequential min/max expression.
3987 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
3989
3990 bool canRecurseInto(SCEVTypes Kind) const {
3991 // We can only recurse into the SCEV expression of the same effective type
3992 // as the type of our root SCEV expression.
3993 return RootKind == Kind || NonSequentialRootKind == Kind;
3994 };
3995
3996 RetVal visitAnyMinMaxExpr(const SCEV *S) {
3997 assert((isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) &&
3998 "Only for min/max expressions.");
3999 SCEVTypes Kind = S->getSCEVType();
4000
4001 if (!canRecurseInto(Kind))
4002 return S;
4003
4004 auto *NAry = cast<SCEVNAryExpr>(S);
4006 bool Changed = visit(Kind, NAry->operands(), NewOps);
4007
4008 if (!Changed)
4009 return S;
4010 if (NewOps.empty())
4011 return std::nullopt;
4012
4013 return isa<SCEVSequentialMinMaxExpr>(S)
4014 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
4015 : SE.getMinMaxExpr(Kind, NewOps);
4016 }
4017
4018 RetVal visit(const SCEV *S) {
4019 // Has the whole operand been seen already?
4020 if (!SeenOps.insert(S).second)
4021 return std::nullopt;
4022 return Base::visit(S);
4023 }
4024
4025public:
4026 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4027 SCEVTypes RootKind)
4028 : SE(SE), RootKind(RootKind),
4029 NonSequentialRootKind(
4030 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4031 RootKind)) {}
4032
4033 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
4035 bool Changed = false;
4037 Ops.reserve(OrigOps.size());
4038
4039 for (const SCEV *Op : OrigOps) {
4040 RetVal NewOp = visit(Op);
4041 if (NewOp != Op)
4042 Changed = true;
4043 if (NewOp)
4044 Ops.emplace_back(*NewOp);
4045 }
4046
4047 if (Changed)
4048 NewOps = std::move(Ops);
4049 return Changed;
4050 }
4051
4052 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4053
4054 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4055
4056 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4057
4058 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4059
4060 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4061
4062 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4063
4064 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4065
4066 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4067
4068 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4069
4070 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4071
4072 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4073 return visitAnyMinMaxExpr(Expr);
4074 }
4075
4076 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4077 return visitAnyMinMaxExpr(Expr);
4078 }
4079
4080 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4081 return visitAnyMinMaxExpr(Expr);
4082 }
4083
4084 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4085 return visitAnyMinMaxExpr(Expr);
4086 }
4087
4088 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4089 return visitAnyMinMaxExpr(Expr);
4090 }
4091
4092 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4093
4094 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4095};
4096
4097} // namespace
4098
4100 switch (Kind) {
4101 case scConstant:
4102 case scVScale:
4103 case scTruncate:
4104 case scZeroExtend:
4105 case scSignExtend:
4106 case scPtrToInt:
4107 case scAddExpr:
4108 case scMulExpr:
4109 case scUDivExpr:
4110 case scAddRecExpr:
4111 case scUMaxExpr:
4112 case scSMaxExpr:
4113 case scUMinExpr:
4114 case scSMinExpr:
4115 case scUnknown:
4116 // If any operand is poison, the whole expression is poison.
4117 return true;
4119 // FIXME: if the *first* operand is poison, the whole expression is poison.
4120 return false; // Pessimistically, say that it does not propagate poison.
4121 case scCouldNotCompute:
4122 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4123 }
4124 llvm_unreachable("Unknown SCEV kind!");
4125}
4126
4127namespace {
4128// The only way poison may be introduced in a SCEV expression is from a
4129// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4130// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4131// introduce poison -- they encode guaranteed, non-speculated knowledge.
4132//
4133// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4134// with the notable exception of umin_seq, where only poison from the first
4135// operand is (unconditionally) propagated.
4136struct SCEVPoisonCollector {
4137 bool LookThroughMaybePoisonBlocking;
4139 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4140 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4141
4142 bool follow(const SCEV *S) {
4143 if (!LookThroughMaybePoisonBlocking &&
4145 return false;
4146
4147 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4148 if (!isGuaranteedNotToBePoison(SU->getValue()))
4149 MaybePoison.insert(SU);
4150 }
4151 return true;
4152 }
4153 bool isDone() const { return false; }
4154};
4155} // namespace
4156
4157/// Return true if V is poison given that AssumedPoison is already poison.
4158static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4159 // First collect all SCEVs that might result in AssumedPoison to be poison.
4160 // We need to look through potentially poison-blocking operations here,
4161 // because we want to find all SCEVs that *might* result in poison, not only
4162 // those that are *required* to.
4163 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4164 visitAll(AssumedPoison, PC1);
4165
4166 // AssumedPoison is never poison. As the assumption is false, the implication
4167 // is true. Don't bother walking the other SCEV in this case.
4168 if (PC1.MaybePoison.empty())
4169 return true;
4170
4171 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4172 // as well. We cannot look through potentially poison-blocking operations
4173 // here, as their arguments only *may* make the result poison.
4174 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4175 visitAll(S, PC2);
4176
4177 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4178 // it will also make S poison by being part of PC2.MaybePoison.
4179 return all_of(PC1.MaybePoison, [&](const SCEVUnknown *S) {
4180 return PC2.MaybePoison.contains(S);
4181 });
4182}
4183
4185 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4186 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4187 visitAll(S, PC);
4188 for (const SCEVUnknown *SU : PC.MaybePoison)
4189 Result.insert(SU->getValue());
4190}
4191
4193 const SCEV *S, Instruction *I,
4194 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4195 // If the instruction cannot be poison, it's always safe to reuse.
4197 return true;
4198
4199 // Otherwise, it is possible that I is more poisonous that S. Collect the
4200 // poison-contributors of S, and then check whether I has any additional
4201 // poison-contributors. Poison that is contributed through poison-generating
4202 // flags is handled by dropping those flags instead.
4204 getPoisonGeneratingValues(PoisonVals, S);
4205
4206 SmallVector<Value *> Worklist;
4208 Worklist.push_back(I);
4209 while (!Worklist.empty()) {
4210 Value *V = Worklist.pop_back_val();
4211 if (!Visited.insert(V).second)
4212 continue;
4213
4214 // Avoid walking large instruction graphs.
4215 if (Visited.size() > 16)
4216 return false;
4217
4218 // Either the value can't be poison, or the S would also be poison if it
4219 // is.
4220 if (PoisonVals.contains(V) || isGuaranteedNotToBePoison(V))
4221 continue;
4222
4223 auto *I = dyn_cast<Instruction>(V);
4224 if (!I)
4225 return false;
4226
4227 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4228 // can't replace an arbitrary add with disjoint or, even if we drop the
4229 // flag. We would need to convert the or into an add.
4230 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4231 if (PDI->isDisjoint())
4232 return false;
4233
4234 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4235 // because SCEV currently assumes it can't be poison. Remove this special
4236 // case once we proper model when vscale can be poison.
4237 if (auto *II = dyn_cast<IntrinsicInst>(I);
4238 II && II->getIntrinsicID() == Intrinsic::vscale)
4239 continue;
4240
4241 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4242 return false;
4243
4244 // If the instruction can't create poison, we can recurse to its operands.
4245 if (I->hasPoisonGeneratingAnnotations())
4246 DropPoisonGeneratingInsts.push_back(I);
4247
4248 for (Value *Op : I->operands())
4249 Worklist.push_back(Op);
4250 }
4251 return true;
4252}
4253
4254const SCEV *
4257 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4258 "Not a SCEVSequentialMinMaxExpr!");
4259 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4260 if (Ops.size() == 1)
4261 return Ops[0];
4262#ifndef NDEBUG
4263 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4264 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4265 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4266 "Operand types don't match!");
4267 assert(Ops[0]->getType()->isPointerTy() ==
4268 Ops[i]->getType()->isPointerTy() &&
4269 "min/max should be consistently pointerish");
4270 }
4271#endif
4272
4273 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4274 // so we can *NOT* do any kind of sorting of the expressions!
4275
4276 // Check if we have created the same expression before.
4277 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4278 return S;
4279
4280 // FIXME: there are *some* simplifications that we can do here.
4281
4282 // Keep only the first instance of an operand.
4283 {
4284 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4285 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4286 if (Changed)
4287 return getSequentialMinMaxExpr(Kind, Ops);
4288 }
4289
4290 // Check to see if one of the operands is of the same kind. If so, expand its
4291 // operands onto our operand list, and recurse to simplify.
4292 {
4293 unsigned Idx = 0;
4294 bool DeletedAny = false;
4295 while (Idx < Ops.size()) {
4296 if (Ops[Idx]->getSCEVType() != Kind) {
4297 ++Idx;
4298 continue;
4299 }
4300 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4301 Ops.erase(Ops.begin() + Idx);
4302 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4303 SMME->operands().end());
4304 DeletedAny = true;
4305 }
4306
4307 if (DeletedAny)
4308 return getSequentialMinMaxExpr(Kind, Ops);
4309 }
4310
4311 const SCEV *SaturationPoint;
4313 switch (Kind) {
4315 SaturationPoint = getZero(Ops[0]->getType());
4316 Pred = ICmpInst::ICMP_ULE;
4317 break;
4318 default:
4319 llvm_unreachable("Not a sequential min/max type.");
4320 }
4321
4322 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4323 // We can replace %x umin_seq %y with %x umin %y if either:
4324 // * %y being poison implies %x is also poison.
4325 // * %x cannot be the saturating value (e.g. zero for umin).
4326 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4327 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4328 SaturationPoint)) {
4329 SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
4330 Ops[i - 1] = getMinMaxExpr(
4332 SeqOps);
4333 Ops.erase(Ops.begin() + i);
4334 return getSequentialMinMaxExpr(Kind, Ops);
4335 }
4336 // Fold %x umin_seq %y to %x if %x ule %y.
4337 // TODO: We might be able to prove the predicate for a later operand.
4338 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4339 Ops.erase(Ops.begin() + i);
4340 return getSequentialMinMaxExpr(Kind, Ops);
4341 }
4342 }
4343
4344 // Okay, it looks like we really DO need an expr. Check to see if we
4345 // already have one, otherwise create a new one.
4347 ID.AddInteger(Kind);
4348 for (const SCEV *Op : Ops)
4349 ID.AddPointer(Op);
4350 void *IP = nullptr;
4351 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4352 if (ExistingSCEV)
4353 return ExistingSCEV;
4354
4355 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4356 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
4357 SCEV *S = new (SCEVAllocator)
4358 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4359
4360 UniqueSCEVs.InsertNode(S, IP);
4361 registerUser(S, Ops);
4362 return S;
4363}
4364
4365const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4367 return getSMaxExpr(Ops);
4368}
4369
4371 return getMinMaxExpr(scSMaxExpr, Ops);
4372}
4373
4374const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4376 return getUMaxExpr(Ops);
4377}
4378
4380 return getMinMaxExpr(scUMaxExpr, Ops);
4381}
4382
4384 const SCEV *RHS) {
4386 return getSMinExpr(Ops);
4387}
4388
4390 return getMinMaxExpr(scSMinExpr, Ops);
4391}
4392
4393const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS,
4394 bool Sequential) {
4396 return getUMinExpr(Ops, Sequential);
4397}
4398
4400 bool Sequential) {
4401 return Sequential ? getSequentialMinMaxExpr(scSequentialUMinExpr, Ops)
4402 : getMinMaxExpr(scUMinExpr, Ops);
4403}
4404
4405const SCEV *
4407 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4408 if (Size.isScalable())
4409 Res = getMulExpr(Res, getVScale(IntTy));
4410 return Res;
4411}
4412
4414 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4415}
4416
4418 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4419}
4420
4422 StructType *STy,
4423 unsigned FieldNo) {
4424 // We can bypass creating a target-independent constant expression and then
4425 // folding it back into a ConstantInt. This is just a compile-time
4426 // optimization.
4427 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4428 assert(!SL->getSizeInBits().isScalable() &&
4429 "Cannot get offset for structure containing scalable vector types");
4430 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4431}
4432
4434 // Don't attempt to do anything other than create a SCEVUnknown object
4435 // here. createSCEV only calls getUnknown after checking for all other
4436 // interesting possibilities, and any other code that calls getUnknown
4437 // is doing so in order to hide a value from SCEV canonicalization.
4438
4440 ID.AddInteger(scUnknown);
4441 ID.AddPointer(V);
4442 void *IP = nullptr;
4443 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4444 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4445 "Stale SCEVUnknown in uniquing map!");
4446 return S;
4447 }
4448 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4449 FirstUnknown);
4450 FirstUnknown = cast<SCEVUnknown>(S);
4451 UniqueSCEVs.InsertNode(S, IP);
4452 return S;
4453}
4454
4455//===----------------------------------------------------------------------===//
4456// Basic SCEV Analysis and PHI Idiom Recognition Code
4457//
4458
4459/// Test if values of the given type are analyzable within the SCEV
4460/// framework. This primarily includes integer types, and it can optionally
4461/// include pointer types if the ScalarEvolution class has access to
4462/// target-specific information.
4464 // Integers and pointers are always SCEVable.
4465 return Ty->isIntOrPtrTy();
4466}
4467
4468/// Return the size in bits of the specified type, for which isSCEVable must
4469/// return true.
4471 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4472 if (Ty->isPointerTy())
4474 return getDataLayout().getTypeSizeInBits(Ty);
4475}
4476
4477/// Return a type with the same bitwidth as the given type and which represents
4478/// how SCEV will treat the given type, for which isSCEVable must return
4479/// true. For pointer types, this is the pointer index sized integer type.
4481 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4482
4483 if (Ty->isIntegerTy())
4484 return Ty;
4485
4486 // The only other support type is pointer.
4487 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4488 return getDataLayout().getIndexType(Ty);
4489}
4490
4492 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4493}
4494
4496 const SCEV *B) {
4497 /// For a valid use point to exist, the defining scope of one operand
4498 /// must dominate the other.
4499 bool PreciseA, PreciseB;
4500 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4501 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4502 if (!PreciseA || !PreciseB)
4503 // Can't tell.
4504 return false;
4505 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4506 DT.dominates(ScopeB, ScopeA);
4507}
4508
4510 return CouldNotCompute.get();
4511}
4512
4513bool ScalarEvolution::checkValidity(const SCEV *S) const {
4514 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4515 auto *SU = dyn_cast<SCEVUnknown>(S);
4516 return SU && SU->getValue() == nullptr;
4517 });
4518
4519 return !ContainsNulls;
4520}
4521
4523 HasRecMapType::iterator I = HasRecMap.find(S);
4524 if (I != HasRecMap.end())
4525 return I->second;
4526
4527 bool FoundAddRec =
4528 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4529 HasRecMap.insert({S, FoundAddRec});
4530 return FoundAddRec;
4531}
4532
4533/// Return the ValueOffsetPair set for \p S. \p S can be represented
4534/// by the value and offset from any ValueOffsetPair in the set.
4535ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4536 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4537 if (SI == ExprValueMap.end())
4538 return std::nullopt;
4539 return SI->second.getArrayRef();
4540}
4541
4542/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4543/// cannot be used separately. eraseValueFromMap should be used to remove
4544/// V from ValueExprMap and ExprValueMap at the same time.
4545void ScalarEvolution::eraseValueFromMap(Value *V) {
4546 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4547 if (I != ValueExprMap.end()) {
4548 auto EVIt = ExprValueMap.find(I->second);
4549 bool Removed = EVIt->second.remove(V);
4550 (void) Removed;
4551 assert(Removed && "Value not in ExprValueMap?");
4552 ValueExprMap.erase(I);
4553 }
4554}
4555
4556void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4557 // A recursive query may have already computed the SCEV. It should be
4558 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4559 // inferred nowrap flags.
4560 auto It = ValueExprMap.find_as(V);
4561 if (It == ValueExprMap.end()) {
4562 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4563 ExprValueMap[S].insert(V);
4564 }
4565}
4566
4567/// Return an existing SCEV if it exists, otherwise analyze the expression and
4568/// create a new one.
4570 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4571
4572 if (const SCEV *S = getExistingSCEV(V))
4573 return S;
4574 return createSCEVIter(V);
4575}
4576
4578 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4579
4580 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4581 if (I != ValueExprMap.end()) {
4582 const SCEV *S = I->second;
4583 assert(checkValidity(S) &&
4584 "existing SCEV has not been properly invalidated");
4585 return S;
4586 }
4587 return nullptr;
4588}
4589
4590/// Return a SCEV corresponding to -V = -1*V
4592 SCEV::NoWrapFlags Flags) {
4593 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4594 return getConstant(
4595 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4596
4597 Type *Ty = V->getType();
4598 Ty = getEffectiveSCEVType(Ty);
4599 return getMulExpr(V, getMinusOne(Ty), Flags);
4600}
4601
4602/// If Expr computes ~A, return A else return nullptr
4603static const SCEV *MatchNotExpr(const SCEV *Expr) {
4604 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
4605 if (!Add || Add->getNumOperands() != 2 ||
4606 !Add->getOperand(0)->isAllOnesValue())
4607 return nullptr;
4608
4609 const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
4610 if (!AddRHS || AddRHS->getNumOperands() != 2 ||
4611 !AddRHS->getOperand(0)->isAllOnesValue())
4612 return nullptr;
4613
4614 return AddRHS->getOperand(1);
4615}
4616
4617/// Return a SCEV corresponding to ~V = -1-V
4619 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4620
4621 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4622 return getConstant(
4623 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4624
4625 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4626 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4627 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4628 SmallVector<const SCEV *, 2> MatchedOperands;
4629 for (const SCEV *Operand : MME->operands()) {
4630 const SCEV *Matched = MatchNotExpr(Operand);
4631 if (!Matched)
4632 return (const SCEV *)nullptr;
4633 MatchedOperands.push_back(Matched);
4634 }
4635 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4636 MatchedOperands);
4637 };
4638 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4639 return Replaced;
4640 }
4641
4642 Type *Ty = V->getType();
4643 Ty = getEffectiveSCEVType(Ty);
4644 return getMinusSCEV(getMinusOne(Ty), V);
4645}
4646
4648 assert(P->getType()->isPointerTy());
4649
4650 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4651 // The base of an AddRec is the first operand.
4652 SmallVector<const SCEV *> Ops{AddRec->operands()};
4653 Ops[0] = removePointerBase(Ops[0]);
4654 // Don't try to transfer nowrap flags for now. We could in some cases
4655 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4656 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4657 }
4658 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4659 // The base of an Add is the pointer operand.
4660 SmallVector<const SCEV *> Ops{Add->operands()};
4661 const SCEV **PtrOp = nullptr;
4662 for (const SCEV *&AddOp : Ops) {
4663 if (AddOp->getType()->isPointerTy()) {
4664 assert(!PtrOp && "Cannot have multiple pointer ops");
4665 PtrOp = &AddOp;
4666 }
4667 }
4668 *PtrOp = removePointerBase(*PtrOp);
4669 // Don't try to transfer nowrap flags for now. We could in some cases
4670 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4671 return getAddExpr(Ops);
4672 }
4673 // Any other expression must be a pointer base.
4674 return getZero(P->getType());
4675}
4676
4677const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4678 SCEV::NoWrapFlags Flags,
4679 unsigned Depth) {
4680 // Fast path: X - X --> 0.
4681 if (LHS == RHS)
4682 return getZero(LHS->getType());
4683
4684 // If we subtract two pointers with different pointer bases, bail.
4685 // Eventually, we're going to add an assertion to getMulExpr that we
4686 // can't multiply by a pointer.
4687 if (RHS->getType()->isPointerTy()) {
4688 if (!LHS->getType()->isPointerTy() ||
4690 return getCouldNotCompute();
4693 }
4694
4695 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4696 // makes it so that we cannot make much use of NUW.
4697 auto AddFlags = SCEV::FlagAnyWrap;
4698 const bool RHSIsNotMinSigned =
4700 if (hasFlags(Flags, SCEV::FlagNSW)) {
4701 // Let M be the minimum representable signed value. Then (-1)*RHS
4702 // signed-wraps if and only if RHS is M. That can happen even for
4703 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4704 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4705 // (-1)*RHS, we need to prove that RHS != M.
4706 //
4707 // If LHS is non-negative and we know that LHS - RHS does not
4708 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4709 // either by proving that RHS > M or that LHS >= 0.
4710 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4711 AddFlags = SCEV::FlagNSW;
4712 }
4713 }
4714
4715 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4716 // RHS is NSW and LHS >= 0.
4717 //
4718 // The difficulty here is that the NSW flag may have been proven
4719 // relative to a loop that is to be found in a recurrence in LHS and
4720 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4721 // larger scope than intended.
4722 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4723
4724 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4725}
4726
4728 unsigned Depth) {
4729 Type *SrcTy = V->getType();
4730 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4731 "Cannot truncate or zero extend with non-integer arguments!");
4732 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4733 return V; // No conversion
4734 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4735 return getTruncateExpr(V, Ty, Depth);
4736 return getZeroExtendExpr(V, Ty, Depth);
4737}
4738
4740 unsigned Depth) {
4741 Type *SrcTy = V->getType();
4742 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4743 "Cannot truncate or zero extend with non-integer arguments!");
4744 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4745 return V; // No conversion
4746 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4747 return getTruncateExpr(V, Ty, Depth);
4748 return getSignExtendExpr(V, Ty, Depth);
4749}
4750
4751const SCEV *
4753 Type *SrcTy = V->getType();
4754 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4755 "Cannot noop or zero extend with non-integer arguments!");
4757 "getNoopOrZeroExtend cannot truncate!");
4758 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4759 return V; // No conversion
4760 return getZeroExtendExpr(V, Ty);
4761}
4762
4763const SCEV *
4765 Type *SrcTy = V->getType();
4766 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4767 "Cannot noop or sign extend with non-integer arguments!");
4769 "getNoopOrSignExtend cannot truncate!");
4770 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4771 return V; // No conversion
4772 return getSignExtendExpr(V, Ty);
4773}
4774
4775const SCEV *
4777 Type *SrcTy = V->getType();
4778 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4779 "Cannot noop or any extend with non-integer arguments!");
4781 "getNoopOrAnyExtend cannot truncate!");
4782 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4783 return V; // No conversion
4784 return getAnyExtendExpr(V, Ty);
4785}
4786
4787const SCEV *
4789 Type *SrcTy = V->getType();
4790 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4791 "Cannot truncate or noop with non-integer arguments!");
4793 "getTruncateOrNoop cannot extend!");
4794 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4795 return V; // No conversion
4796 return getTruncateExpr(V, Ty);
4797}
4798
4800 const SCEV *RHS) {
4801 const SCEV *PromotedLHS = LHS;
4802 const SCEV *PromotedRHS = RHS;
4803
4805 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4806 else
4807 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4808
4809 return getUMaxExpr(PromotedLHS, PromotedRHS);
4810}
4811
4813 const SCEV *RHS,
4814 bool Sequential) {
4816 return getUMinFromMismatchedTypes(Ops, Sequential);
4817}
4818
4819const SCEV *
4821 bool Sequential) {
4822 assert(!Ops.empty() && "At least one operand must be!");
4823 // Trivial case.
4824 if (Ops.size() == 1)
4825 return Ops[0];
4826
4827 // Find the max type first.
4828 Type *MaxType = nullptr;
4829 for (const auto *S : Ops)
4830 if (MaxType)
4831 MaxType = getWiderType(MaxType, S->getType());
4832 else
4833 MaxType = S->getType();
4834 assert(MaxType && "Failed to find maximum type!");
4835
4836 // Extend all ops to max type.
4837 SmallVector<const SCEV *, 2> PromotedOps;
4838 for (const auto *S : Ops)
4839 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4840
4841 // Generate umin.
4842 return getUMinExpr(PromotedOps, Sequential);
4843}
4844
4846 // A pointer operand may evaluate to a nonpointer expression, such as null.
4847 if (!V->getType()->isPointerTy())
4848 return V;
4849
4850 while (true) {
4851 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4852 V = AddRec->getStart();
4853 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4854 const SCEV *PtrOp = nullptr;
4855 for (const SCEV *AddOp : Add->operands()) {
4856 if (AddOp->getType()->isPointerTy()) {
4857 assert(!PtrOp && "Cannot have multiple pointer ops");
4858 PtrOp = AddOp;
4859 }
4860 }
4861 assert(PtrOp && "Must have pointer op");
4862 V = PtrOp;
4863 } else // Not something we can look further into.
4864 return V;
4865 }
4866}
4867
4868/// Push users of the given Instruction onto the given Worklist.
4872 // Push the def-use children onto the Worklist stack.
4873 for (User *U : I->users()) {
4874 auto *UserInsn = cast<Instruction>(U);
4875 if (Visited.insert(UserInsn).second)
4876 Worklist.push_back(UserInsn);
4877 }
4878}
4879
4880namespace {
4881
4882/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4883/// expression in case its Loop is L. If it is not L then
4884/// if IgnoreOtherLoops is true then use AddRec itself
4885/// otherwise rewrite cannot be done.
4886/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4887class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4888public:
4889 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4890 bool IgnoreOtherLoops = true) {
4891 SCEVInitRewriter Rewriter(L, SE);
4892 const SCEV *Result = Rewriter.visit(S);
4893 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4894 return SE.getCouldNotCompute();
4895 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4896 ? SE.getCouldNotCompute()
4897 : Result;
4898 }
4899
4900 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4901 if (!SE.isLoopInvariant(Expr, L))
4902 SeenLoopVariantSCEVUnknown = true;
4903 return Expr;
4904 }
4905
4906 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4907 // Only re-write AddRecExprs for this loop.
4908 if (Expr->getLoop() == L)
4909 return Expr->getStart();
4910 SeenOtherLoops = true;
4911 return Expr;
4912 }
4913
4914 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4915
4916 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4917
4918private:
4919 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4920 : SCEVRewriteVisitor(SE), L(L) {}
4921
4922 const Loop *L;
4923 bool SeenLoopVariantSCEVUnknown = false;
4924 bool SeenOtherLoops = false;
4925};
4926
4927/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4928/// increment expression in case its Loop is L. If it is not L then
4929/// use AddRec itself.
4930/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4931class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4932public:
4933 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4934 SCEVPostIncRewriter Rewriter(L, SE);
4935 const SCEV *Result = Rewriter.visit(S);
4936 return Rewriter.hasSeenLoopVariantSCEVUnknown()
4937 ? SE.getCouldNotCompute()
4938 : Result;
4939 }
4940
4941 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4942 if (!SE.isLoopInvariant(Expr, L))
4943 SeenLoopVariantSCEVUnknown = true;
4944 return Expr;
4945 }
4946
4947 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4948 // Only re-write AddRecExprs for this loop.
4949 if (Expr->getLoop() == L)
4950 return Expr->getPostIncExpr(SE);
4951 SeenOtherLoops = true;
4952 return Expr;
4953 }
4954
4955 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4956
4957 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4958
4959private:
4960 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
4961 : SCEVRewriteVisitor(SE), L(L) {}
4962
4963 const Loop *L;
4964 bool SeenLoopVariantSCEVUnknown = false;
4965 bool SeenOtherLoops = false;
4966};
4967
4968/// This class evaluates the compare condition by matching it against the
4969/// condition of loop latch. If there is a match we assume a true value
4970/// for the condition while building SCEV nodes.
4971class SCEVBackedgeConditionFolder
4972 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
4973public:
4974 static const SCEV *rewrite(const SCEV *S, const Loop *L,
4975 ScalarEvolution &SE) {
4976 bool IsPosBECond = false;
4977 Value *BECond = nullptr;
4978 if (BasicBlock *Latch = L->getLoopLatch()) {
4979 BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
4980 if (BI && BI->isConditional()) {
4981 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
4982 "Both outgoing branches should not target same header!");
4983 BECond = BI->getCondition();
4984 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
4985 } else {
4986 return S;
4987 }
4988 }
4989 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
4990 return Rewriter.visit(S);
4991 }
4992
4993 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4994 const SCEV *Result = Expr;
4995 bool InvariantF = SE.isLoopInvariant(Expr, L);
4996
4997 if (!InvariantF) {
4998 Instruction *I = cast<Instruction>(Expr->getValue());
4999 switch (I->getOpcode()) {
5000 case Instruction::Select: {
5001 SelectInst *SI = cast<SelectInst>(I);
5002 std::optional<const SCEV *> Res =
5003 compareWithBackedgeCondition(SI->getCondition());
5004 if (Res) {
5005 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
5006 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
5007 }
5008 break;
5009 }
5010 default: {
5011 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
5012 if (Res)
5013 Result = *Res;
5014 break;
5015 }
5016 }
5017 }
5018 return Result;
5019 }
5020
5021private:
5022 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5023 bool IsPosBECond, ScalarEvolution &SE)
5024 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5025 IsPositiveBECond(IsPosBECond) {}
5026
5027 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5028
5029 const Loop *L;
5030 /// Loop back condition.
5031 Value *BackedgeCond = nullptr;
5032 /// Set to true if loop back is on positive branch condition.
5033 bool IsPositiveBECond;
5034};
5035
5036std::optional<const SCEV *>
5037SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5038
5039 // If value matches the backedge condition for loop latch,
5040 // then return a constant evolution node based on loopback
5041 // branch taken.
5042 if (BackedgeCond == IC)
5043 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5045 return std::nullopt;
5046}
5047
5048class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5049public:
5050 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5051 ScalarEvolution &SE) {
5052 SCEVShiftRewriter Rewriter(L, SE);
5053 const SCEV *Result = Rewriter.visit(S);
5054 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5055 }
5056
5057 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5058 // Only allow AddRecExprs for this loop.
5059 if (!SE.isLoopInvariant(Expr, L))
5060 Valid = false;
5061 return Expr;
5062 }
5063
5064 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5065 if (Expr->getLoop() == L && Expr->isAffine())
5066 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5067 Valid = false;
5068 return Expr;
5069 }
5070
5071 bool isValid() { return Valid; }
5072
5073private:
5074 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5075 : SCEVRewriteVisitor(SE), L(L) {}
5076
5077 const Loop *L;
5078 bool Valid = true;
5079};
5080
5081} // end anonymous namespace
5082
5084ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5085 if (!AR->isAffine())
5086 return SCEV::FlagAnyWrap;
5087
5088 using OBO = OverflowingBinaryOperator;
5089
5091
5092 if (!AR->hasNoSelfWrap()) {
5093 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5094 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5095 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5096 const APInt &BECountAP = BECountMax->getAPInt();
5097 unsigned NoOverflowBitWidth =
5098 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5099 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5101 }
5102 }
5103
5104 if (!AR->hasNoSignedWrap()) {
5105 ConstantRange AddRecRange = getSignedRange(AR);
5106 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5107
5109 Instruction::Add, IncRange, OBO::NoSignedWrap);
5110 if (NSWRegion.contains(AddRecRange))
5112 }
5113
5114 if (!AR->hasNoUnsignedWrap()) {
5115 ConstantRange AddRecRange = getUnsignedRange(AR);
5116 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5117
5119 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5120 if (NUWRegion.contains(AddRecRange))
5122 }
5123
5124 return Result;
5125}
5126
5128ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5130
5131 if (AR->hasNoSignedWrap())
5132 return Result;
5133
5134 if (!AR->isAffine())
5135 return Result;
5136
5137 // This function can be expensive, only try to prove NSW once per AddRec.
5138 if (!SignedWrapViaInductionTried.insert(AR).second)
5139 return Result;
5140
5141 const SCEV *Step = AR->getStepRecurrence(*this);
5142 const Loop *L = AR->getLoop();
5143
5144 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5145 // Note that this serves two purposes: It filters out loops that are
5146 // simply not analyzable, and it covers the case where this code is
5147 // being called from within backedge-taken count analysis, such that
5148 // attempting to ask for the backedge-taken count would likely result
5149 // in infinite recursion. In the later case, the analysis code will
5150 // cope with a conservative value, and it will take care to purge
5151 // that value once it has finished.
5152 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5153
5154 // Normally, in the cases we can prove no-overflow via a
5155 // backedge guarding condition, we can also compute a backedge
5156 // taken count for the loop. The exceptions are assumptions and
5157 // guards present in the loop -- SCEV is not great at exploiting
5158 // these to compute max backedge taken counts, but can still use
5159 // these to prove lack of overflow. Use this fact to avoid
5160 // doing extra work that may not pay off.
5161
5162 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5163 AC.assumptions().empty())
5164 return Result;
5165
5166 // If the backedge is guarded by a comparison with the pre-inc value the
5167 // addrec is safe. Also, if the entry is guarded by a comparison with the
5168 // start value and the backedge is guarded by a comparison with the post-inc
5169 // value, the addrec is safe.
5171 const SCEV *OverflowLimit =
5172 getSignedOverflowLimitForStep(Step, &Pred, this);
5173 if (OverflowLimit &&
5174 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5175 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5176 Result = setFlags(Result, SCEV::FlagNSW);
5177 }
5178 return Result;
5179}
5181ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5183
5184 if (AR->hasNoUnsignedWrap())
5185 return Result;
5186
5187 if (!AR->isAffine())
5188 return Result;
5189
5190 // This function can be expensive, only try to prove NUW once per AddRec.
5191 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5192 return Result;
5193
5194 const SCEV *Step = AR->getStepRecurrence(*this);
5195 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5196 const Loop *L = AR->getLoop();
5197
5198 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5199 // Note that this serves two purposes: It filters out loops that are
5200 // simply not analyzable, and it covers the case where this code is
5201 // being called from within backedge-taken count analysis, such that
5202 // attempting to ask for the backedge-taken count would likely result
5203 // in infinite recursion. In the later case, the analysis code will
5204 // cope with a conservative value, and it will take care to purge
5205 // that value once it has finished.
5206 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5207
5208 // Normally, in the cases we can prove no-overflow via a
5209 // backedge guarding condition, we can also compute a backedge
5210 // taken count for the loop. The exceptions are assumptions and
5211 // guards present in the loop -- SCEV is not great at exploiting
5212 // these to compute max backedge taken counts, but can still use
5213 // these to prove lack of overflow. Use this fact to avoid
5214 // doing extra work that may not pay off.
5215
5216 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5217 AC.assumptions().empty())
5218 return Result;
5219
5220 // If the backedge is guarded by a comparison with the pre-inc value the
5221 // addrec is safe. Also, if the entry is guarded by a comparison with the
5222 // start value and the backedge is guarded by a comparison with the post-inc
5223 // value, the addrec is safe.
5224 if (isKnownPositive(Step)) {
5226 getUnsignedRangeMax(Step));
5229 Result = setFlags(Result, SCEV::FlagNUW);
5230 }
5231 }
5232
5233 return Result;
5234}
5235
5236namespace {
5237
5238/// Represents an abstract binary operation. This may exist as a
5239/// normal instruction or constant expression, or may have been
5240/// derived from an expression tree.
5241struct BinaryOp {
5242 unsigned Opcode;
5243 Value *LHS;
5244 Value *RHS;
5245 bool IsNSW = false;
5246 bool IsNUW = false;
5247
5248 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5249 /// constant expression.
5250 Operator *Op = nullptr;
5251
5252 explicit BinaryOp(Operator *Op)
5253 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5254 Op(Op) {
5255 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5256 IsNSW = OBO->hasNoSignedWrap();
5257 IsNUW = OBO->hasNoUnsignedWrap();
5258 }
5259 }
5260
5261 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5262 bool IsNUW = false)
5263 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5264};
5265
5266} // end anonymous namespace
5267
5268/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5269static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5270 AssumptionCache &AC,
5271 const DominatorTree &DT,
5272 const Instruction *CxtI) {
5273 auto *Op = dyn_cast<Operator>(V);
5274 if (!Op)
5275 return std::nullopt;
5276
5277 // Implementation detail: all the cleverness here should happen without
5278 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5279 // SCEV expressions when possible, and we should not break that.
5280
5281 switch (Op->getOpcode()) {
5282 case Instruction::Add:
5283 case Instruction::Sub:
5284 case Instruction::Mul:
5285 case Instruction::UDiv:
5286 case Instruction::URem:
5287 case Instruction::And:
5288 case Instruction::AShr:
5289 case Instruction::Shl:
5290 return BinaryOp(Op);
5291
5292 case Instruction::Or: {
5293 // Convert or disjoint into add nuw nsw.
5294 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5295 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5296 /*IsNSW=*/true, /*IsNUW=*/true);
5297 return BinaryOp(Op);
5298 }
5299
5300 case Instruction::Xor:
5301 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5302 // If the RHS of the xor is a signmask, then this is just an add.
5303 // Instcombine turns add of signmask into xor as a strength reduction step.
5304 if (RHSC->getValue().isSignMask())
5305 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5306 // Binary `xor` is a bit-wise `add`.
5307 if (V->getType()->isIntegerTy(1))
5308 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5309 return BinaryOp(Op);
5310
5311 case Instruction::LShr:
5312 // Turn logical shift right of a constant into a unsigned divide.
5313 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5314 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5315
5316 // If the shift count is not less than the bitwidth, the result of
5317 // the shift is undefined. Don't try to analyze it, because the
5318 // resolution chosen here may differ from the resolution chosen in
5319 // other parts of the compiler.
5320 if (SA->getValue().ult(BitWidth)) {
5321 Constant *X =
5322 ConstantInt::get(SA->getContext(),
5323 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5324 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5325 }
5326 }
5327 return BinaryOp(Op);
5328
5329 case Instruction::ExtractValue: {
5330 auto *EVI = cast<ExtractValueInst>(Op);
5331 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5332 break;
5333
5334 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5335 if (!WO)
5336 break;
5337
5338 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5339 bool Signed = WO->isSigned();
5340 // TODO: Should add nuw/nsw flags for mul as well.
5341 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5342 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5343
5344 // Now that we know that all uses of the arithmetic-result component of
5345 // CI are guarded by the overflow check, we can go ahead and pretend
5346 // that the arithmetic is non-overflowing.
5347 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5348 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5349 }
5350
5351 default:
5352 break;
5353 }
5354
5355 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5356 // semantics as a Sub, return a binary sub expression.
5357 if (auto *II = dyn_cast<IntrinsicInst>(V))
5358 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5359 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5360
5361 return std::nullopt;
5362}
5363
5364/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5365/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5366/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5367/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5368/// follows one of the following patterns:
5369/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5370/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5371/// If the SCEV expression of \p Op conforms with one of the expected patterns
5372/// we return the type of the truncation operation, and indicate whether the
5373/// truncated type should be treated as signed/unsigned by setting
5374/// \p Signed to true/false, respectively.
5375static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5376 bool &Signed, ScalarEvolution &SE) {
5377 // The case where Op == SymbolicPHI (that is, with no type conversions on
5378 // the way) is handled by the regular add recurrence creating logic and
5379 // would have already been triggered in createAddRecForPHI. Reaching it here
5380 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5381 // because one of the other operands of the SCEVAddExpr updating this PHI is
5382 // not invariant).
5383 //
5384 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5385 // this case predicates that allow us to prove that Op == SymbolicPHI will
5386 // be added.
5387 if (Op == SymbolicPHI)
5388 return nullptr;
5389
5390 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5391 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5392 if (SourceBits != NewBits)
5393 return nullptr;
5394
5395 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op);
5396 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op);
5397 if (!SExt && !ZExt)
5398 return nullptr;
5399 const SCEVTruncateExpr *Trunc =
5400 SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand())
5401 : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand());
5402 if (!Trunc)
5403 return nullptr;
5404 const SCEV *X = Trunc->getOperand();
5405 if (X != SymbolicPHI)
5406 return nullptr;
5407 Signed = SExt != nullptr;
5408 return Trunc->getType();
5409}
5410
5411static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5412 if (!PN->getType()->isIntegerTy())
5413 return nullptr;
5414 const Loop *L = LI.getLoopFor(PN->getParent());
5415 if (!L || L->getHeader() != PN->getParent())
5416 return nullptr;
5417 return L;
5418}
5419
5420// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5421// computation that updates the phi follows the following pattern:
5422// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5423// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5424// If so, try to see if it can be rewritten as an AddRecExpr under some
5425// Predicates. If successful, return them as a pair. Also cache the results
5426// of the analysis.
5427//
5428// Example usage scenario:
5429// Say the Rewriter is called for the following SCEV:
5430// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5431// where:
5432// %X = phi i64 (%Start, %BEValue)
5433// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5434// and call this function with %SymbolicPHI = %X.
5435//
5436// The analysis will find that the value coming around the backedge has
5437// the following SCEV:
5438// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5439// Upon concluding that this matches the desired pattern, the function
5440// will return the pair {NewAddRec, SmallPredsVec} where:
5441// NewAddRec = {%Start,+,%Step}
5442// SmallPredsVec = {P1, P2, P3} as follows:
5443// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5444// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5445// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5446// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5447// under the predicates {P1,P2,P3}.
5448// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5449// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5450//
5451// TODO's:
5452//
5453// 1) Extend the Induction descriptor to also support inductions that involve
5454// casts: When needed (namely, when we are called in the context of the
5455// vectorizer induction analysis), a Set of cast instructions will be
5456// populated by this method, and provided back to isInductionPHI. This is
5457// needed to allow the vectorizer to properly record them to be ignored by
5458// the cost model and to avoid vectorizing them (otherwise these casts,
5459// which are redundant under the runtime overflow checks, will be
5460// vectorized, which can be costly).
5461//
5462// 2) Support additional induction/PHISCEV patterns: We also want to support
5463// inductions where the sext-trunc / zext-trunc operations (partly) occur
5464// after the induction update operation (the induction increment):
5465//
5466// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5467// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5468//
5469// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5470// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5471//
5472// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5473std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5474ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5476
5477 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5478 // return an AddRec expression under some predicate.
5479
5480 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5481 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5482 assert(L && "Expecting an integer loop header phi");
5483
5484 // The loop may have multiple entrances or multiple exits; we can analyze
5485 // this phi as an addrec if it has a unique entry value and a unique
5486 // backedge value.
5487 Value *BEValueV = nullptr, *StartValueV = nullptr;
5488 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5489 Value *V = PN->getIncomingValue(i);
5490 if (L->contains(PN->getIncomingBlock(i))) {
5491 if (!BEValueV) {
5492 BEValueV = V;
5493 } else if (BEValueV != V) {
5494 BEValueV = nullptr;
5495 break;
5496 }
5497 } else if (!StartValueV) {
5498 StartValueV = V;
5499 } else if (StartValueV != V) {
5500 StartValueV = nullptr;
5501 break;
5502 }
5503 }
5504 if (!BEValueV || !StartValueV)
5505 return std::nullopt;
5506
5507 const SCEV *BEValue = getSCEV(BEValueV);
5508
5509 // If the value coming around the backedge is an add with the symbolic
5510 // value we just inserted, possibly with casts that we can ignore under
5511 // an appropriate runtime guard, then we found a simple induction variable!
5512 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5513 if (!Add)
5514 return std::nullopt;
5515
5516 // If there is a single occurrence of the symbolic value, possibly
5517 // casted, replace it with a recurrence.
5518 unsigned FoundIndex = Add->getNumOperands();
5519 Type *TruncTy = nullptr;
5520 bool Signed;
5521 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5522 if ((TruncTy =
5523 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5524 if (FoundIndex == e) {
5525 FoundIndex = i;
5526 break;
5527 }
5528
5529 if (FoundIndex == Add->getNumOperands())
5530 return std::nullopt;
5531
5532 // Create an add with everything but the specified operand.
5534 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5535 if (i != FoundIndex)
5536 Ops.push_back(Add->getOperand(i));
5537 const SCEV *Accum = getAddExpr(Ops);
5538
5539 // The runtime checks will not be valid if the step amount is
5540 // varying inside the loop.
5541 if (!isLoopInvariant(Accum, L))
5542 return std::nullopt;
5543
5544 // *** Part2: Create the predicates
5545
5546 // Analysis was successful: we have a phi-with-cast pattern for which we
5547 // can return an AddRec expression under the following predicates:
5548 //
5549 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5550 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5551 // P2: An Equal predicate that guarantees that
5552 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5553 // P3: An Equal predicate that guarantees that
5554 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5555 //
5556 // As we next prove, the above predicates guarantee that:
5557 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5558 //
5559 //
5560 // More formally, we want to prove that:
5561 // Expr(i+1) = Start + (i+1) * Accum
5562 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5563 //
5564 // Given that:
5565 // 1) Expr(0) = Start
5566 // 2) Expr(1) = Start + Accum
5567 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5568 // 3) Induction hypothesis (step i):
5569 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5570 //
5571 // Proof:
5572 // Expr(i+1) =
5573 // = Start + (i+1)*Accum
5574 // = (Start + i*Accum) + Accum
5575 // = Expr(i) + Accum
5576 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5577 // :: from step i
5578 //
5579 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5580 //
5581 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5582 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5583 // + Accum :: from P3
5584 //
5585 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5586 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5587 //
5588 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5589 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5590 //
5591 // By induction, the same applies to all iterations 1<=i<n:
5592 //
5593
5594 // Create a truncated addrec for which we will add a no overflow check (P1).
5595 const SCEV *StartVal = getSCEV(StartValueV);
5596 const SCEV *PHISCEV =
5597 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5598 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5599
5600 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5601 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5602 // will be constant.
5603 //
5604 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5605 // add P1.
5606 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5610 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5611 Predicates.push_back(AddRecPred);
5612 }
5613
5614 // Create the Equal Predicates P2,P3:
5615
5616 // It is possible that the predicates P2 and/or P3 are computable at
5617 // compile time due to StartVal and/or Accum being constants.
5618 // If either one is, then we can check that now and escape if either P2
5619 // or P3 is false.
5620
5621 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5622 // for each of StartVal and Accum
5623 auto getExtendedExpr = [&](const SCEV *Expr,
5624 bool CreateSignExtend) -> const SCEV * {
5625 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5626 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5627 const SCEV *ExtendedExpr =
5628 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5629 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5630 return ExtendedExpr;
5631 };
5632
5633 // Given:
5634 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5635 // = getExtendedExpr(Expr)
5636 // Determine whether the predicate P: Expr == ExtendedExpr
5637 // is known to be false at compile time
5638 auto PredIsKnownFalse = [&](const SCEV *Expr,
5639 const SCEV *ExtendedExpr) -> bool {
5640 return Expr != ExtendedExpr &&
5641 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5642 };
5643
5644 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5645 if (PredIsKnownFalse(StartVal, StartExtended)) {
5646 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5647 return std::nullopt;
5648 }
5649
5650 // The Step is always Signed (because the overflow checks are either
5651 // NSSW or NUSW)
5652 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5653 if (PredIsKnownFalse(Accum, AccumExtended)) {
5654 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5655 return std::nullopt;
5656 }
5657
5658 auto AppendPredicate = [&](const SCEV *Expr,
5659 const SCEV *ExtendedExpr) -> void {
5660 if (Expr != ExtendedExpr &&
5661 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5662 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5663 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5664 Predicates.push_back(Pred);
5665 }
5666 };
5667
5668 AppendPredicate(StartVal, StartExtended);
5669 AppendPredicate(Accum, AccumExtended);
5670
5671 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5672 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5673 // into NewAR if it will also add the runtime overflow checks specified in
5674 // Predicates.
5675 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5676
5677 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5678 std::make_pair(NewAR, Predicates);
5679 // Remember the result of the analysis for this SCEV at this locayyytion.
5680 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5681 return PredRewrite;
5682}
5683
5684std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5686 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5687 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5688 if (!L)
5689 return std::nullopt;
5690
5691 // Check to see if we already analyzed this PHI.
5692 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5693 if (I != PredicatedSCEVRewrites.end()) {
5694 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5695 I->second;
5696 // Analysis was done before and failed to create an AddRec:
5697 if (Rewrite.first == SymbolicPHI)
5698 return std::nullopt;
5699 // Analysis was done before and succeeded to create an AddRec under
5700 // a predicate:
5701 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5702 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5703 return Rewrite;
5704 }
5705
5706 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5707 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5708
5709 // Record in the cache that the analysis failed
5710 if (!Rewrite) {
5712 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5713 return std::nullopt;
5714 }
5715
5716 return Rewrite;
5717}
5718
5719// FIXME: This utility is currently required because the Rewriter currently
5720// does not rewrite this expression:
5721// {0, +, (sext ix (trunc iy to ix) to iy)}
5722// into {0, +, %step},
5723// even when the following Equal predicate exists:
5724// "%step == (sext ix (trunc iy to ix) to iy)".
5726 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5727 if (AR1 == AR2)
5728 return true;
5729
5730 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5731 if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) &&
5732 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1)))
5733 return false;
5734 return true;
5735 };
5736
5737 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5738 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5739 return false;
5740 return true;
5741}
5742
5743/// A helper function for createAddRecFromPHI to handle simple cases.
5744///
5745/// This function tries to find an AddRec expression for the simplest (yet most
5746/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5747/// If it fails, createAddRecFromPHI will use a more general, but slow,
5748/// technique for finding the AddRec expression.
5749const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5750 Value *BEValueV,
5751 Value *StartValueV) {
5752 const Loop *L = LI.getLoopFor(PN->getParent());
5753 assert(L && L->getHeader() == PN->getParent());
5754 assert(BEValueV && StartValueV);
5755
5756 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5757 if (!BO)
5758 return nullptr;
5759
5760 if (BO->Opcode != Instruction::Add)
5761 return nullptr;
5762
5763 const SCEV *Accum = nullptr;
5764 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5765 Accum = getSCEV(BO->RHS);
5766 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5767 Accum = getSCEV(BO->LHS);
5768
5769 if (!Accum)
5770 return nullptr;
5771
5773 if (BO->IsNUW)
5774 Flags = setFlags(Flags, SCEV::FlagNUW);
5775 if (BO->IsNSW)
5776 Flags = setFlags(Flags, SCEV::FlagNSW);
5777
5778 const SCEV *StartVal = getSCEV(StartValueV);
5779 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5780 insertValueToMap(PN, PHISCEV);
5781
5782 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5783 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5785 proveNoWrapViaConstantRanges(AR)));
5786 }
5787
5788 // We can add Flags to the post-inc expression only if we
5789 // know that it is *undefined behavior* for BEValueV to
5790 // overflow.
5791 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5792 assert(isLoopInvariant(Accum, L) &&
5793 "Accum is defined outside L, but is not invariant?");
5794 if (isAddRecNeverPoison(BEInst, L))
5795 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5796 }
5797
5798 return PHISCEV;
5799}
5800
5801const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5802 const Loop *L = LI.getLoopFor(PN->getParent());
5803 if (!L || L->getHeader() != PN->getParent())
5804 return nullptr;
5805
5806 // The loop may have multiple entrances or multiple exits; we can analyze
5807 // this phi as an addrec if it has a unique entry value and a unique
5808 // backedge value.
5809 Value *BEValueV = nullptr, *StartValueV = nullptr;
5810 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5811 Value *V = PN->getIncomingValue(i);
5812 if (L->contains(PN->getIncomingBlock(i))) {
5813 if (!BEValueV) {
5814 BEValueV = V;
5815 } else if (BEValueV != V) {
5816 BEValueV = nullptr;
5817 break;
5818 }
5819 } else if (!StartValueV) {
5820 StartValueV = V;
5821 } else if (StartValueV != V) {
5822 StartValueV = nullptr;
5823 break;
5824 }
5825 }
5826 if (!BEValueV || !StartValueV)
5827 return nullptr;
5828
5829 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5830 "PHI node already processed?");
5831
5832 // First, try to find AddRec expression without creating a fictituos symbolic
5833 // value for PN.
5834 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5835 return S;
5836
5837 // Handle PHI node value symbolically.
5838 const SCEV *SymbolicName = getUnknown(PN);
5839 insertValueToMap(PN, SymbolicName);
5840
5841 // Using this symbolic name for the PHI, analyze the value coming around
5842 // the back-edge.
5843 const SCEV *BEValue = getSCEV(BEValueV);
5844
5845 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5846 // has a special value for the first iteration of the loop.
5847
5848 // If the value coming around the backedge is an add with the symbolic
5849 // value we just inserted, then we found a simple induction variable!
5850 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5851 // If there is a single occurrence of the symbolic value, replace it
5852 // with a recurrence.
5853 unsigned FoundIndex = Add->getNumOperands();
5854 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5855 if (Add->getOperand(i) == SymbolicName)
5856 if (FoundIndex == e) {
5857 FoundIndex = i;
5858 break;
5859 }
5860
5861 if (FoundIndex != Add->getNumOperands()) {
5862 // Create an add with everything but the specified operand.
5864 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5865 if (i != FoundIndex)
5866 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5867 L, *this));
5868 const SCEV *Accum = getAddExpr(Ops);
5869
5870 // This is not a valid addrec if the step amount is varying each
5871 // loop iteration, but is not itself an addrec in this loop.
5872 if (isLoopInvariant(Accum, L) ||
5873 (isa<SCEVAddRecExpr>(Accum) &&
5874 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5876
5877 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
5878 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5879 if (BO->IsNUW)
5880 Flags = setFlags(Flags, SCEV::FlagNUW);
5881 if (BO->IsNSW)
5882 Flags = setFlags(Flags, SCEV::FlagNSW);
5883 }
5884 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5885 if (GEP->getOperand(0) == PN) {
5886 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
5887 // If the increment has any nowrap flags, then we know the address
5888 // space cannot be wrapped around.
5889 if (NW != GEPNoWrapFlags::none())
5890 Flags = setFlags(Flags, SCEV::FlagNW);
5891 // If the GEP is nuw or nusw with non-negative offset, we know that
5892 // no unsigned wrap occurs. We cannot set the nsw flag as only the
5893 // offset is treated as signed, while the base is unsigned.
5894 if (NW.hasNoUnsignedWrap() ||
5896 Flags = setFlags(Flags, SCEV::FlagNUW);
5897 }
5898
5899 // We cannot transfer nuw and nsw flags from subtraction
5900 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5901 // for instance.
5902 }
5903
5904 const SCEV *StartVal = getSCEV(StartValueV);
5905 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5906
5907 // Okay, for the entire analysis of this edge we assumed the PHI
5908 // to be symbolic. We now need to go back and purge all of the
5909 // entries for the scalars that use the symbolic expression.
5910 forgetMemoizedResults(SymbolicName);
5911 insertValueToMap(PN, PHISCEV);
5912
5913 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5914 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5916 proveNoWrapViaConstantRanges(AR)));
5917 }
5918
5919 // We can add Flags to the post-inc expression only if we
5920 // know that it is *undefined behavior* for BEValueV to
5921 // overflow.
5922 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5923 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5924 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5925
5926 return PHISCEV;
5927 }
5928 }
5929 } else {
5930 // Otherwise, this could be a loop like this:
5931 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
5932 // In this case, j = {1,+,1} and BEValue is j.
5933 // Because the other in-value of i (0) fits the evolution of BEValue
5934 // i really is an addrec evolution.
5935 //
5936 // We can generalize this saying that i is the shifted value of BEValue
5937 // by one iteration:
5938 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
5939 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
5940 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
5941 if (Shifted != getCouldNotCompute() &&
5942 Start != getCouldNotCompute()) {
5943 const SCEV *StartVal = getSCEV(StartValueV);
5944 if (Start == StartVal) {
5945 // Okay, for the entire analysis of this edge we assumed the PHI
5946 // to be symbolic. We now need to go back and purge all of the
5947 // entries for the scalars that use the symbolic expression.
5948 forgetMemoizedResults(SymbolicName);
5949 insertValueToMap(PN, Shifted);
5950 return Shifted;
5951 }
5952 }
5953 }
5954
5955 // Remove the temporary PHI node SCEV that has been inserted while intending
5956 // to create an AddRecExpr for this PHI node. We can not keep this temporary
5957 // as it will prevent later (possibly simpler) SCEV expressions to be added
5958 // to the ValueExprMap.
5959 eraseValueFromMap(PN);
5960
5961 return nullptr;
5962}
5963
5964// Try to match a control flow sequence that branches out at BI and merges back
5965// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
5966// match.
5968 Value *&C, Value *&LHS, Value *&RHS) {
5969 C = BI->getCondition();
5970
5971 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
5972 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
5973
5974 if (!LeftEdge.isSingleEdge())
5975 return false;
5976
5977 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
5978
5979 Use &LeftUse = Merge->getOperandUse(0);
5980 Use &RightUse = Merge->getOperandUse(1);
5981
5982 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
5983 LHS = LeftUse;
5984 RHS = RightUse;
5985 return true;
5986 }
5987
5988 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
5989 LHS = RightUse;
5990 RHS = LeftUse;
5991 return true;
5992 }
5993
5994 return false;
5995}
5996
5997const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
5998 auto IsReachable =
5999 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
6000 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
6001 // Try to match
6002 //
6003 // br %cond, label %left, label %right
6004 // left:
6005 // br label %merge
6006 // right:
6007 // br label %merge
6008 // merge:
6009 // V = phi [ %x, %left ], [ %y, %right ]
6010 //
6011 // as "select %cond, %x, %y"
6012
6013 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
6014 assert(IDom && "At least the entry block should dominate PN");
6015
6016 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
6017 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6018
6019 if (BI && BI->isConditional() &&
6020 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
6021 properlyDominates(getSCEV(LHS), PN->getParent()) &&
6022 properlyDominates(getSCEV(RHS), PN->getParent()))
6023 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6024 }
6025
6026 return nullptr;
6027}
6028
6029const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6030 if (const SCEV *S = createAddRecFromPHI(PN))
6031 return S;
6032
6033 if (Value *V = simplifyInstruction(PN, {getDataLayout(), &TLI, &DT, &AC}))
6034 return getSCEV(V);
6035
6036 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6037 return S;
6038
6039 // If it's not a loop phi, we can't handle it yet.
6040 return getUnknown(PN);
6041}
6042
6043bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6044 SCEVTypes RootKind) {
6045 struct FindClosure {
6046 const SCEV *OperandToFind;
6047 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6048 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6049
6050 bool Found = false;
6051
6052 bool canRecurseInto(SCEVTypes Kind) const {
6053 // We can only recurse into the SCEV expression of the same effective type
6054 // as the type of our root SCEV expression, and into zero-extensions.
6055 return RootKind == Kind || NonSequentialRootKind == Kind ||
6056 scZeroExtend == Kind;
6057 };
6058
6059 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6060 : OperandToFind(OperandToFind), RootKind(RootKind),
6061 NonSequentialRootKind(
6063 RootKind)) {}
6064
6065 bool follow(const SCEV *S) {
6066 Found = S == OperandToFind;
6067
6068 return !isDone() && canRecurseInto(S->getSCEVType());
6069 }
6070
6071 bool isDone() const { return Found; }
6072 };
6073
6074 FindClosure FC(OperandToFind, RootKind);
6075 visitAll(Root, FC);
6076 return FC.Found;
6077}
6078
6079std::optional<const SCEV *>
6080ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6081 ICmpInst *Cond,
6082 Value *TrueVal,
6083 Value *FalseVal) {
6084 // Try to match some simple smax or umax patterns.
6085 auto *ICI = Cond;
6086
6087 Value *LHS = ICI->getOperand(0);
6088 Value *RHS = ICI->getOperand(1);
6089
6090 switch (ICI->getPredicate()) {
6091 case ICmpInst::ICMP_SLT:
6092 case ICmpInst::ICMP_SLE:
6093 case ICmpInst::ICMP_ULT:
6094 case ICmpInst::ICMP_ULE:
6095 std::swap(LHS, RHS);
6096 [[fallthrough]];
6097 case ICmpInst::ICMP_SGT:
6098 case ICmpInst::ICMP_SGE:
6099 case ICmpInst::ICMP_UGT:
6100 case ICmpInst::ICMP_UGE:
6101 // a > b ? a+x : b+x -> max(a, b)+x
6102 // a > b ? b+x : a+x -> min(a, b)+x
6104 bool Signed = ICI->isSigned();
6105 const SCEV *LA = getSCEV(TrueVal);
6106 const SCEV *RA = getSCEV(FalseVal);
6107 const SCEV *LS = getSCEV(LHS);
6108 const SCEV *RS = getSCEV(RHS);
6109 if (LA->getType()->isPointerTy()) {
6110 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6111 // Need to make sure we can't produce weird expressions involving
6112 // negated pointers.
6113 if (LA == LS && RA == RS)
6114 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6115 if (LA == RS && RA == LS)
6116 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6117 }
6118 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6119 if (Op->getType()->isPointerTy()) {
6121 if (isa<SCEVCouldNotCompute>(Op))
6122 return Op;
6123 }
6124 if (Signed)
6125 Op = getNoopOrSignExtend(Op, Ty);
6126 else
6127 Op = getNoopOrZeroExtend(Op, Ty);
6128 return Op;
6129 };
6130 LS = CoerceOperand(LS);
6131 RS = CoerceOperand(RS);
6132 if (isa<SCEVCouldNotCompute>(LS) || isa<SCEVCouldNotCompute>(RS))
6133 break;
6134 const SCEV *LDiff = getMinusSCEV(LA, LS);
6135 const SCEV *RDiff = getMinusSCEV(RA, RS);
6136 if (LDiff == RDiff)
6137 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6138 LDiff);
6139 LDiff = getMinusSCEV(LA, RS);
6140 RDiff = getMinusSCEV(RA, LS);
6141 if (LDiff == RDiff)
6142 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6143 LDiff);
6144 }
6145 break;
6146 case ICmpInst::ICMP_NE:
6147 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6148 std::swap(TrueVal, FalseVal);
6149 [[fallthrough]];
6150 case ICmpInst::ICMP_EQ:
6151 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6153 isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
6154 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6155 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6156 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6157 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6158 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6159 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6160 return getAddExpr(getUMaxExpr(X, C), Y);
6161 }
6162 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6163 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6164 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6165 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6166 if (isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero() &&
6167 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6168 const SCEV *X = getSCEV(LHS);
6169 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6170 X = ZExt->getOperand();
6171 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6172 const SCEV *FalseValExpr = getSCEV(FalseVal);
6173 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6174 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6175 /*Sequential=*/true);
6176 }
6177 }
6178 break;
6179 default:
6180 break;
6181 }
6182
6183 return std::nullopt;
6184}
6185
6186static std::optional<const SCEV *>
6188 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6189 assert(CondExpr->getType()->isIntegerTy(1) &&
6190 TrueExpr->getType() == FalseExpr->getType() &&
6191 TrueExpr->getType()->isIntegerTy(1) &&
6192 "Unexpected operands of a select.");
6193
6194 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6195 // --> C + (umin_seq cond, x - C)
6196 //
6197 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6198 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6199 // --> C + (umin_seq ~cond, x - C)
6200
6201 // FIXME: while we can't legally model the case where both of the hands
6202 // are fully variable, we only require that the *difference* is constant.
6203 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6204 return std::nullopt;
6205
6206 const SCEV *X, *C;
6207 if (isa<SCEVConstant>(TrueExpr)) {
6208 CondExpr = SE->getNotSCEV(CondExpr);
6209 X = FalseExpr;
6210 C = TrueExpr;
6211 } else {
6212 X = TrueExpr;
6213 C = FalseExpr;
6214 }
6215 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6216 /*Sequential=*/true));
6217}
6218
6219static std::optional<const SCEV *>
6221 Value *FalseVal) {
6222 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6223 return std::nullopt;
6224
6225 const auto *SECond = SE->getSCEV(Cond);
6226 const auto *SETrue = SE->getSCEV(TrueVal);
6227 const auto *SEFalse = SE->getSCEV(FalseVal);
6228 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6229}
6230
6231const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6232 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6233 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6234 assert(TrueVal->getType() == FalseVal->getType() &&
6235 V->getType() == TrueVal->getType() &&
6236 "Types of select hands and of the result must match.");
6237
6238 // For now, only deal with i1-typed `select`s.
6239 if (!V->getType()->isIntegerTy(1))
6240 return getUnknown(V);
6241
6242 if (std::optional<const SCEV *> S =
6243 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6244 return *S;
6245
6246 return getUnknown(V);
6247}
6248
6249const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6250 Value *TrueVal,
6251 Value *FalseVal) {
6252 // Handle "constant" branch or select. This can occur for instance when a
6253 // loop pass transforms an inner loop and moves on to process the outer loop.
6254 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6255 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6256
6257 if (auto *I = dyn_cast<Instruction>(V)) {
6258 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6259 if (std::optional<const SCEV *> S =
6260 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6261 TrueVal, FalseVal))
6262 return *S;
6263 }
6264 }
6265
6266 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6267}
6268
6269/// Expand GEP instructions into add and multiply operations. This allows them
6270/// to be analyzed by regular SCEV code.
6271const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6272 assert(GEP->getSourceElementType()->isSized() &&
6273 "GEP source element type must be sized");
6274
6276 for (Value *Index : GEP->indices())
6277 IndexExprs.push_back(getSCEV(Index));
6278 return getGEPExpr(GEP, IndexExprs);
6279}
6280
6281APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
6283 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6284 return TrailingZeros >= BitWidth
6286 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6287 };
6288 auto GetGCDMultiple = [this](const SCEVNAryExpr *N) {
6289 // The result is GCD of all operands results.
6290 APInt Res = getConstantMultiple(N->getOperand(0));
6291 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6293 Res, getConstantMultiple(N->getOperand(I)));
6294 return Res;
6295 };
6296
6297 switch (S->getSCEVType()) {
6298 case scConstant:
6299 return cast<SCEVConstant>(S)->getAPInt();
6300 case scPtrToInt:
6301 return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand());
6302 case scUDivExpr:
6303 case scVScale:
6304 return APInt(BitWidth, 1);
6305 case scTruncate: {
6306 // Only multiples that are a power of 2 will hold after truncation.
6307 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6308 uint32_t TZ = getMinTrailingZeros(T->getOperand());
6309 return GetShiftedByZeros(TZ);
6310 }
6311 case scZeroExtend: {
6312 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6313 return getConstantMultiple(Z->getOperand()).zext(BitWidth);
6314 }
6315 case scSignExtend: {
6316 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6318 }
6319 case scMulExpr: {
6320 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6321 if (M->hasNoUnsignedWrap()) {
6322 // The result is the product of all operand results.
6323 APInt Res = getConstantMultiple(M->getOperand(0));
6324 for (const SCEV *Operand : M->operands().drop_front())
6325 Res = Res * getConstantMultiple(Operand);
6326 return Res;
6327 }
6328
6329 // If there are no wrap guarentees, find the trailing zeros, which is the
6330 // sum of trailing zeros for all its operands.
6331 uint32_t TZ = 0;
6332 for (const SCEV *Operand : M->operands())
6333 TZ += getMinTrailingZeros(Operand);
6334 return GetShiftedByZeros(TZ);
6335 }
6336 case scAddExpr:
6337 case scAddRecExpr: {
6338 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6339 if (N->hasNoUnsignedWrap())
6340 return GetGCDMultiple(N);
6341 // Find the trailing bits, which is the minimum of its operands.
6342 uint32_t TZ = getMinTrailingZeros(N->getOperand(0));
6343 for (const SCEV *Operand : N->operands().drop_front())
6344 TZ = std::min(TZ, getMinTrailingZeros(Operand));
6345 return GetShiftedByZeros(TZ);
6346 }
6347 case scUMaxExpr:
6348 case scSMaxExpr:
6349 case scUMinExpr:
6350 case scSMinExpr:
6352 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6353 case scUnknown: {
6354 // ask ValueTracking for known bits
6355 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6356 unsigned Known =
6357 computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT)
6358 .countMinTrailingZeros();
6359 return GetShiftedByZeros(Known);
6360 }
6361 case scCouldNotCompute:
6362 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6363 }
6364 llvm_unreachable("Unknown SCEV kind!");
6365}
6366
6368 auto I = ConstantMultipleCache.find(S);
6369 if (I != ConstantMultipleCache.end())
6370 return I->second;
6371
6372 APInt Result = getConstantMultipleImpl(S);
6373 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6374 assert(InsertPair.second && "Should insert a new key");
6375 return InsertPair.first->second;
6376}
6377
6379 APInt Multiple = getConstantMultiple(S);
6380 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6381}
6382
6384 return std::min(getConstantMultiple(S).countTrailingZeros(),
6385 (unsigned)getTypeSizeInBits(S->getType()));
6386}
6387
6388/// Helper method to assign a range to V from metadata present in the IR.
6389static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6390 if (Instruction *I = dyn_cast<Instruction>(V)) {
6391 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6392 return getConstantRangeFromMetadata(*MD);
6393 if (const auto *CB = dyn_cast<CallBase>(V))
6394 if (std::optional<ConstantRange> Range = CB->getRange())
6395 return Range;
6396 }
6397 if (auto *A = dyn_cast<Argument>(V))
6398 if (std::optional<ConstantRange> Range = A->getRange())
6399 return Range;
6400
6401 return std::nullopt;
6402}
6403
6405 SCEV::NoWrapFlags Flags) {
6406 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6407 AddRec->setNoWrapFlags(Flags);
6408 UnsignedRanges.erase(AddRec);
6409 SignedRanges.erase(AddRec);
6410 ConstantMultipleCache.erase(AddRec);
6411 }
6412}
6413
6414ConstantRange ScalarEvolution::
6415getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6416 const DataLayout &DL = getDataLayout();
6417
6418 unsigned BitWidth = getTypeSizeInBits(U->getType());
6419 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6420
6421 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6422 // use information about the trip count to improve our available range. Note
6423 // that the trip count independent cases are already handled by known bits.
6424 // WARNING: The definition of recurrence used here is subtly different than
6425 // the one used by AddRec (and thus most of this file). Step is allowed to
6426 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6427 // and other addrecs in the same loop (for non-affine addrecs). The code
6428 // below intentionally handles the case where step is not loop invariant.
6429 auto *P = dyn_cast<PHINode>(U->getValue());
6430 if (!P)
6431 return FullSet;
6432
6433 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6434 // even the values that are not available in these blocks may come from them,
6435 // and this leads to false-positive recurrence test.
6436 for (auto *Pred : predecessors(P->getParent()))
6437 if (!DT.isReachableFromEntry(Pred))
6438 return FullSet;
6439
6440 BinaryOperator *BO;
6441 Value *Start, *Step;
6442 if (!matchSimpleRecurrence(P, BO, Start, Step))
6443 return FullSet;
6444
6445 // If we found a recurrence in reachable code, we must be in a loop. Note
6446 // that BO might be in some subloop of L, and that's completely okay.
6447 auto *L = LI.getLoopFor(P->getParent());
6448 assert(L && L->getHeader() == P->getParent());
6449 if (!L->contains(BO->getParent()))
6450 // NOTE: This bailout should be an assert instead. However, asserting
6451 // the condition here exposes a case where LoopFusion is querying SCEV
6452 // with malformed loop information during the midst of the transform.
6453 // There doesn't appear to be an obvious fix, so for the moment bailout
6454 // until the caller issue can be fixed. PR49566 tracks the bug.
6455 return FullSet;
6456
6457 // TODO: Extend to other opcodes such as mul, and div
6458 switch (BO->getOpcode()) {
6459 default:
6460 return FullSet;
6461 case Instruction::AShr:
6462 case Instruction::LShr:
6463 case Instruction::Shl:
6464 break;
6465 };
6466
6467 if (BO->getOperand(0) != P)
6468 // TODO: Handle the power function forms some day.
6469 return FullSet;
6470
6471 unsigned TC = getSmallConstantMaxTripCount(L);
6472 if (!TC || TC >= BitWidth)
6473 return FullSet;
6474
6475 auto KnownStart = computeKnownBits(Start, DL, 0, &AC, nullptr, &DT);
6476 auto KnownStep = computeKnownBits(Step, DL, 0, &AC, nullptr, &DT);
6477 assert(KnownStart.getBitWidth() == BitWidth &&
6478 KnownStep.getBitWidth() == BitWidth);
6479
6480 // Compute total shift amount, being careful of overflow and bitwidths.
6481 auto MaxShiftAmt = KnownStep.getMaxValue();
6482 APInt TCAP(BitWidth, TC-1);
6483 bool Overflow = false;
6484 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6485 if (Overflow)
6486 return FullSet;
6487
6488 switch (BO->getOpcode()) {
6489 default:
6490 llvm_unreachable("filtered out above");
6491 case Instruction::AShr: {
6492 // For each ashr, three cases:
6493 // shift = 0 => unchanged value
6494 // saturation => 0 or -1
6495 // other => a value closer to zero (of the same sign)
6496 // Thus, the end value is closer to zero than the start.
6497 auto KnownEnd = KnownBits::ashr(KnownStart,
6498 KnownBits::makeConstant(TotalShift));
6499 if (KnownStart.isNonNegative())
6500 // Analogous to lshr (simply not yet canonicalized)
6501 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6502 KnownStart.getMaxValue() + 1);
6503 if (KnownStart.isNegative())
6504 // End >=u Start && End <=s Start
6505 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6506 KnownEnd.getMaxValue() + 1);
6507 break;
6508 }
6509 case Instruction::LShr: {
6510 // For each lshr, three cases:
6511 // shift = 0 => unchanged value
6512 // saturation => 0
6513 // other => a smaller positive number
6514 // Thus, the low end of the unsigned range is the last value produced.
6515 auto KnownEnd = KnownBits::lshr(KnownStart,
6516 KnownBits::makeConstant(TotalShift));
6517 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6518 KnownStart.getMaxValue() + 1);
6519 }
6520 case Instruction::Shl: {
6521 // Iff no bits are shifted out, value increases on every shift.
6522 auto KnownEnd = KnownBits::shl(KnownStart,
6523 KnownBits::makeConstant(TotalShift));
6524 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6525 return ConstantRange(KnownStart.getMinValue(),
6526 KnownEnd.getMaxValue() + 1);
6527 break;
6528 }
6529 };
6530 return FullSet;
6531}
6532
6533const ConstantRange &
6534ScalarEvolution::getRangeRefIter(const SCEV *S,
6535 ScalarEvolution::RangeSignHint SignHint) {
6537 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6538 : SignedRanges;
6541
6542 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6543 // SCEVUnknown PHI node.
6544 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6545 if (!Seen.insert(Expr).second)
6546 return;
6547 if (Cache.contains(Expr))
6548 return;
6549 switch (Expr->getSCEVType()) {
6550 case scUnknown:
6551 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6552 break;
6553 [[fallthrough]];
6554 case scConstant:
6555 case scVScale:
6556 case scTruncate:
6557 case scZeroExtend:
6558 case scSignExtend:
6559 case scPtrToInt:
6560 case scAddExpr:
6561 case scMulExpr:
6562 case scUDivExpr:
6563 case scAddRecExpr:
6564 case scUMaxExpr:
6565 case scSMaxExpr:
6566 case scUMinExpr:
6567 case scSMinExpr:
6569 WorkList.push_back(Expr);
6570 break;
6571 case scCouldNotCompute:
6572 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6573 }
6574 };
6575 AddToWorklist(S);
6576
6577 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6578 for (unsigned I = 0; I != WorkList.size(); ++I) {
6579 const SCEV *P = WorkList[I];
6580 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6581 // If it is not a `SCEVUnknown`, just recurse into operands.
6582 if (!UnknownS) {
6583 for (const SCEV *Op : P->operands())
6584 AddToWorklist(Op);
6585 continue;
6586 }
6587 // `SCEVUnknown`'s require special treatment.
6588 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6589 if (!PendingPhiRangesIter.insert(P).second)
6590 continue;
6591 for (auto &Op : reverse(P->operands()))
6592 AddToWorklist(getSCEV(Op));
6593 }
6594 }
6595
6596 if (!WorkList.empty()) {
6597 // Use getRangeRef to compute ranges for items in the worklist in reverse
6598 // order. This will force ranges for earlier operands to be computed before
6599 // their users in most cases.
6600 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6601 getRangeRef(P, SignHint);
6602
6603 if (auto *UnknownS = dyn_cast<SCEVUnknown>(P))
6604 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue()))
6605 PendingPhiRangesIter.erase(P);
6606 }
6607 }
6608
6609 return getRangeRef(S, SignHint, 0);
6610}
6611
6612/// Determine the range for a particular SCEV. If SignHint is
6613/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6614/// with a "cleaner" unsigned (resp. signed) representation.
6615const ConstantRange &ScalarEvolution::getRangeRef(
6616 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6618 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6619 : SignedRanges;
6621 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6623
6624 // See if we've computed this range already.
6626 if (I != Cache.end())
6627 return I->second;
6628
6629 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6630 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6631
6632 // Switch to iteratively computing the range for S, if it is part of a deeply
6633 // nested expression.
6635 return getRangeRefIter(S, SignHint);
6636
6637 unsigned BitWidth = getTypeSizeInBits(S->getType());
6638 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6639 using OBO = OverflowingBinaryOperator;
6640
6641 // If the value has known zeros, the maximum value will have those known zeros
6642 // as well.
6643 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6644 APInt Multiple = getNonZeroConstantMultiple(S);
6645 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6646 if (!Remainder.isZero())
6647 ConservativeResult =
6649 APInt::getMaxValue(BitWidth) - Remainder + 1);
6650 }
6651 else {
6653 if (TZ != 0) {
6654 ConservativeResult = ConstantRange(
6656 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6657 }
6658 }
6659
6660 switch (S->getSCEVType()) {
6661 case scConstant:
6662 llvm_unreachable("Already handled above.");
6663 case scVScale:
6664 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6665 case scTruncate: {
6666 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6667 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6668 return setRange(
6669 Trunc, SignHint,
6670 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6671 }
6672 case scZeroExtend: {
6673 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6674 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6675 return setRange(
6676 ZExt, SignHint,
6677 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6678 }
6679 case scSignExtend: {
6680 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6681 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6682 return setRange(
6683 SExt, SignHint,
6684 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6685 }
6686 case scPtrToInt: {
6687 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(S);
6688 ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint, Depth + 1);
6689 return setRange(PtrToInt, SignHint, X);
6690 }
6691 case scAddExpr: {
6692 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6693 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6694 unsigned WrapType = OBO::AnyWrap;
6695 if (Add->hasNoSignedWrap())
6696 WrapType |= OBO::NoSignedWrap;
6697 if (Add->hasNoUnsignedWrap())
6698 WrapType |= OBO::NoUnsignedWrap;
6699 for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
6700 X = X.addWithNoWrap(getRangeRef(Add->getOperand(i), SignHint, Depth + 1),
6701 WrapType, RangeType);
6702 return setRange(Add, SignHint,
6703 ConservativeResult.intersectWith(X, RangeType));
6704 }
6705 case scMulExpr: {
6706 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6707 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6708 for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
6709 X = X.multiply(getRangeRef(Mul->getOperand(i), SignHint, Depth + 1));
6710 return setRange(Mul, SignHint,
6711 ConservativeResult.intersectWith(X, RangeType));
6712 }
6713 case scUDivExpr: {
6714 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6715 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6716 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6717 return setRange(UDiv, SignHint,
6718 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6719 }
6720 case scAddRecExpr: {
6721 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6722 // If there's no unsigned wrap, the value will never be less than its
6723 // initial value.
6724 if (AddRec->hasNoUnsignedWrap()) {
6725 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6726 if (!UnsignedMinValue.isZero())
6727 ConservativeResult = ConservativeResult.intersectWith(
6728 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6729 }
6730
6731 // If there's no signed wrap, and all the operands except initial value have
6732 // the same sign or zero, the value won't ever be:
6733 // 1: smaller than initial value if operands are non negative,
6734 // 2: bigger than initial value if operands are non positive.
6735 // For both cases, value can not cross signed min/max boundary.
6736 if (AddRec->hasNoSignedWrap()) {
6737 bool AllNonNeg = true;
6738 bool AllNonPos = true;
6739 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6740 if (!isKnownNonNegative(AddRec->getOperand(i)))
6741 AllNonNeg = false;
6742 if (!isKnownNonPositive(AddRec->getOperand(i)))
6743 AllNonPos = false;
6744 }
6745 if (AllNonNeg)
6746 ConservativeResult = ConservativeResult.intersectWith(
6749 RangeType);
6750 else if (AllNonPos)
6751 ConservativeResult = ConservativeResult.intersectWith(
6753 getSignedRangeMax(AddRec->getStart()) +
6754 1),
6755 RangeType);
6756 }
6757
6758 // TODO: non-affine addrec
6759 if (AddRec->isAffine()) {
6760 const SCEV *MaxBEScev =
6762 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
6763 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
6764
6765 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
6766 // MaxBECount's active bits are all <= AddRec's bit width.
6767 if (MaxBECount.getBitWidth() > BitWidth &&
6768 MaxBECount.getActiveBits() <= BitWidth)
6769 MaxBECount = MaxBECount.trunc(BitWidth);
6770 else if (MaxBECount.getBitWidth() < BitWidth)
6771 MaxBECount = MaxBECount.zext(BitWidth);
6772
6773 if (MaxBECount.getBitWidth() == BitWidth) {
6774 auto RangeFromAffine = getRangeForAffineAR(
6775 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6776 ConservativeResult =
6777 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
6778
6779 auto RangeFromFactoring = getRangeViaFactoring(
6780 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6781 ConservativeResult =
6782 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
6783 }
6784 }
6785
6786 // Now try symbolic BE count and more powerful methods.
6788 const SCEV *SymbolicMaxBECount =
6790 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
6791 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
6792 AddRec->hasNoSelfWrap()) {
6793 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
6794 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
6795 ConservativeResult =
6796 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
6797 }
6798 }
6799 }
6800
6801 return setRange(AddRec, SignHint, std::move(ConservativeResult));
6802 }
6803 case scUMaxExpr:
6804 case scSMaxExpr:
6805 case scUMinExpr:
6806 case scSMinExpr:
6807 case scSequentialUMinExpr: {
6809 switch (S->getSCEVType()) {
6810 case scUMaxExpr:
6811 ID = Intrinsic::umax;
6812 break;
6813 case scSMaxExpr:
6814 ID = Intrinsic::smax;
6815 break;
6816 case scUMinExpr:
6818 ID = Intrinsic::umin;
6819 break;
6820 case scSMinExpr:
6821 ID = Intrinsic::smin;
6822 break;
6823 default:
6824 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
6825 }
6826
6827 const auto *NAry = cast<SCEVNAryExpr>(S);
6828 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
6829 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
6830 X = X.intrinsic(
6831 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
6832 return setRange(S, SignHint,
6833 ConservativeResult.intersectWith(X, RangeType));
6834 }
6835 case scUnknown: {
6836 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6837 Value *V = U->getValue();
6838
6839 // Check if the IR explicitly contains !range metadata.
6840 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
6841 if (MDRange)
6842 ConservativeResult =
6843 ConservativeResult.intersectWith(*MDRange, RangeType);
6844
6845 // Use facts about recurrences in the underlying IR. Note that add
6846 // recurrences are AddRecExprs and thus don't hit this path. This
6847 // primarily handles shift recurrences.
6848 auto CR = getRangeForUnknownRecurrence(U);
6849 ConservativeResult = ConservativeResult.intersectWith(CR);
6850
6851 // See if ValueTracking can give us a useful range.
6852 const DataLayout &DL = getDataLayout();
6853 KnownBits Known = computeKnownBits(V, DL, 0, &AC, nullptr, &DT);
6854 if (Known.getBitWidth() != BitWidth)
6855 Known = Known.zextOrTrunc(BitWidth);
6856
6857 // ValueTracking may be able to compute a tighter result for the number of
6858 // sign bits than for the value of those sign bits.
6859 unsigned NS = ComputeNumSignBits(V, DL, 0, &AC, nullptr, &DT);
6860 if (U->getType()->isPointerTy()) {
6861 // If the pointer size is larger than the index size type, this can cause
6862 // NS to be larger than BitWidth. So compensate for this.
6863 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
6864 int ptrIdxDiff = ptrSize - BitWidth;
6865 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
6866 NS -= ptrIdxDiff;
6867 }
6868
6869 if (NS > 1) {
6870 // If we know any of the sign bits, we know all of the sign bits.
6871 if (!Known.Zero.getHiBits(NS).isZero())
6872 Known.Zero.setHighBits(NS);
6873 if (!Known.One.getHiBits(NS).isZero())
6874 Known.One.setHighBits(NS);
6875 }
6876
6877 if (Known.getMinValue() != Known.getMaxValue() + 1)
6878 ConservativeResult = ConservativeResult.intersectWith(
6879 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
6880 RangeType);
6881 if (NS > 1)
6882 ConservativeResult = ConservativeResult.intersectWith(
6884 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
6885 RangeType);
6886
6887 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
6888 // Strengthen the range if the underlying IR value is a
6889 // global/alloca/heap allocation using the size of the object.
6890 ObjectSizeOpts Opts;
6891 Opts.RoundToAlign = false;
6892 Opts.NullIsUnknownSize = true;
6893 uint64_t ObjSize;
6894 if ((isa<GlobalVariable>(V) || isa<AllocaInst>(V) ||
6895 isAllocationFn(V, &TLI)) &&
6896 getObjectSize(V, ObjSize, DL, &TLI, Opts) && ObjSize > 1) {
6897 // The highest address the object can start is ObjSize bytes before the
6898 // end (unsigned max value). If this value is not a multiple of the
6899 // alignment, the last possible start value is the next lowest multiple
6900 // of the alignment. Note: The computations below cannot overflow,
6901 // because if they would there's no possible start address for the
6902 // object.
6903 APInt MaxVal = APInt::getMaxValue(BitWidth) - APInt(BitWidth, ObjSize);
6904 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
6905 uint64_t Rem = MaxVal.urem(Align);
6906 MaxVal -= APInt(BitWidth, Rem);
6907 APInt MinVal = APInt::getZero(BitWidth);
6908 if (llvm::isKnownNonZero(V, DL))
6909 MinVal = Align;
6910 ConservativeResult = ConservativeResult.intersectWith(
6911 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
6912 }
6913 }
6914
6915 // A range of Phi is a subset of union of all ranges of its input.
6916 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
6917 // Make sure that we do not run over cycled Phis.
6918 if (PendingPhiRanges.insert(Phi).second) {
6919 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
6920
6921 for (const auto &Op : Phi->operands()) {
6922 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
6923 RangeFromOps = RangeFromOps.unionWith(OpRange);
6924 // No point to continue if we already have a full set.
6925 if (RangeFromOps.isFullSet())
6926 break;
6927 }
6928 ConservativeResult =
6929 ConservativeResult.intersectWith(RangeFromOps, RangeType);
6930 bool Erased = PendingPhiRanges.erase(Phi);
6931 assert(Erased && "Failed to erase Phi properly?");
6932 (void)Erased;
6933 }
6934 }
6935
6936 // vscale can't be equal to zero
6937 if (const auto *II = dyn_cast<IntrinsicInst>(V))
6938 if (II->getIntrinsicID() == Intrinsic::vscale) {
6940 ConservativeResult = ConservativeResult.difference(Disallowed);
6941 }
6942
6943 return setRange(U, SignHint, std::move(ConservativeResult));
6944 }
6945 case scCouldNotCompute:
6946 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6947 }
6948
6949 return setRange(S, SignHint, std::move(ConservativeResult));
6950}
6951
6952// Given a StartRange, Step and MaxBECount for an expression compute a range of
6953// values that the expression can take. Initially, the expression has a value
6954// from StartRange and then is changed by Step up to MaxBECount times. Signed
6955// argument defines if we treat Step as signed or unsigned.
6957 const ConstantRange &StartRange,
6958 const APInt &MaxBECount,
6959 bool Signed) {
6960 unsigned BitWidth = Step.getBitWidth();
6961 assert(BitWidth == StartRange.getBitWidth() &&
6962 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
6963 // If either Step or MaxBECount is 0, then the expression won't change, and we
6964 // just need to return the initial range.
6965 if (Step == 0 || MaxBECount == 0)
6966 return StartRange;
6967
6968 // If we don't know anything about the initial value (i.e. StartRange is
6969 // FullRange), then we don't know anything about the final range either.
6970 // Return FullRange.
6971 if (StartRange.isFullSet())
6972 return ConstantRange::getFull(BitWidth);
6973
6974 // If Step is signed and negative, then we use its absolute value, but we also
6975 // note that we're moving in the opposite direction.
6976 bool Descending = Signed && Step.isNegative();
6977
6978 if (Signed)
6979 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
6980 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
6981 // This equations hold true due to the well-defined wrap-around behavior of
6982 // APInt.
6983 Step = Step.abs();
6984
6985 // Check if Offset is more than full span of BitWidth. If it is, the
6986 // expression is guaranteed to overflow.
6987 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
6988 return ConstantRange::getFull(BitWidth);
6989
6990 // Offset is by how much the expression can change. Checks above guarantee no
6991 // overflow here.
6992 APInt Offset = Step * MaxBECount;
6993
6994 // Minimum value of the final range will match the minimal value of StartRange
6995 // if the expression is increasing and will be decreased by Offset otherwise.
6996 // Maximum value of the final range will match the maximal value of StartRange
6997 // if the expression is decreasing and will be increased by Offset otherwise.
6998 APInt StartLower = StartRange.getLower();
6999 APInt StartUpper = StartRange.getUpper() - 1;
7000 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
7001 : (StartUpper + std::move(Offset));
7002
7003 // It's possible that the new minimum/maximum value will fall into the initial
7004 // range (due to wrap around). This means that the expression can take any
7005 // value in this bitwidth, and we have to return full range.
7006 if (StartRange.contains(MovedBoundary))
7007 return ConstantRange::getFull(BitWidth);
7008
7009 APInt NewLower =
7010 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7011 APInt NewUpper =
7012 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7013 NewUpper += 1;
7014
7015 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7016 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7017}
7018
7019ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7020 const SCEV *Step,
7021 const APInt &MaxBECount) {
7022 assert(getTypeSizeInBits(Start->getType()) ==
7023 getTypeSizeInBits(Step->getType()) &&
7024 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7025 "mismatched bit widths");
7026
7027 // First, consider step signed.
7028 ConstantRange StartSRange = getSignedRange(Start);
7029 ConstantRange StepSRange = getSignedRange(Step);
7030
7031 // If Step can be both positive and negative, we need to find ranges for the
7032 // maximum absolute step values in both directions and union them.
7034 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7036 StartSRange, MaxBECount,
7037 /* Signed = */ true));
7038
7039 // Next, consider step unsigned.
7041 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7042 /* Signed = */ false);
7043
7044 // Finally, intersect signed and unsigned ranges.
7046}
7047
7048ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7049 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7050 ScalarEvolution::RangeSignHint SignHint) {
7051 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7052 assert(AddRec->hasNoSelfWrap() &&
7053 "This only works for non-self-wrapping AddRecs!");
7054 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7055 const SCEV *Step = AddRec->getStepRecurrence(*this);
7056 // Only deal with constant step to save compile time.
7057 if (!isa<SCEVConstant>(Step))
7058 return ConstantRange::getFull(BitWidth);
7059 // Let's make sure that we can prove that we do not self-wrap during
7060 // MaxBECount iterations. We need this because MaxBECount is a maximum
7061 // iteration count estimate, and we might infer nw from some exit for which we
7062 // do not know max exit count (or any other side reasoning).
7063 // TODO: Turn into assert at some point.
7064 if (getTypeSizeInBits(MaxBECount->getType()) >
7065 getTypeSizeInBits(AddRec->getType()))
7066 return ConstantRange::getFull(BitWidth);
7067 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7068 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7069 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7070 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7071 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7072 MaxItersWithoutWrap))
7073 return ConstantRange::getFull(BitWidth);
7074
7075 ICmpInst::Predicate LEPred =
7077 ICmpInst::Predicate GEPred =
7079 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7080
7081 // We know that there is no self-wrap. Let's take Start and End values and
7082 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7083 // the iteration. They either lie inside the range [Min(Start, End),
7084 // Max(Start, End)] or outside it:
7085 //
7086 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7087 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7088 //
7089 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7090 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7091 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7092 // Start <= End and step is positive, or Start >= End and step is negative.
7093 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7094 ConstantRange StartRange = getRangeRef(Start, SignHint);
7095 ConstantRange EndRange = getRangeRef(End, SignHint);
7096 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7097 // If they already cover full iteration space, we will know nothing useful
7098 // even if we prove what we want to prove.
7099 if (RangeBetween.isFullSet())
7100 return RangeBetween;
7101 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7102 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7103 : RangeBetween.isWrappedSet();
7104 if (IsWrappedSet)
7105 return ConstantRange::getFull(BitWidth);
7106
7107 if (isKnownPositive(Step) &&
7108 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7109 return RangeBetween;
7110 if (isKnownNegative(Step) &&
7111 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7112 return RangeBetween;
7113 return ConstantRange::getFull(BitWidth);
7114}
7115
7116ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7117 const SCEV *Step,
7118 const APInt &MaxBECount) {
7119 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7120 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7121
7122 unsigned BitWidth = MaxBECount.getBitWidth();
7123 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7124 getTypeSizeInBits(Step->getType()) == BitWidth &&
7125 "mismatched bit widths");
7126
7127 struct SelectPattern {
7128 Value *Condition = nullptr;
7129 APInt TrueValue;
7130 APInt FalseValue;
7131
7132 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7133 const SCEV *S) {
7134 std::optional<unsigned> CastOp;
7135 APInt Offset(BitWidth, 0);
7136
7138 "Should be!");
7139
7140 // Peel off a constant offset:
7141 if (auto *SA = dyn_cast<SCEVAddExpr>(S)) {
7142 // In the future we could consider being smarter here and handle
7143 // {Start+Step,+,Step} too.
7144 if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0)))
7145 return;
7146
7147 Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt();
7148 S = SA->getOperand(1);
7149 }
7150
7151 // Peel off a cast operation
7152 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7153 CastOp = SCast->getSCEVType();
7154 S = SCast->getOperand();
7155 }
7156
7157 using namespace llvm::PatternMatch;
7158
7159 auto *SU = dyn_cast<SCEVUnknown>(S);
7160 const APInt *TrueVal, *FalseVal;
7161 if (!SU ||
7162 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7163 m_APInt(FalseVal)))) {
7164 Condition = nullptr;
7165 return;
7166 }
7167
7168 TrueValue = *TrueVal;
7169 FalseValue = *FalseVal;
7170
7171 // Re-apply the cast we peeled off earlier
7172 if (CastOp)
7173 switch (*CastOp) {
7174 default:
7175 llvm_unreachable("Unknown SCEV cast type!");
7176
7177 case scTruncate:
7178 TrueValue = TrueValue.trunc(BitWidth);
7179 FalseValue = FalseValue.trunc(BitWidth);
7180 break;
7181 case scZeroExtend:
7182 TrueValue = TrueValue.zext(BitWidth);
7183 FalseValue = FalseValue.zext(BitWidth);
7184 break;
7185 case scSignExtend:
7186 TrueValue = TrueValue.sext(BitWidth);
7187 FalseValue = FalseValue.sext(BitWidth);
7188 break;
7189 }
7190
7191 // Re-apply the constant offset we peeled off earlier
7192 TrueValue += Offset;
7193 FalseValue += Offset;
7194 }
7195
7196 bool isRecognized() { return Condition != nullptr; }
7197 };
7198
7199 SelectPattern StartPattern(*this, BitWidth, Start);
7200 if (!StartPattern.isRecognized())
7201 return ConstantRange::getFull(BitWidth);
7202
7203 SelectPattern StepPattern(*this, BitWidth, Step);
7204 if (!StepPattern.isRecognized())
7205 return ConstantRange::getFull(BitWidth);
7206
7207 if (StartPattern.Condition != StepPattern.Condition) {
7208 // We don't handle this case today; but we could, by considering four
7209 // possibilities below instead of two. I'm not sure if there are cases where
7210 // that will help over what getRange already does, though.
7211 return ConstantRange::getFull(BitWidth);
7212 }
7213
7214 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7215 // construct arbitrary general SCEV expressions here. This function is called
7216 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7217 // say) can end up caching a suboptimal value.
7218
7219 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7220 // C2352 and C2512 (otherwise it isn't needed).
7221
7222 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7223 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7224 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7225 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7226
7227 ConstantRange TrueRange =
7228 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7229 ConstantRange FalseRange =
7230 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7231
7232 return TrueRange.unionWith(FalseRange);
7233}
7234
7235SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7236 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7237 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7238
7239 // Return early if there are no flags to propagate to the SCEV.
7241 if (BinOp->hasNoUnsignedWrap())
7243 if (BinOp->hasNoSignedWrap())
7245 if (Flags == SCEV::FlagAnyWrap)
7246 return SCEV::FlagAnyWrap;
7247
7248 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7249}
7250
7251const Instruction *
7252ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7253 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7254 return &*AddRec->getLoop()->getHeader()->begin();
7255 if (auto *U = dyn_cast<SCEVUnknown>(S))
7256 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7257 return I;
7258 return nullptr;
7259}
7260
7261const Instruction *
7262ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops,
7263 bool &Precise) {
7264 Precise = true;
7265 // Do a bounded search of the def relation of the requested SCEVs.
7268 auto pushOp = [&](const SCEV *S) {
7269 if (!Visited.insert(S).second)
7270 return;
7271 // Threshold of 30 here is arbitrary.
7272 if (Visited.size() > 30) {
7273 Precise = false;
7274 return;
7275 }
7276 Worklist.push_back(S);
7277 };
7278
7279 for (const auto *S : Ops)
7280 pushOp(S);
7281
7282 const Instruction *Bound = nullptr;
7283 while (!Worklist.empty()) {
7284 auto *S = Worklist.pop_back_val();
7285 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7286 if (!Bound || DT.dominates(Bound, DefI))
7287 Bound = DefI;
7288 } else {
7289 for (const auto *Op : S->operands())
7290 pushOp(Op);
7291 }
7292 }
7293 return Bound ? Bound : &*F.getEntryBlock().begin();
7294}
7295
7296const Instruction *
7297ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) {
7298 bool Discard;
7299 return getDefiningScopeBound(Ops, Discard);
7300}
7301
7302bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7303 const Instruction *B) {
7304 if (A->getParent() == B->getParent() &&
7306 B->getIterator()))
7307 return true;
7308
7309 auto *BLoop = LI.getLoopFor(B->getParent());
7310 if (BLoop && BLoop->getHeader() == B->getParent() &&
7311 BLoop->getLoopPreheader() == A->getParent() &&
7313 A->getParent()->end()) &&
7314 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7315 B->getIterator()))
7316 return true;
7317 return false;
7318}
7319
7320
7321bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7322 // Only proceed if we can prove that I does not yield poison.
7324 return false;
7325
7326 // At this point we know that if I is executed, then it does not wrap
7327 // according to at least one of NSW or NUW. If I is not executed, then we do
7328 // not know if the calculation that I represents would wrap. Multiple
7329 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7330 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7331 // derived from other instructions that map to the same SCEV. We cannot make
7332 // that guarantee for cases where I is not executed. So we need to find a
7333 // upper bound on the defining scope for the SCEV, and prove that I is
7334 // executed every time we enter that scope. When the bounding scope is a
7335 // loop (the common case), this is equivalent to proving I executes on every
7336 // iteration of that loop.
7338 for (const Use &Op : I->operands()) {
7339 // I could be an extractvalue from a call to an overflow intrinsic.
7340 // TODO: We can do better here in some cases.
7341 if (isSCEVable(Op->getType()))
7342 SCEVOps.push_back(getSCEV(Op));
7343 }
7344 auto *DefI = getDefiningScopeBound(SCEVOps);
7345 return isGuaranteedToTransferExecutionTo(DefI, I);
7346}
7347
7348bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7349 // If we know that \c I can never be poison period, then that's enough.
7350 if (isSCEVExprNeverPoison(I))
7351 return true;
7352
7353 // If the loop only has one exit, then we know that, if the loop is entered,
7354 // any instruction dominating that exit will be executed. If any such
7355 // instruction would result in UB, the addrec cannot be poison.
7356 //
7357 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7358 // also handles uses outside the loop header (they just need to dominate the
7359 // single exit).
7360
7361 auto *ExitingBB = L->getExitingBlock();
7362 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7363 return false;
7364
7367
7368 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7369 // things that are known to be poison under that assumption go on the
7370 // Worklist.
7371 KnownPoison.insert(I);
7372 Worklist.push_back(I);
7373
7374 while (!Worklist.empty()) {
7375 const Instruction *Poison = Worklist.pop_back_val();
7376
7377 for (const Use &U : Poison->uses()) {
7378 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7379 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7380 DT.dominates(PoisonUser->getParent(), ExitingBB))
7381 return true;
7382
7383 if (propagatesPoison(U) && L->contains(PoisonUser))
7384 if (KnownPoison.insert(PoisonUser).second)
7385 Worklist.push_back(PoisonUser);
7386 }
7387 }
7388
7389 return false;
7390}
7391
7392ScalarEvolution::LoopProperties
7393ScalarEvolution::getLoopProperties(const Loop *L) {
7394 using LoopProperties = ScalarEvolution::LoopProperties;
7395
7396 auto Itr = LoopPropertiesCache.find(L);
7397 if (Itr == LoopPropertiesCache.end()) {
7398 auto HasSideEffects = [](Instruction *I) {
7399 if (auto *SI = dyn_cast<StoreInst>(I))
7400 return !SI->isSimple();
7401
7402 return I->mayThrow() || I->mayWriteToMemory();
7403 };
7404
7405 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7406 /*HasNoSideEffects*/ true};
7407
7408 for (auto *BB : L->getBlocks())
7409 for (auto &I : *BB) {
7411 LP.HasNoAbnormalExits = false;
7412 if (HasSideEffects(&I))
7413 LP.HasNoSideEffects = false;
7414 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7415 break; // We're already as pessimistic as we can get.
7416 }
7417
7418 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7419 assert(InsertPair.second && "We just checked!");
7420 Itr = InsertPair.first;
7421 }
7422
7423 return Itr->second;
7424}
7425
7427 // A mustprogress loop without side effects must be finite.
7428 // TODO: The check used here is very conservative. It's only *specific*
7429 // side effects which are well defined in infinite loops.
7430 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7431}
7432
7433const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7434 // Worklist item with a Value and a bool indicating whether all operands have
7435 // been visited already.
7438
7439 Stack.emplace_back(V, true);
7440 Stack.emplace_back(V, false);
7441 while (!Stack.empty()) {
7442 auto E = Stack.pop_back_val();
7443 Value *CurV = E.getPointer();
7444
7445 if (getExistingSCEV(CurV))
7446 continue;
7447
7449 const SCEV *CreatedSCEV = nullptr;
7450 // If all operands have been visited already, create the SCEV.
7451 if (E.getInt()) {
7452 CreatedSCEV = createSCEV(CurV);
7453 } else {
7454 // Otherwise get the operands we need to create SCEV's for before creating
7455 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7456 // just use it.
7457 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7458 }
7459
7460 if (CreatedSCEV) {
7461 insertValueToMap(CurV, CreatedSCEV);
7462 } else {
7463 // Queue CurV for SCEV creation, followed by its's operands which need to
7464 // be constructed first.
7465 Stack.emplace_back(CurV, true);
7466 for (Value *Op : Ops)
7467 Stack.emplace_back(Op, false);
7468 }
7469 }
7470
7471 return getExistingSCEV(V);
7472}
7473
7474const SCEV *
7475ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7476 if (!isSCEVable(V->getType()))
7477 return getUnknown(V);
7478
7479 if (Instruction *I = dyn_cast<Instruction>(V)) {
7480 // Don't attempt to analyze instructions in blocks that aren't
7481 // reachable. Such instructions don't matter, and they aren't required
7482 // to obey basic rules for definitions dominating uses which this
7483 // analysis depends on.
7484 if (!DT.isReachableFromEntry(I->getParent()))
7485 return getUnknown(PoisonValue::get(V->getType()));
7486 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7487 return getConstant(CI);
7488 else if (isa<GlobalAlias>(V))
7489 return getUnknown(V);
7490 else if (!isa<ConstantExpr>(V))
7491 return getUnknown(V);
7492
7493 Operator *U = cast<Operator>(V);
7494 if (auto BO =
7495 MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) {
7496 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7497 switch (BO->Opcode) {
7498 case Instruction::Add:
7499 case Instruction::Mul: {
7500 // For additions and multiplications, traverse add/mul chains for which we
7501 // can potentially create a single SCEV, to reduce the number of
7502 // get{Add,Mul}Expr calls.
7503 do {
7504 if (BO->Op) {
7505 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7506 Ops.push_back(BO->Op);
7507 break;
7508 }
7509 }
7510 Ops.push_back(BO->RHS);
7511 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7512 dyn_cast<Instruction>(V));
7513 if (!NewBO ||
7514 (BO->Opcode == Instruction::Add &&
7515 (NewBO->Opcode != Instruction::Add &&
7516 NewBO->Opcode != Instruction::Sub)) ||
7517 (BO->Opcode == Instruction::Mul &&
7518 NewBO->Opcode != Instruction::Mul)) {
7519 Ops.push_back(BO->LHS);
7520 break;
7521 }
7522 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7523 // requires a SCEV for the LHS.
7524 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7525 auto *I = dyn_cast<Instruction>(BO->Op);
7526 if (I && programUndefinedIfPoison(I)) {
7527 Ops.push_back(BO->LHS);
7528 break;
7529 }
7530 }
7531 BO = NewBO;
7532 } while (true);
7533 return nullptr;
7534 }
7535 case Instruction::Sub:
7536 case Instruction::UDiv:
7537 case Instruction::URem:
7538 break;
7539 case Instruction::AShr:
7540 case Instruction::Shl:
7541 case Instruction::Xor:
7542 if (!IsConstArg)
7543 return nullptr;
7544 break;
7545 case Instruction::And:
7546 case Instruction::Or:
7547 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7548 return nullptr;
7549 break;
7550 case Instruction::LShr:
7551 return getUnknown(V);
7552 default:
7553 llvm_unreachable("Unhandled binop");
7554 break;
7555 }
7556
7557 Ops.push_back(BO->LHS);
7558 Ops.push_back(BO->RHS);
7559 return nullptr;
7560 }
7561
7562 switch (U->getOpcode()) {
7563 case Instruction::Trunc:
7564 case Instruction::ZExt:
7565 case Instruction::SExt:
7566 case Instruction::PtrToInt:
7567 Ops.push_back(U->getOperand(0));
7568 return nullptr;
7569
7570 case Instruction::BitCast:
7571 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7572 Ops.push_back(U->getOperand(0));
7573 return nullptr;
7574 }
7575 return getUnknown(V);
7576
7577 case Instruction::SDiv:
7578 case Instruction::SRem:
7579 Ops.push_back(U->getOperand(0));
7580 Ops.push_back(U->getOperand(1));
7581 return nullptr;
7582
7583 case Instruction::GetElementPtr:
7584 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7585 "GEP source element type must be sized");
7586 for (Value *Index : U->operands())
7587 Ops.push_back(Index);
7588 return nullptr;
7589
7590 case Instruction::IntToPtr:
7591 return getUnknown(V);
7592
7593 case Instruction::PHI:
7594 // Keep constructing SCEVs' for phis recursively for now.
7595 return nullptr;
7596
7597 case Instruction::Select: {
7598 // Check if U is a select that can be simplified to a SCEVUnknown.
7599 auto CanSimplifyToUnknown = [this, U]() {
7600 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7601 return false;
7602
7603 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7604 if (!ICI)
7605 return false;
7606 Value *LHS = ICI->getOperand(0);
7607 Value *RHS = ICI->getOperand(1);
7608 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7609 ICI->getPredicate() == CmpInst::ICMP_NE) {
7610 if (!(isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()))
7611 return true;
7612 } else if (getTypeSizeInBits(LHS->getType()) >
7613 getTypeSizeInBits(U->getType()))
7614 return true;
7615 return false;
7616 };
7617 if (CanSimplifyToUnknown())
7618 return getUnknown(U);
7619
7620 for (Value *Inc : U->operands())
7621 Ops.push_back(Inc);
7622 return nullptr;
7623 break;
7624 }
7625 case Instruction::Call:
7626 case Instruction::Invoke:
7627 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7628 Ops.push_back(RV);
7629 return nullptr;
7630 }
7631
7632 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7633 switch (II->getIntrinsicID()) {
7634 case Intrinsic::abs:
7635 Ops.push_back(II->getArgOperand(0));
7636 return nullptr;
7637 case Intrinsic::umax:
7638 case Intrinsic::umin:
7639 case Intrinsic::smax:
7640 case Intrinsic::smin:
7641 case Intrinsic::usub_sat:
7642 case Intrinsic::uadd_sat:
7643 Ops.push_back(II->getArgOperand(0));
7644 Ops.push_back(II->getArgOperand(1));
7645 return nullptr;
7646 case Intrinsic::start_loop_iterations:
7647 case Intrinsic::annotation:
7648 case Intrinsic::ptr_annotation:
7649 Ops.push_back(II->getArgOperand(0));
7650 return nullptr;
7651 default:
7652 break;
7653 }
7654 }
7655 break;
7656 }
7657
7658 return nullptr;
7659}
7660
7661const SCEV *ScalarEvolution::createSCEV(Value *V) {
7662 if (!isSCEVable(V->getType()))
7663 return getUnknown(V);
7664
7665 if (Instruction *I = dyn_cast<Instruction>(V)) {
7666 // Don't attempt to analyze instructions in blocks that aren't
7667 // reachable. Such instructions don't matter, and they aren't required
7668 // to obey basic rules for definitions dominating uses which this
7669 // analysis depends on.
7670 if (!DT.isReachableFromEntry(I->getParent()))
7671 return getUnknown(PoisonValue::get(V->getType()));
7672 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7673 return getConstant(CI);
7674 else if (isa<GlobalAlias>(V))
7675 return getUnknown(V);
7676 else if (!isa<ConstantExpr>(V))
7677 return getUnknown(V);
7678
7679 const SCEV *LHS;
7680 const SCEV *RHS;
7681
7682 Operator *U = cast<Operator>(V);
7683 if (auto BO =
7684 MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) {
7685 switch (BO->Opcode) {
7686 case Instruction::Add: {
7687 // The simple thing to do would be to just call getSCEV on both operands
7688 // and call getAddExpr with the result. However if we're looking at a
7689 // bunch of things all added together, this can be quite inefficient,
7690 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7691 // Instead, gather up all the operands and make a single getAddExpr call.
7692 // LLVM IR canonical form means we need only traverse the left operands.
7694 do {
7695 if (BO->Op) {
7696 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7697 AddOps.push_back(OpSCEV);
7698 break;
7699 }
7700
7701 // If a NUW or NSW flag can be applied to the SCEV for this
7702 // addition, then compute the SCEV for this addition by itself
7703 // with a separate call to getAddExpr. We need to do that
7704 // instead of pushing the operands of the addition onto AddOps,
7705 // since the flags are only known to apply to this particular
7706 // addition - they may not apply to other additions that can be
7707 // formed with operands from AddOps.
7708 const SCEV *RHS = getSCEV(BO->RHS);
7709 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7710 if (Flags != SCEV::FlagAnyWrap) {
7711 const SCEV *LHS = getSCEV(BO->LHS);
7712 if (BO->Opcode == Instruction::Sub)
7713 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
7714 else
7715 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
7716 break;
7717 }
7718 }
7719
7720 if (BO->Opcode == Instruction::Sub)
7721 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
7722 else
7723 AddOps.push_back(getSCEV(BO->RHS));
7724
7725 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7726 dyn_cast<Instruction>(V));
7727 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7728 NewBO->Opcode != Instruction::Sub)) {
7729 AddOps.push_back(getSCEV(BO->LHS));
7730 break;
7731 }
7732 BO = NewBO;
7733 } while (true);
7734
7735 return getAddExpr(AddOps);
7736 }
7737
7738 case Instruction::Mul: {
7740 do {
7741 if (BO->Op) {
7742 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7743 MulOps.push_back(OpSCEV);
7744 break;
7745 }
7746
7747 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7748 if (Flags != SCEV::FlagAnyWrap) {
7749 LHS = getSCEV(BO->LHS);
7750 RHS = getSCEV(BO->RHS);
7751 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
7752 break;
7753 }
7754 }
7755
7756 MulOps.push_back(getSCEV(BO->RHS));
7757 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7758 dyn_cast<Instruction>(V));
7759 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7760 MulOps.push_back(getSCEV(BO->LHS));
7761 break;
7762 }
7763 BO = NewBO;
7764 } while (true);
7765
7766 return getMulExpr(MulOps);
7767 }
7768 case Instruction::UDiv:
7769 LHS = getSCEV(BO->LHS);
7770 RHS = getSCEV(BO->RHS);
7771 return getUDivExpr(LHS, RHS);
7772 case Instruction::URem:
7773 LHS = getSCEV(BO->LHS);
7774 RHS = getSCEV(BO->RHS);
7775 return getURemExpr(LHS, RHS);
7776 case Instruction::Sub: {
7778 if (BO->Op)
7779 Flags = getNoWrapFlagsFromUB(BO->Op);
7780 LHS = getSCEV(BO->LHS);
7781 RHS = getSCEV(BO->RHS);
7782 return getMinusSCEV(LHS, RHS, Flags);
7783 }
7784 case Instruction::And:
7785 // For an expression like x&255 that merely masks off the high bits,
7786 // use zext(trunc(x)) as the SCEV expression.
7787 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7788 if (CI->isZero())
7789 return getSCEV(BO->RHS);
7790 if (CI->isMinusOne())
7791 return getSCEV(BO->LHS);
7792 const APInt &A = CI->getValue();
7793
7794 // Instcombine's ShrinkDemandedConstant may strip bits out of
7795 // constants, obscuring what would otherwise be a low-bits mask.
7796 // Use computeKnownBits to compute what ShrinkDemandedConstant
7797 // knew about to reconstruct a low-bits mask value.
7798 unsigned LZ = A.countl_zero();
7799 unsigned TZ = A.countr_zero();
7800 unsigned BitWidth = A.getBitWidth();
7801 KnownBits Known(BitWidth);
7802 computeKnownBits(BO->LHS, Known, getDataLayout(),
7803 0, &AC, nullptr, &DT);
7804
7805 APInt EffectiveMask =
7806 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
7807 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
7808 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
7809 const SCEV *LHS = getSCEV(BO->LHS);
7810 const SCEV *ShiftedLHS = nullptr;
7811 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
7812 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
7813 // For an expression like (x * 8) & 8, simplify the multiply.
7814 unsigned MulZeros = OpC->getAPInt().countr_zero();
7815 unsigned GCD = std::min(MulZeros, TZ);
7816 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
7818 MulOps.push_back(getConstant(OpC->getAPInt().lshr(GCD)));
7819 append_range(MulOps, LHSMul->operands().drop_front());
7820 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
7821 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
7822 }
7823 }
7824 if (!ShiftedLHS)
7825 ShiftedLHS = getUDivExpr(LHS, MulCount);
7826 return getMulExpr(
7828 getTruncateExpr(ShiftedLHS,
7829 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
7830 BO->LHS->getType()),
7831 MulCount);
7832 }
7833 }
7834 // Binary `and` is a bit-wise `umin`.
7835 if (BO->LHS->getType()->isIntegerTy(1)) {
7836 LHS = getSCEV(BO->LHS);
7837 RHS = getSCEV(BO->RHS);
7838 return getUMinExpr(LHS, RHS);
7839 }
7840 break;
7841
7842 case Instruction::Or:
7843 // Binary `or` is a bit-wise `umax`.
7844 if (BO->LHS->getType()->isIntegerTy(1)) {
7845 LHS = getSCEV(BO->LHS);
7846 RHS = getSCEV(BO->RHS);
7847 return getUMaxExpr(LHS, RHS);
7848 }
7849 break;
7850
7851 case Instruction::Xor:
7852 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7853 // If the RHS of xor is -1, then this is a not operation.
7854 if (CI->isMinusOne())
7855 return getNotSCEV(getSCEV(BO->LHS));
7856
7857 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
7858 // This is a variant of the check for xor with -1, and it handles
7859 // the case where instcombine has trimmed non-demanded bits out
7860 // of an xor with -1.
7861 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
7862 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
7863 if (LBO->getOpcode() == Instruction::And &&
7864 LCI->getValue() == CI->getValue())
7865 if (const SCEVZeroExtendExpr *Z =
7866 dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) {
7867 Type *UTy = BO->LHS->getType();
7868 const SCEV *Z0 = Z->getOperand();
7869 Type *Z0Ty = Z0->getType();
7870 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
7871
7872 // If C is a low-bits mask, the zero extend is serving to
7873 // mask off the high bits. Complement the operand and
7874 // re-apply the zext.
7875 if (CI->getValue().isMask(Z0TySize))
7876 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
7877
7878 // If C is a single bit, it may be in the sign-bit position
7879 // before the zero-extend. In this case, represent the xor
7880 // using an add, which is equivalent, and re-apply the zext.
7881 APInt Trunc = CI->getValue().trunc(Z0TySize);
7882 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
7883 Trunc.isSignMask())
7884 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
7885 UTy);
7886 }
7887 }
7888 break;
7889
7890 case Instruction::Shl:
7891 // Turn shift left of a constant amount into a multiply.
7892 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
7893 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
7894
7895 // If the shift count is not less than the bitwidth, the result of
7896 // the shift is undefined. Don't try to analyze it, because the
7897 // resolution chosen here may differ from the resolution chosen in
7898 // other parts of the compiler.
7899 if (SA->getValue().uge(BitWidth))
7900 break;
7901
7902 // We can safely preserve the nuw flag in all cases. It's also safe to
7903 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
7904 // requires special handling. It can be preserved as long as we're not
7905 // left shifting by bitwidth - 1.
7906 auto Flags = SCEV::FlagAnyWrap;
7907 if (BO->Op) {
7908 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
7909 if ((MulFlags & SCEV::FlagNSW) &&
7910 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
7912 if (MulFlags & SCEV::FlagNUW)
7914 }
7915
7916 ConstantInt *X = ConstantInt::get(
7917 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
7918 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
7919 }
7920 break;
7921
7922 case Instruction::AShr:
7923 // AShr X, C, where C is a constant.
7924 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
7925 if (!CI)
7926 break;
7927
7928 Type *OuterTy = BO->LHS->getType();
7930 // If the shift count is not less than the bitwidth, the result of
7931 // the shift is undefined. Don't try to analyze it, because the
7932 // resolution chosen here may differ from the resolution chosen in
7933 // other parts of the compiler.
7934 if (CI->getValue().uge(BitWidth))
7935 break;
7936
7937 if (CI->isZero())
7938 return getSCEV(BO->LHS); // shift by zero --> noop
7939
7940 uint64_t AShrAmt = CI->getZExtValue();
7941 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
7942
7943 Operator *L = dyn_cast<Operator>(BO->LHS);
7944 const SCEV *AddTruncateExpr = nullptr;
7945 ConstantInt *ShlAmtCI = nullptr;
7946 const SCEV *AddConstant = nullptr;
7947
7948 if (L && L->getOpcode() == Instruction::Add) {
7949 // X = Shl A, n
7950 // Y = Add X, c
7951 // Z = AShr Y, m
7952 // n, c and m are constants.
7953
7954 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
7955 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
7956 if (LShift && LShift->getOpcode() == Instruction::Shl) {
7957 if (AddOperandCI) {
7958 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
7959 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
7960 // since we truncate to TruncTy, the AddConstant should be of the
7961 // same type, so create a new Constant with type same as TruncTy.
7962 // Also, the Add constant should be shifted right by AShr amount.
7963 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
7964 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
7965 // we model the expression as sext(add(trunc(A), c << n)), since the
7966 // sext(trunc) part is already handled below, we create a
7967 // AddExpr(TruncExp) which will be used later.
7968 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
7969 }
7970 }
7971 } else if (L && L->getOpcode() == Instruction::Shl) {
7972 // X = Shl A, n
7973 // Y = AShr X, m
7974 // Both n and m are constant.
7975
7976 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
7977 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
7978 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
7979 }
7980
7981 if (AddTruncateExpr && ShlAmtCI) {
7982 // We can merge the two given cases into a single SCEV statement,
7983 // incase n = m, the mul expression will be 2^0, so it gets resolved to
7984 // a simpler case. The following code handles the two cases:
7985 //
7986 // 1) For a two-shift sext-inreg, i.e. n = m,
7987 // use sext(trunc(x)) as the SCEV expression.
7988 //
7989 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
7990 // expression. We already checked that ShlAmt < BitWidth, so
7991 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
7992 // ShlAmt - AShrAmt < Amt.
7993 const APInt &ShlAmt = ShlAmtCI->getValue();
7994 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
7996 ShlAmtCI->getZExtValue() - AShrAmt);
7997 const SCEV *CompositeExpr =
7998 getMulExpr(AddTruncateExpr, getConstant(Mul));
7999 if (L->getOpcode() != Instruction::Shl)
8000 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8001
8002 return getSignExtendExpr(CompositeExpr, OuterTy);
8003 }
8004 }
8005 break;
8006 }
8007 }
8008
8009 switch (U->getOpcode()) {
8010 case Instruction::Trunc:
8011 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8012
8013 case Instruction::ZExt:
8014 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8015
8016 case Instruction::SExt:
8017 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8018 dyn_cast<Instruction>(V))) {
8019 // The NSW flag of a subtract does not always survive the conversion to
8020 // A + (-1)*B. By pushing sign extension onto its operands we are much
8021 // more likely to preserve NSW and allow later AddRec optimisations.
8022 //
8023 // NOTE: This is effectively duplicating this logic from getSignExtend:
8024 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8025 // but by that point the NSW information has potentially been lost.
8026 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8027 Type *Ty = U->getType();
8028 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8029 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8030 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8031 }
8032 }
8033 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8034
8035 case Instruction::BitCast:
8036 // BitCasts are no-op casts so we just eliminate the cast.
8037 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8038 return getSCEV(U->getOperand(0));
8039 break;
8040
8041 case Instruction::PtrToInt: {
8042 // Pointer to integer cast is straight-forward, so do model it.
8043 const SCEV *Op = getSCEV(U->getOperand(0));
8044 Type *DstIntTy = U->getType();
8045 // But only if effective SCEV (integer) type is wide enough to represent
8046 // all possible pointer values.
8047 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8048 if (isa<SCEVCouldNotCompute>(IntOp))
8049 return getUnknown(V);
8050 return IntOp;
8051 }
8052 case Instruction::IntToPtr:
8053 // Just don't deal with inttoptr casts.
8054 return getUnknown(V);
8055
8056 case Instruction::SDiv:
8057 // If both operands are non-negative, this is just an udiv.
8058 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8059 isKnownNonNegative(getSCEV(U->getOperand(1))))
8060 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8061 break;
8062
8063 case Instruction::SRem:
8064 // If both operands are non-negative, this is just an urem.
8065 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8066 isKnownNonNegative(getSCEV(U->getOperand(1))))
8067 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8068 break;
8069
8070 case Instruction::GetElementPtr:
8071 return createNodeForGEP(cast<GEPOperator>(U));
8072
8073 case Instruction::PHI:
8074 return createNodeForPHI(cast<PHINode>(U));
8075
8076 case Instruction::Select:
8077 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8078 U->getOperand(2));
8079
8080 case Instruction::Call:
8081 case Instruction::Invoke:
8082 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8083 return getSCEV(RV);
8084
8085 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8086 switch (II->getIntrinsicID()) {
8087 case Intrinsic::abs:
8088 return getAbsExpr(
8089 getSCEV(II->getArgOperand(0)),
8090 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8091 case Intrinsic::umax:
8092 LHS = getSCEV(II->getArgOperand(0));
8093 RHS = getSCEV(II->getArgOperand(1));
8094 return getUMaxExpr(LHS, RHS);
8095 case Intrinsic::umin:
8096 LHS = getSCEV(II->getArgOperand(0));
8097 RHS = getSCEV(II->getArgOperand(1));
8098 return getUMinExpr(LHS, RHS);
8099 case Intrinsic::smax:
8100 LHS = getSCEV(II->getArgOperand(0));
8101 RHS = getSCEV(II->getArgOperand(1));
8102 return getSMaxExpr(LHS, RHS);
8103 case Intrinsic::smin:
8104 LHS = getSCEV(II->getArgOperand(0));
8105 RHS = getSCEV(II->getArgOperand(1));
8106 return getSMinExpr(LHS, RHS);
8107 case Intrinsic::usub_sat: {
8108 const SCEV *X = getSCEV(II->getArgOperand(0));
8109 const SCEV *Y = getSCEV(II->getArgOperand(1));
8110 const SCEV *ClampedY = getUMinExpr(X, Y);
8111 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8112 }
8113 case Intrinsic::uadd_sat: {
8114 const SCEV *X = getSCEV(II->getArgOperand(0));
8115 const SCEV *Y = getSCEV(II->getArgOperand(1));
8116 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8117 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8118 }
8119 case Intrinsic::start_loop_iterations:
8120 case Intrinsic::annotation:
8121 case Intrinsic::ptr_annotation:
8122 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8123 // just eqivalent to the first operand for SCEV purposes.
8124 return getSCEV(II->getArgOperand(0));
8125 case Intrinsic::vscale:
8126 return getVScale(II->getType());
8127 default:
8128 break;
8129 }
8130 }
8131 break;
8132 }
8133
8134 return getUnknown(V);
8135}
8136
8137//===----------------------------------------------------------------------===//
8138// Iteration Count Computation Code
8139//
8140
8142 if (isa<SCEVCouldNotCompute>(ExitCount))
8143 return getCouldNotCompute();
8144
8145 auto *ExitCountType = ExitCount->getType();
8146 assert(ExitCountType->isIntegerTy());
8147 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8148 1 + ExitCountType->getScalarSizeInBits());
8149 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8150}
8151
8153 Type *EvalTy,
8154 const Loop *L) {
8155 if (isa<SCEVCouldNotCompute>(ExitCount))
8156 return getCouldNotCompute();
8157
8158 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8159 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8160
8161 auto CanAddOneWithoutOverflow = [&]() {
8162 ConstantRange ExitCountRange =
8163 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8164 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8165 return true;
8166
8167 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8168 getMinusOne(ExitCount->getType()));
8169 };
8170
8171 // If we need to zero extend the backedge count, check if we can add one to
8172 // it prior to zero extending without overflow. Provided this is safe, it
8173 // allows better simplification of the +1.
8174 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8175 return getZeroExtendExpr(
8176 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8177
8178 // Get the total trip count from the count by adding 1. This may wrap.
8179 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8180}
8181
8182static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8183 if (!ExitCount)
8184 return 0;
8185
8186 ConstantInt *ExitConst = ExitCount->getValue();
8187
8188 // Guard against huge trip counts.
8189 if (ExitConst->getValue().getActiveBits() > 32)
8190 return 0;
8191
8192 // In case of integer overflow, this returns 0, which is correct.
8193 return ((unsigned)ExitConst->getZExtValue()) + 1;
8194}
8195
8197 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8198 return getConstantTripCount(ExitCount);
8199}
8200
8201unsigned
8203 const BasicBlock *ExitingBlock) {
8204 assert(ExitingBlock && "Must pass a non-null exiting block!");
8205 assert(L->isLoopExiting(ExitingBlock) &&
8206 "Exiting block must actually branch out of the loop!");
8207 const SCEVConstant *ExitCount =
8208 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8209 return getConstantTripCount(ExitCount);
8210}
8211
8213 const auto *MaxExitCount =
8214 dyn_cast<SCEVConstant>(getConstantMaxBackedgeTakenCount(L));
8215 return getConstantTripCount(MaxExitCount);
8216}
8217
8219 SmallVector<BasicBlock *, 8> ExitingBlocks;
8220 L->getExitingBlocks(ExitingBlocks);
8221
8222 std::optional<unsigned> Res;
8223 for (auto *ExitingBB : ExitingBlocks) {
8224 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8225 if (!Res)
8226 Res = Multiple;
8227 Res = (unsigned)std::gcd(*Res, Multiple);
8228 }
8229 return Res.value_or(1);
8230}
8231
8233 const SCEV *ExitCount) {
8234 if (ExitCount == getCouldNotCompute())
8235 return 1;
8236
8237 // Get the trip count
8238 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8239
8240 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8241 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8242 // the greatest power of 2 divisor less than 2^32.
8243 return Multiple.getActiveBits() > 32
8244 ? 1U << std::min((unsigned)31, Multiple.countTrailingZeros())
8245 : (unsigned)Multiple.zextOrTrunc(32).getZExtValue();
8246}
8247
8248/// Returns the largest constant divisor of the trip count of this loop as a
8249/// normal unsigned value, if possible. This means that the actual trip count is
8250/// always a multiple of the returned value (don't forget the trip count could
8251/// very well be zero as well!).
8252///
8253/// Returns 1 if the trip count is unknown or not guaranteed to be the
8254/// multiple of a constant (which is also the case if the trip count is simply
8255/// constant, use getSmallConstantTripCount for that case), Will also return 1
8256/// if the trip count is very large (>= 2^32).
8257///
8258/// As explained in the comments for getSmallConstantTripCount, this assumes
8259/// that control exits the loop via ExitingBlock.
8260unsigned
8262 const BasicBlock *ExitingBlock) {
8263 assert(ExitingBlock && "Must pass a non-null exiting block!");
8264 assert(L->isLoopExiting(ExitingBlock) &&
8265 "Exiting block must actually branch out of the loop!");
8266 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8267 return getSmallConstantTripMultiple(L, ExitCount);
8268}
8269
8271 const BasicBlock *ExitingBlock,
8272 ExitCountKind Kind) {
8273 switch (Kind) {
8274 case Exact:
8275 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8276 case SymbolicMaximum:
8277 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8278 case ConstantMaximum:
8279 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8280 };
8281 llvm_unreachable("Invalid ExitCountKind!");
8282}
8283
8284const SCEV *
8287 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8288}
8289
8291 ExitCountKind Kind) {
8292 switch (Kind) {
8293 case Exact:
8294 return getBackedgeTakenInfo(L).getExact(L, this);
8295 case ConstantMaximum:
8296 return getBackedgeTakenInfo(L).getConstantMax(this);
8297 case SymbolicMaximum:
8298 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8299 };
8300 llvm_unreachable("Invalid ExitCountKind!");
8301}
8302
8305 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8306}
8307
8309 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8310}
8311
8312/// Push PHI nodes in the header of the given loop onto the given Worklist.
8313static void PushLoopPHIs(const Loop *L,
8316 BasicBlock *Header = L->getHeader();
8317
8318 // Push all Loop-header PHIs onto the Worklist stack.
8319 for (PHINode &PN : Header->phis())
8320 if (Visited.insert(&PN).second)
8321 Worklist.push_back(&PN);
8322}
8323
8324ScalarEvolution::BackedgeTakenInfo &
8325ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8326 auto &BTI = getBackedgeTakenInfo(L);
8327 if (BTI.hasFullInfo())
8328 return BTI;
8329
8330 auto Pair = PredicatedBackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
8331
8332 if (!Pair.second)
8333 return Pair.first->second;
8334
8335 BackedgeTakenInfo Result =
8336 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8337
8338 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8339}
8340
8341ScalarEvolution::BackedgeTakenInfo &
8342ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8343 // Initially insert an invalid entry for this loop. If the insertion
8344 // succeeds, proceed to actually compute a backedge-taken count and
8345 // update the value. The temporary CouldNotCompute value tells SCEV
8346 // code elsewhere that it shouldn't attempt to request a new
8347 // backedge-taken count, which could result in infinite recursion.
8348 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8349 BackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
8350 if (!Pair.second)
8351 return Pair.first->second;
8352
8353 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8354 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8355 // must be cleared in this scope.
8356 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8357
8358 // Now that we know more about the trip count for this loop, forget any
8359 // existing SCEV values for PHI nodes in this loop since they are only
8360 // conservative estimates made without the benefit of trip count
8361 // information. This invalidation is not necessary for correctness, and is
8362 // only done to produce more precise results.
8363 if (Result.hasAnyInfo()) {
8364 // Invalidate any expression using an addrec in this loop.
8366 auto LoopUsersIt = LoopUsers.find(L);
8367 if (LoopUsersIt != LoopUsers.end())
8368 append_range(ToForget, LoopUsersIt->second);
8369 forgetMemoizedResults(ToForget);
8370
8371 // Invalidate constant-evolved loop header phis.
8372 for (PHINode &PN : L->getHeader()->phis())
8373 ConstantEvolutionLoopExitValue.erase(&PN);
8374 }
8375
8376 // Re-lookup the insert position, since the call to
8377 // computeBackedgeTakenCount above could result in a
8378 // recusive call to getBackedgeTakenInfo (on a different
8379 // loop), which would invalidate the iterator computed
8380 // earlier.
8381 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8382}
8383
8385 // This method is intended to forget all info about loops. It should
8386 // invalidate caches as if the following happened:
8387 // - The trip counts of all loops have changed arbitrarily
8388 // - Every llvm::Value has been updated in place to produce a different
8389 // result.
8390 BackedgeTakenCounts.clear();
8391 PredicatedBackedgeTakenCounts.clear();
8392 BECountUsers.clear();
8393 LoopPropertiesCache.clear();
8394 ConstantEvolutionLoopExitValue.clear();
8395 ValueExprMap.clear();
8396 ValuesAtScopes.clear();
8397 ValuesAtScopesUsers.clear();
8398 LoopDispositions.clear();
8399 BlockDispositions.clear();
8400 UnsignedRanges.clear();
8401 SignedRanges.clear();
8402 ExprValueMap.clear();
8403 HasRecMap.clear();
8404 ConstantMultipleCache.clear();
8405 PredicatedSCEVRewrites.clear();
8406 FoldCache.clear();
8407 FoldCacheUser.clear();
8408}
8409void ScalarEvolution::visitAndClearUsers(
8413 while (!Worklist.empty()) {
8414 Instruction *I = Worklist.pop_back_val();
8415 if (!isSCEVable(I->getType()))
8416 continue;
8417
8419 ValueExprMap.find_as(static_cast<Value *>(I));
8420 if (It != ValueExprMap.end()) {
8421 eraseValueFromMap(It->first);
8422 ToForget.push_back(It->second);
8423 if (PHINode *PN = dyn_cast<PHINode>(I))
8424 ConstantEvolutionLoopExitValue.erase(PN);
8425 }
8426
8427 PushDefUseChildren(I, Worklist, Visited);
8428 }
8429}
8430
8432 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8436
8437 // Iterate over all the loops and sub-loops to drop SCEV information.
8438 while (!LoopWorklist.empty()) {
8439 auto *CurrL = LoopWorklist.pop_back_val();
8440
8441 // Drop any stored trip count value.
8442 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8443 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8444
8445 // Drop information about predicated SCEV rewrites for this loop.
8446 for (auto I = PredicatedSCEVRewrites.begin();
8447 I != PredicatedSCEVRewrites.end();) {
8448 std::pair<const SCEV *, const Loop *> Entry = I->first;
8449 if (Entry.second == CurrL)
8450 PredicatedSCEVRewrites.erase(I++);
8451 else
8452 ++I;
8453 }
8454
8455 auto LoopUsersItr = LoopUsers.find(CurrL);
8456 if (LoopUsersItr != LoopUsers.end()) {
8457 ToForget.insert(ToForget.end(), LoopUsersItr->second.begin(),
8458 LoopUsersItr->second.end());
8459 }
8460
8461 // Drop information about expressions based on loop-header PHIs.
8462 PushLoopPHIs(CurrL, Worklist, Visited);
8463 visitAndClearUsers(Worklist, Visited, ToForget);
8464
8465 LoopPropertiesCache.erase(CurrL);
8466 // Forget all contained loops too, to avoid dangling entries in the
8467 // ValuesAtScopes map.
8468 LoopWorklist.append(CurrL->begin(), CurrL->end());
8469 }
8470 forgetMemoizedResults(ToForget);
8471}
8472
8474 forgetLoop(L->getOutermostLoop());
8475}
8476
8478 Instruction *I = dyn_cast<Instruction>(V);
8479 if (!I) return;
8480
8481 // Drop information about expressions based on loop-header PHIs.
8485 Worklist.push_back(I);
8486 Visited.insert(I);
8487 visitAndClearUsers(Worklist, Visited, ToForget);
8488
8489 forgetMemoizedResults(ToForget);
8490}
8491
8493 if (!isSCEVable(V->getType()))
8494 return;
8495
8496 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8497 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8498 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8499 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8500 if (const SCEV *S = getExistingSCEV(V)) {
8501 struct InvalidationRootCollector {
8502 Loop *L;
8504
8505 InvalidationRootCollector(Loop *L) : L(L) {}
8506
8507 bool follow(const SCEV *S) {
8508 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8509 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8510 if (L->contains(I))
8511 Roots.push_back(S);
8512 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8513 if (L->contains(AddRec->getLoop()))
8514 Roots.push_back(S);
8515 }
8516 return true;
8517 }
8518 bool isDone() const { return false; }
8519 };
8520
8521 InvalidationRootCollector C(L);
8522 visitAll(S, C);
8523 forgetMemoizedResults(C.Roots);
8524 }
8525
8526 // Also perform the normal invalidation.
8527 forgetValue(V);
8528}
8529
8530void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8531
8533 // Unless a specific value is passed to invalidation, completely clear both
8534 // caches.
8535 if (!V) {
8536 BlockDispositions.clear();
8537 LoopDispositions.clear();
8538 return;
8539 }
8540
8541 if (!isSCEVable(V->getType()))
8542 return;
8543
8544 const SCEV *S = getExistingSCEV(V);
8545 if (!S)
8546 return;
8547
8548 // Invalidate the block and loop dispositions cached for S. Dispositions of
8549 // S's users may change if S's disposition changes (i.e. a user may change to
8550 // loop-invariant, if S changes to loop invariant), so also invalidate
8551 // dispositions of S's users recursively.
8552 SmallVector<const SCEV *, 8> Worklist = {S};
8554 while (!Worklist.empty()) {
8555 const SCEV *Curr = Worklist.pop_back_val();
8556 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8557 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8558 if (!LoopDispoRemoved && !BlockDispoRemoved)
8559 continue;
8560 auto Users = SCEVUsers.find(Curr);
8561 if (Users != SCEVUsers.end())
8562 for (const auto *User : Users->second)
8563 if (Seen.insert(User).second)
8564 Worklist.push_back(User);
8565 }
8566}
8567
8568/// Get the exact loop backedge taken count considering all loop exits. A
8569/// computable result can only be returned for loops with all exiting blocks
8570/// dominating the latch. howFarToZero assumes that the limit of each loop test
8571/// is never skipped. This is a valid assumption as long as the loop exits via
8572/// that test. For precise results, it is the caller's responsibility to specify
8573/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8574const SCEV *
8575ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE,
8577 // If any exits were not computable, the loop is not computable.
8578 if (!isComplete() || ExitNotTaken.empty())
8579 return SE->getCouldNotCompute();
8580
8581 const BasicBlock *Latch = L->getLoopLatch();
8582 // All exiting blocks we have collected must dominate the only backedge.
8583 if (!Latch)
8584 return SE->getCouldNotCompute();
8585
8586 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8587 // count is simply a minimum out of all these calculated exit counts.
8589 for (const auto &ENT : ExitNotTaken) {
8590 const SCEV *BECount = ENT.ExactNotTaken;
8591 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8592 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8593 "We should only have known counts for exiting blocks that dominate "
8594 "latch!");
8595
8596 Ops.push_back(BECount);
8597
8598 if (Preds)
8599 for (const auto *P : ENT.Predicates)
8600 Preds->push_back(P);
8601
8602 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8603 "Predicate should be always true!");
8604 }
8605
8606 // If an earlier exit exits on the first iteration (exit count zero), then
8607 // a later poison exit count should not propagate into the result. This are
8608 // exactly the semantics provided by umin_seq.
8609 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8610}
8611
8612/// Get the exact not taken count for this loop exit.
8613const SCEV *
8614ScalarEvolution::BackedgeTakenInfo::getExact(const BasicBlock *ExitingBlock,
8615 ScalarEvolution *SE) const {
8616 for (const auto &ENT : ExitNotTaken)
8617 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8618 return ENT.ExactNotTaken;
8619
8620 return SE->getCouldNotCompute();
8621}
8622
8623const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8624 const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
8625 for (const auto &ENT : ExitNotTaken)
8626 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8627 return ENT.ConstantMaxNotTaken;
8628
8629 return SE->getCouldNotCompute();
8630}
8631
8632const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8633 const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
8634 for (const auto &ENT : ExitNotTaken)
8635 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8636 return ENT.SymbolicMaxNotTaken;
8637
8638 return SE->getCouldNotCompute();
8639}
8640
8641/// getConstantMax - Get the constant max backedge taken count for the loop.
8642const SCEV *
8643ScalarEvolution::BackedgeTakenInfo::getConstantMax(ScalarEvolution *SE) const {
8644 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8645 return !ENT.hasAlwaysTruePredicate();
8646 };
8647
8648 if (!getConstantMax() || any_of(ExitNotTaken, PredicateNotAlwaysTrue))
8649 return SE->getCouldNotCompute();
8650
8651 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8652 isa<SCEVConstant>(getConstantMax())) &&
8653 "No point in having a non-constant max backedge taken count!");
8654 return getConstantMax();
8655}
8656
8657const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8658 const Loop *L, ScalarEvolution *SE,
8660 if (!SymbolicMax) {
8661 // Form an expression for the maximum exit count possible for this loop. We
8662 // merge the max and exact information to approximate a version of
8663 // getConstantMaxBackedgeTakenCount which isn't restricted to just
8664 // constants.
8666
8667 for (const auto &ENT : ExitNotTaken) {
8668 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
8669 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
8670 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
8671 "We should only have known counts for exiting blocks that "
8672 "dominate latch!");
8673 ExitCounts.push_back(ExitCount);
8674 if (Predicates)
8675 for (const auto *P : ENT.Predicates)
8676 Predicates->push_back(P);
8677
8678 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
8679 "Predicate should be always true!");
8680 }
8681 }
8682 if (ExitCounts.empty())
8683 SymbolicMax = SE->getCouldNotCompute();
8684 else
8685 SymbolicMax =
8686 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
8687 }
8688 return SymbolicMax;
8689}
8690
8691bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
8692 ScalarEvolution *SE) const {
8693 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8694 return !ENT.hasAlwaysTruePredicate();
8695 };
8696 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
8697}
8698
8700 : ExitLimit(E, E, E, false, std::nullopt) {}
8701
8703 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8704 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8706 : ExactNotTaken(E), ConstantMaxNotTaken(ConstantMaxNotTaken),
8707 SymbolicMaxNotTaken(SymbolicMaxNotTaken), MaxOrZero(MaxOrZero) {
8708 // If we prove the max count is zero, so is the symbolic bound. This happens
8709 // in practice due to differences in a) how context sensitive we've chosen
8710 // to be and b) how we reason about bounds implied by UB.
8711 if (ConstantMaxNotTaken->isZero()) {
8713 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
8714 }
8715
8716 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
8717 !isa<SCEVCouldNotCompute>(ConstantMaxNotTaken)) &&
8718 "Exact is not allowed to be less precise than Constant Max");
8719 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
8720 !isa<SCEVCouldNotCompute>(SymbolicMaxNotTaken)) &&
8721 "Exact is not allowed to be less precise than Symbolic Max");
8722 assert((isa<SCEVCouldNotCompute>(SymbolicMaxNotTaken) ||
8723 !isa<SCEVCouldNotCompute>(ConstantMaxNotTaken)) &&
8724 "Symbolic Max is not allowed to be less precise than Constant Max");
8725 assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
8726 isa<SCEVConstant>(ConstantMaxNotTaken)) &&
8727 "No point in having a non-constant max backedge taken count!");
8728 for (const auto *PredSet : PredSetList)
8729 for (const auto *P : *PredSet)
8730 addPredicate(P);
8731 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
8732 "Backedge count should be int");
8733 assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
8735 "Max backedge count should be int");
8736}
8737
8739 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8740 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8742 : ExitLimit(E, ConstantMaxNotTaken, SymbolicMaxNotTaken, MaxOrZero,
8743 { &PredSet }) {}
8744
8745/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
8746/// computable exit into a persistent ExitNotTakenInfo array.
8747ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
8749 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
8750 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
8751 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8752
8753 ExitNotTaken.reserve(ExitCounts.size());
8754 std::transform(ExitCounts.begin(), ExitCounts.end(),
8755 std::back_inserter(ExitNotTaken),
8756 [&](const EdgeExitInfo &EEI) {
8757 BasicBlock *ExitBB = EEI.first;
8758 const ExitLimit &EL = EEI.second;
8759 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
8760 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
8761 EL.Predicates);
8762 });
8763 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
8764 isa<SCEVConstant>(ConstantMax)) &&
8765 "No point in having a non-constant max backedge taken count!");
8766}
8767
8768/// Compute the number of times the backedge of the specified loop will execute.
8769ScalarEvolution::BackedgeTakenInfo
8770ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
8771 bool AllowPredicates) {
8772 SmallVector<BasicBlock *, 8> ExitingBlocks;
8773 L->getExitingBlocks(ExitingBlocks);
8774
8775 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8776
8778 bool CouldComputeBECount = true;
8779 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
8780 const SCEV *MustExitMaxBECount = nullptr;
8781 const SCEV *MayExitMaxBECount = nullptr;
8782 bool MustExitMaxOrZero = false;
8783 bool IsOnlyExit = ExitingBlocks.size() == 1;
8784
8785 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
8786 // and compute maxBECount.
8787 // Do a union of all the predicates here.
8788 for (BasicBlock *ExitBB : ExitingBlocks) {
8789 // We canonicalize untaken exits to br (constant), ignore them so that
8790 // proving an exit untaken doesn't negatively impact our ability to reason
8791 // about the loop as whole.
8792 if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
8793 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
8794 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8795 if (ExitIfTrue == CI->isZero())
8796 continue;
8797 }
8798
8799 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
8800
8801 assert((AllowPredicates || EL.Predicates.empty()) &&
8802 "Predicated exit limit when predicates are not allowed!");
8803
8804 // 1. For each exit that can be computed, add an entry to ExitCounts.
8805 // CouldComputeBECount is true only if all exits can be computed.
8806 if (EL.ExactNotTaken != getCouldNotCompute())
8807 ++NumExitCountsComputed;
8808 else
8809 // We couldn't compute an exact value for this exit, so
8810 // we won't be able to compute an exact value for the loop.
8811 CouldComputeBECount = false;
8812 // Remember exit count if either exact or symbolic is known. Because
8813 // Exact always implies symbolic, only check symbolic.
8814 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
8815 ExitCounts.emplace_back(ExitBB, EL);
8816 else {
8817 assert(EL.ExactNotTaken == getCouldNotCompute() &&
8818 "Exact is known but symbolic isn't?");
8819 ++NumExitCountsNotComputed;
8820 }
8821
8822 // 2. Derive the loop's MaxBECount from each exit's max number of
8823 // non-exiting iterations. Partition the loop exits into two kinds:
8824 // LoopMustExits and LoopMayExits.
8825 //
8826 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
8827 // is a LoopMayExit. If any computable LoopMustExit is found, then
8828 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
8829 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
8830 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
8831 // any
8832 // computable EL.ConstantMaxNotTaken.
8833 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
8834 DT.dominates(ExitBB, Latch)) {
8835 if (!MustExitMaxBECount) {
8836 MustExitMaxBECount = EL.ConstantMaxNotTaken;
8837 MustExitMaxOrZero = EL.MaxOrZero;
8838 } else {
8839 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
8840 EL.ConstantMaxNotTaken);
8841 }
8842 } else if (MayExitMaxBECount != getCouldNotCompute()) {
8843 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
8844 MayExitMaxBECount = EL.ConstantMaxNotTaken;
8845 else {
8846 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
8847 EL.ConstantMaxNotTaken);
8848 }
8849 }
8850 }
8851 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
8852 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
8853 // The loop backedge will be taken the maximum or zero times if there's
8854 // a single exit that must be taken the maximum or zero times.
8855 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
8856
8857 // Remember which SCEVs are used in exit limits for invalidation purposes.
8858 // We only care about non-constant SCEVs here, so we can ignore
8859 // EL.ConstantMaxNotTaken
8860 // and MaxBECount, which must be SCEVConstant.
8861 for (const auto &Pair : ExitCounts) {
8862 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
8863 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
8864 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
8865 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
8866 {L, AllowPredicates});
8867 }
8868 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
8869 MaxBECount, MaxOrZero);
8870}
8871
8873ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
8874 bool IsOnlyExit, bool AllowPredicates) {
8875 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
8876 // If our exiting block does not dominate the latch, then its connection with
8877 // loop's exit limit may be far from trivial.
8878 const BasicBlock *Latch = L->getLoopLatch();
8879 if (!Latch || !DT.dominates(ExitingBlock, Latch))
8880 return getCouldNotCompute();
8881
8882 Instruction *Term = ExitingBlock->getTerminator();
8883 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
8884 assert(BI->isConditional() && "If unconditional, it can't be in loop!");
8885 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8886 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
8887 "It should have one successor in loop and one exit block!");
8888 // Proceed to the next level to examine the exit condition expression.
8889 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
8890 /*ControlsOnlyExit=*/IsOnlyExit,
8891 AllowPredicates);
8892 }
8893
8894 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
8895 // For switch, make sure that there is a single exit from the loop.
8896 BasicBlock *Exit = nullptr;
8897 for (auto *SBB : successors(ExitingBlock))
8898 if (!L->contains(SBB)) {
8899 if (Exit) // Multiple exit successors.
8900 return getCouldNotCompute();
8901 Exit = SBB;
8902 }
8903 assert(Exit && "Exiting block must have at least one exit");
8904 return computeExitLimitFromSingleExitSwitch(
8905 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
8906 }
8907
8908 return getCouldNotCompute();
8909}
8910
8912 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
8913 bool AllowPredicates) {
8914 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
8915 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
8916 ControlsOnlyExit, AllowPredicates);
8917}
8918
8919std::optional<ScalarEvolution::ExitLimit>
8920ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
8921 bool ExitIfTrue, bool ControlsOnlyExit,
8922 bool AllowPredicates) {
8923 (void)this->L;
8924 (void)this->ExitIfTrue;
8925 (void)this->AllowPredicates;
8926
8927 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
8928 this->AllowPredicates == AllowPredicates &&
8929 "Variance in assumed invariant key components!");
8930 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
8931 if (Itr == TripCountMap.end())
8932 return std::nullopt;
8933 return Itr->second;
8934}
8935
8936void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
8937 bool ExitIfTrue,
8938 bool ControlsOnlyExit,
8939 bool AllowPredicates,
8940 const ExitLimit &EL) {
8941 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
8942 this->AllowPredicates == AllowPredicates &&
8943 "Variance in assumed invariant key components!");
8944
8945 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
8946 assert(InsertResult.second && "Expected successful insertion!");
8947 (void)InsertResult;
8948 (void)ExitIfTrue;
8949}
8950
8951ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
8952 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8953 bool ControlsOnlyExit, bool AllowPredicates) {
8954
8955 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
8956 AllowPredicates))
8957 return *MaybeEL;
8958
8959 ExitLimit EL = computeExitLimitFromCondImpl(
8960 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
8961 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
8962 return EL;
8963}
8964
8965ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
8966 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8967 bool ControlsOnlyExit, bool AllowPredicates) {
8968 // Handle BinOp conditions (And, Or).
8969 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
8970 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
8971 return *LimitFromBinOp;
8972
8973 // With an icmp, it may be feasible to compute an exact backedge-taken count.
8974 // Proceed to the next level to examine the icmp.
8975 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
8976 ExitLimit EL =
8977 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
8978 if (EL.hasFullInfo() || !AllowPredicates)
8979 return EL;
8980
8981 // Try again, but use SCEV predicates this time.
8982 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
8983 ControlsOnlyExit,
8984 /*AllowPredicates=*/true);
8985 }
8986
8987 // Check for a constant condition. These are normally stripped out by
8988 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
8989 // preserve the CFG and is temporarily leaving constant conditions
8990 // in place.
8991 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
8992 if (ExitIfTrue == !CI->getZExtValue())
8993 // The backedge is always taken.
8994 return getCouldNotCompute();
8995 // The backedge is never taken.
8996 return getZero(CI->getType());
8997 }
8998
8999 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9000 // with a constant step, we can form an equivalent icmp predicate and figure
9001 // out how many iterations will be taken before we exit.
9002 const WithOverflowInst *WO;
9003 const APInt *C;
9004 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9005 match(WO->getRHS(), m_APInt(C))) {
9006 ConstantRange NWR =
9008 WO->getNoWrapKind());
9009 CmpInst::Predicate Pred;
9010 APInt NewRHSC, Offset;
9011 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9012 if (!ExitIfTrue)
9013 Pred = ICmpInst::getInversePredicate(Pred);
9014 auto *LHS = getSCEV(WO->getLHS());
9015 if (Offset != 0)
9017 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9018 ControlsOnlyExit, AllowPredicates);
9019 if (EL.hasAnyInfo())
9020 return EL;
9021 }
9022
9023 // If it's not an integer or pointer comparison then compute it the hard way.
9024 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9025}
9026
9027std::optional<ScalarEvolution::ExitLimit>
9028ScalarEvolution::computeExitLimitFromCondFromBinOp(
9029 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9030 bool ControlsOnlyExit, bool AllowPredicates) {
9031 // Check if the controlling expression for this loop is an And or Or.
9032 Value *Op0, *Op1;
9033 bool IsAnd = false;
9034 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9035 IsAnd = true;
9036 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9037 IsAnd = false;
9038 else
9039 return std::nullopt;
9040
9041 // EitherMayExit is true in these two cases:
9042 // br (and Op0 Op1), loop, exit
9043 // br (or Op0 Op1), exit, loop
9044 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9045 ExitLimit EL0 = computeExitLimitFromCondCached(
9046 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9047 AllowPredicates);
9048 ExitLimit EL1 = computeExitLimitFromCondCached(
9049 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9050 AllowPredicates);
9051
9052 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
9053 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9054 if (isa<ConstantInt>(Op1))
9055 return Op1 == NeutralElement ? EL0 : EL1;
9056 if (isa<ConstantInt>(Op0))
9057 return Op0 == NeutralElement ? EL1 : EL0;
9058
9059 const SCEV *BECount = getCouldNotCompute();
9060 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9061 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9062 if (EitherMayExit) {
9063 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9064 // Both conditions must be same for the loop to continue executing.
9065 // Choose the less conservative count.
9066 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9067 EL1.ExactNotTaken != getCouldNotCompute()) {
9068 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9069 UseSequentialUMin);
9070 }
9071 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9072 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9073 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9074 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9075 else
9076 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9077 EL1.ConstantMaxNotTaken);
9078 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9079 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9080 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9081 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9082 else
9083 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9084 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9085 } else {
9086 // Both conditions must be same at the same time for the loop to exit.
9087 // For now, be conservative.
9088 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9089 BECount = EL0.ExactNotTaken;
9090 }
9091
9092 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9093 // to be more aggressive when computing BECount than when computing
9094 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9095 // and
9096 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9097 // EL1.ConstantMaxNotTaken to not.
9098 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9099 !isa<SCEVCouldNotCompute>(BECount))
9100 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9101 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9102 SymbolicMaxBECount =
9103 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9104 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9105 { &EL0.Predicates, &EL1.Predicates });
9106}
9107
9108ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9109 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9110 bool AllowPredicates) {
9111 // If the condition was exit on true, convert the condition to exit on false
9113 if (!ExitIfTrue)
9114 Pred = ExitCond->getPredicate();
9115 else
9116 Pred = ExitCond->getInversePredicate();
9117 const ICmpInst::Predicate OriginalPred = Pred;
9118
9119 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9120 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9121
9122 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9123 AllowPredicates);
9124 if (EL.hasAnyInfo())
9125 return EL;
9126
9127 auto *ExhaustiveCount =
9128 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9129
9130 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9131 return ExhaustiveCount;
9132
9133 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9134 ExitCond->getOperand(1), L, OriginalPred);
9135}
9136ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9137 const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
9138 bool ControlsOnlyExit, bool AllowPredicates) {
9139
9140 // Try to evaluate any dependencies out of the loop.
9141 LHS = getSCEVAtScope(LHS, L);
9142 RHS = getSCEVAtScope(RHS, L);
9143
9144 // At this point, we would like to compute how many iterations of the
9145 // loop the predicate will return true for these inputs.
9146 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9147 // If there is a loop-invariant, force it into the RHS.
9148 std::swap(LHS, RHS);
9149 Pred = ICmpInst::getSwappedPredicate(Pred);
9150 }
9151
9152 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9154 // Simplify the operands before analyzing them.
9155 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9156
9157 // If we have a comparison of a chrec against a constant, try to use value
9158 // ranges to answer this query.
9159 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9160 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9161 if (AddRec->getLoop() == L) {
9162 // Form the constant range.
9163 ConstantRange CompRange =
9164 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9165
9166 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9167 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9168 }
9169
9170 // If this loop must exit based on this condition (or execute undefined
9171 // behaviour), and we can prove the test sequence produced must repeat
9172 // the same values on self-wrap of the IV, then we can infer that IV
9173 // doesn't self wrap because if it did, we'd have an infinite (undefined)
9174 // loop.
9175 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9176 // TODO: We can peel off any functions which are invertible *in L*. Loop
9177 // invariant terms are effectively constants for our purposes here.
9178 auto *InnerLHS = LHS;
9179 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9180 InnerLHS = ZExt->getOperand();
9181 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS)) {
9182 auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
9183 if (!AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9184 StrideC && StrideC->getAPInt().isPowerOf2()) {
9185 auto Flags = AR->getNoWrapFlags();
9186 Flags = setFlags(Flags, SCEV::FlagNW);
9189 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9190 }
9191 }
9192 }
9193
9194 switch (Pred) {
9195 case ICmpInst::ICMP_NE: { // while (X != Y)
9196 // Convert to: while (X-Y != 0)
9197 if (LHS->getType()->isPointerTy()) {
9199 if (isa<SCEVCouldNotCompute>(LHS))
9200 return LHS;
9201 }
9202 if (RHS->getType()->isPointerTy()) {
9204 if (isa<SCEVCouldNotCompute>(RHS))
9205 return RHS;
9206 }
9207 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9208 AllowPredicates);
9209 if (EL.hasAnyInfo())
9210 return EL;
9211 break;
9212 }
9213 case ICmpInst::ICMP_EQ: { // while (X == Y)
9214 // Convert to: while (X-Y == 0)
9215 if (LHS->getType()->isPointerTy()) {
9217 if (isa<SCEVCouldNotCompute>(LHS))
9218 return LHS;
9219 }
9220 if (RHS->getType()->isPointerTy()) {
9222 if (isa<SCEVCouldNotCompute>(RHS))
9223 return RHS;
9224 }
9225 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9226 if (EL.hasAnyInfo()) return EL;
9227 break;
9228 }
9229 case ICmpInst::ICMP_SLE:
9230 case ICmpInst::ICMP_ULE:
9231 // Since the loop is finite, an invariant RHS cannot include the boundary
9232 // value, otherwise it would loop forever.
9233 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9234 !isLoopInvariant(RHS, L)) {
9235 // Otherwise, perform the addition in a wider type, to avoid overflow.
9236 // If the LHS is an addrec with the appropriate nowrap flag, the
9237 // extension will be sunk into it and the exit count can be analyzed.
9238 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9239 if (!OldType)
9240 break;
9241 // Prefer doubling the bitwidth over adding a single bit to make it more
9242 // likely that we use a legal type.
9243 auto *NewType =
9244 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9245 if (ICmpInst::isSigned(Pred)) {
9246 LHS = getSignExtendExpr(LHS, NewType);
9247 RHS = getSignExtendExpr(RHS, NewType);
9248 } else {
9249 LHS = getZeroExtendExpr(LHS, NewType);
9250 RHS = getZeroExtendExpr(RHS, NewType);
9251 }
9252 }
9253 RHS = getAddExpr(getOne(RHS->getType()), RHS);
9254 [[fallthrough]];
9255 case ICmpInst::ICMP_SLT:
9256 case ICmpInst::ICMP_ULT: { // while (X < Y)
9257 bool IsSigned = ICmpInst::isSigned(Pred);
9258 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9259 AllowPredicates);
9260 if (EL.hasAnyInfo())
9261 return EL;
9262 break;
9263 }
9264 case ICmpInst::ICMP_SGE:
9265 case ICmpInst::ICMP_UGE:
9266 // Since the loop is finite, an invariant RHS cannot include the boundary
9267 // value, otherwise it would loop forever.
9268 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9269 !isLoopInvariant(RHS, L))
9270 break;
9271 RHS = getAddExpr(getMinusOne(RHS->getType()), RHS);
9272 [[fallthrough]];
9273 case ICmpInst::ICMP_SGT:
9274 case ICmpInst::ICMP_UGT: { // while (X > Y)
9275 bool IsSigned = ICmpInst::isSigned(Pred);
9276 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9277 AllowPredicates);
9278 if (EL.hasAnyInfo())
9279 return EL;
9280 break;
9281 }
9282 default:
9283 break;
9284 }
9285
9286 return getCouldNotCompute();
9287}
9288
9290ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9291 SwitchInst *Switch,
9292 BasicBlock *ExitingBlock,
9293 bool ControlsOnlyExit) {
9294 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9295
9296 // Give up if the exit is the default dest of a switch.
9297 if (Switch->getDefaultDest() == ExitingBlock)
9298 return getCouldNotCompute();
9299
9300 assert(L->contains(Switch->getDefaultDest()) &&
9301 "Default case must not exit the loop!");
9302 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9303 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9304
9305 // while (X != Y) --> while (X-Y != 0)
9306 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9307 if (EL.hasAnyInfo())
9308 return EL;
9309
9310 return getCouldNotCompute();
9311}
9312
9313static ConstantInt *
9315 ScalarEvolution &SE) {
9316 const SCEV *InVal = SE.getConstant(C);
9317 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9318 assert(isa<SCEVConstant>(Val) &&
9319 "Evaluation of SCEV at constant didn't fold correctly?");
9320 return cast<SCEVConstant>(Val)->getValue();
9321}
9322
9323ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9324 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9325 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9326 if (!RHS)
9327 return getCouldNotCompute();
9328
9329 const BasicBlock *Latch = L->getLoopLatch();
9330 if (!Latch)
9331 return getCouldNotCompute();
9332
9333 const BasicBlock *Predecessor = L->getLoopPredecessor();
9334 if (!Predecessor)
9335 return getCouldNotCompute();
9336
9337 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9338 // Return LHS in OutLHS and shift_opt in OutOpCode.
9339 auto MatchPositiveShift =
9340 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9341
9342 using namespace PatternMatch;
9343
9344 ConstantInt *ShiftAmt;
9345 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9346 OutOpCode = Instruction::LShr;
9347 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9348 OutOpCode = Instruction::AShr;
9349 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9350 OutOpCode = Instruction::Shl;
9351 else
9352 return false;
9353
9354 return ShiftAmt->getValue().isStrictlyPositive();
9355 };
9356
9357 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9358 //
9359 // loop:
9360 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9361 // %iv.shifted = lshr i32 %iv, <positive constant>
9362 //
9363 // Return true on a successful match. Return the corresponding PHI node (%iv
9364 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9365 auto MatchShiftRecurrence =
9366 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9367 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9368
9369 {
9371 Value *V;
9372
9373 // If we encounter a shift instruction, "peel off" the shift operation,
9374 // and remember that we did so. Later when we inspect %iv's backedge
9375 // value, we will make sure that the backedge value uses the same
9376 // operation.
9377 //
9378 // Note: the peeled shift operation does not have to be the same
9379 // instruction as the one feeding into the PHI's backedge value. We only
9380 // really care about it being the same *kind* of shift instruction --
9381 // that's all that is required for our later inferences to hold.
9382 if (MatchPositiveShift(LHS, V, OpC)) {
9383 PostShiftOpCode = OpC;
9384 LHS = V;
9385 }
9386 }
9387
9388 PNOut = dyn_cast<PHINode>(LHS);
9389 if (!PNOut || PNOut->getParent() != L->getHeader())
9390 return false;
9391
9392 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9393 Value *OpLHS;
9394
9395 return
9396 // The backedge value for the PHI node must be a shift by a positive
9397 // amount
9398 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9399
9400 // of the PHI node itself
9401 OpLHS == PNOut &&
9402
9403 // and the kind of shift should be match the kind of shift we peeled
9404 // off, if any.
9405 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9406 };
9407
9408 PHINode *PN;
9410 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9411 return getCouldNotCompute();
9412
9413 const DataLayout &DL = getDataLayout();
9414
9415 // The key rationale for this optimization is that for some kinds of shift
9416 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9417 // within a finite number of iterations. If the condition guarding the
9418 // backedge (in the sense that the backedge is taken if the condition is true)
9419 // is false for the value the shift recurrence stabilizes to, then we know
9420 // that the backedge is taken only a finite number of times.
9421
9422 ConstantInt *StableValue = nullptr;
9423 switch (OpCode) {
9424 default:
9425 llvm_unreachable("Impossible case!");
9426
9427 case Instruction::AShr: {
9428 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9429 // bitwidth(K) iterations.
9430 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9431 KnownBits Known = computeKnownBits(FirstValue, DL, 0, &AC,
9432 Predecessor->getTerminator(), &DT);
9433 auto *Ty = cast<IntegerType>(RHS->getType());
9434 if (Known.isNonNegative())
9435 StableValue = ConstantInt::get(Ty, 0);
9436 else if (Known.isNegative())
9437 StableValue = ConstantInt::get(Ty, -1, true);
9438 else
9439 return getCouldNotCompute();
9440
9441 break;
9442 }
9443 case Instruction::LShr:
9444 case Instruction::Shl:
9445 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9446 // stabilize to 0 in at most bitwidth(K) iterations.
9447 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9448 break;
9449 }
9450
9451 auto *Result =
9452 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9453 assert(Result->getType()->isIntegerTy(1) &&
9454 "Otherwise cannot be an operand to a branch instruction");
9455
9456 if (Result->isZeroValue()) {
9457 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9458 const SCEV *UpperBound =
9460 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9461 }
9462
9463 return getCouldNotCompute();
9464}
9465
9466/// Return true if we can constant fold an instruction of the specified type,
9467/// assuming that all operands were constants.
9468static bool CanConstantFold(const Instruction *I) {
9469 if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
9470 isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) ||
9471 isa<LoadInst>(I) || isa<ExtractValueInst>(I))
9472 return true;
9473
9474 if (const CallInst *CI = dyn_cast<CallInst>(I))
9475 if (const Function *F = CI->getCalledFunction())
9476 return canConstantFoldCallTo(CI, F);
9477 return false;
9478}
9479
9480/// Determine whether this instruction can constant evolve within this loop
9481/// assuming its operands can all constant evolve.
9482static bool canConstantEvolve(Instruction *I, const Loop *L) {
9483 // An instruction outside of the loop can't be derived from a loop PHI.
9484 if (!L->contains(I)) return false;
9485
9486 if (isa<PHINode>(I)) {
9487 // We don't currently keep track of the control flow needed to evaluate
9488 // PHIs, so we cannot handle PHIs inside of loops.
9489 return L->getHeader() == I->getParent();
9490 }
9491
9492 // If we won't be able to constant fold this expression even if the operands
9493 // are constants, bail early.
9494 return CanConstantFold(I);
9495}
9496
9497/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9498/// recursing through each instruction operand until reaching a loop header phi.
9499static PHINode *
9502 unsigned Depth) {
9504 return nullptr;
9505
9506 // Otherwise, we can evaluate this instruction if all of its operands are
9507 // constant or derived from a PHI node themselves.
9508 PHINode *PHI = nullptr;
9509 for (Value *Op : UseInst->operands()) {
9510 if (isa<Constant>(Op)) continue;
9511
9512 Instruction *OpInst = dyn_cast<Instruction>(Op);
9513 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9514
9515 PHINode *P = dyn_cast<PHINode>(OpInst);
9516 if (!P)
9517 // If this operand is already visited, reuse the prior result.
9518 // We may have P != PHI if this is the deepest point at which the
9519 // inconsistent paths meet.
9520 P = PHIMap.lookup(OpInst);
9521 if (!P) {
9522 // Recurse and memoize the results, whether a phi is found or not.
9523 // This recursive call invalidates pointers into PHIMap.
9524 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9525 PHIMap[OpInst] = P;
9526 }
9527 if (!P)
9528 return nullptr; // Not evolving from PHI
9529 if (PHI && PHI != P)
9530 return nullptr; // Evolving from multiple different PHIs.
9531 PHI = P;
9532 }
9533 // This is a expression evolving from a constant PHI!
9534 return PHI;
9535}
9536
9537/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9538/// in the loop that V is derived from. We allow arbitrary operations along the
9539/// way, but the operands of an operation must either be constants or a value
9540/// derived from a constant PHI. If this expression does not fit with these
9541/// constraints, return null.
9543 Instruction *I = dyn_cast<Instruction>(V);
9544 if (!I || !canConstantEvolve(I, L)) return nullptr;
9545
9546 if (PHINode *PN = dyn_cast<PHINode>(I))
9547 return PN;
9548
9549 // Record non-constant instructions contained by the loop.
9551 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9552}
9553
9554/// EvaluateExpression - Given an expression that passes the
9555/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9556/// in the loop has the value PHIVal. If we can't fold this expression for some
9557/// reason, return null.
9560 const DataLayout &DL,
9561 const TargetLibraryInfo *TLI) {
9562 // Convenient constant check, but redundant for recursive calls.
9563 if (Constant *C = dyn_cast<Constant>(V)) return C;
9564 Instruction *I = dyn_cast<Instruction>(V);
9565 if (!I) return nullptr;
9566
9567 if (Constant *C = Vals.lookup(I)) return C;
9568
9569 // An instruction inside the loop depends on a value outside the loop that we
9570 // weren't given a mapping for, or a value such as a call inside the loop.
9571 if (!canConstantEvolve(I, L)) return nullptr;
9572
9573 // An unmapped PHI can be due to a branch or another loop inside this loop,
9574 // or due to this not being the initial iteration through a loop where we
9575 // couldn't compute the evolution of this particular PHI last time.
9576 if (isa<PHINode>(I)) return nullptr;
9577
9578 std::vector<Constant*> Operands(I->getNumOperands());
9579
9580 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9581 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9582 if (!Operand) {
9583 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9584 if (!Operands[i]) return nullptr;
9585 continue;
9586 }
9587 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9588 Vals[Operand] = C;
9589 if (!C) return nullptr;
9590 Operands[i] = C;
9591 }
9592
9593 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9594 /*AllowNonDeterministic=*/false);
9595}
9596
9597
9598// If every incoming value to PN except the one for BB is a specific Constant,
9599// return that, else return nullptr.
9601 Constant *IncomingVal = nullptr;
9602
9603 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9604 if (PN->getIncomingBlock(i) == BB)
9605 continue;
9606
9607 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9608 if (!CurrentVal)
9609 return nullptr;
9610
9611 if (IncomingVal != CurrentVal) {
9612 if (IncomingVal)
9613 return nullptr;
9614 IncomingVal = CurrentVal;
9615 }
9616 }
9617
9618 return IncomingVal;
9619}
9620
9621/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9622/// in the header of its containing loop, we know the loop executes a
9623/// constant number of times, and the PHI node is just a recurrence
9624/// involving constants, fold it.
9625Constant *
9626ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9627 const APInt &BEs,
9628 const Loop *L) {
9629 auto I = ConstantEvolutionLoopExitValue.find(PN);
9630 if (I != ConstantEvolutionLoopExitValue.end())
9631 return I->second;
9632
9634 return ConstantEvolutionLoopExitValue[PN] = nullptr; // Not going to evaluate it.
9635
9636 Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
9637
9639 BasicBlock *Header = L->getHeader();
9640 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9641
9642 BasicBlock *Latch = L->getLoopLatch();
9643 if (!Latch)
9644 return nullptr;
9645
9646 for (PHINode &PHI : Header->phis()) {
9647 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9648 CurrentIterVals[&PHI] = StartCST;
9649 }
9650 if (!CurrentIterVals.count(PN))
9651 return RetVal = nullptr;
9652
9653 Value *BEValue = PN->getIncomingValueForBlock(Latch);
9654
9655 // Execute the loop symbolically to determine the exit value.
9656 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
9657 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
9658
9659 unsigned NumIterations = BEs.getZExtValue(); // must be in range
9660 unsigned IterationNum = 0;
9661 const DataLayout &DL = getDataLayout();
9662 for (; ; ++IterationNum) {
9663 if (IterationNum == NumIterations)
9664 return RetVal = CurrentIterVals[PN]; // Got exit value!
9665
9666 // Compute the value of the PHIs for the next iteration.
9667 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
9669 Constant *NextPHI =
9670 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9671 if (!NextPHI)
9672 return nullptr; // Couldn't evaluate!
9673 NextIterVals[PN] = NextPHI;
9674
9675 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
9676
9677 // Also evaluate the other PHI nodes. However, we don't get to stop if we
9678 // cease to be able to evaluate one of them or if they stop evolving,
9679 // because that doesn't necessarily prevent us from computing PN.
9681 for (const auto &I : CurrentIterVals) {
9682 PHINode *PHI = dyn_cast<PHINode>(I.first);
9683 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
9684 PHIsToCompute.emplace_back(PHI, I.second);
9685 }
9686 // We use two distinct loops because EvaluateExpression may invalidate any
9687 // iterators into CurrentIterVals.
9688 for (const auto &I : PHIsToCompute) {
9689 PHINode *PHI = I.first;
9690 Constant *&NextPHI = NextIterVals[PHI];
9691 if (!NextPHI) { // Not already computed.
9692 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9693 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9694 }
9695 if (NextPHI != I.second)
9696 StoppedEvolving = false;
9697 }
9698
9699 // If all entries in CurrentIterVals == NextIterVals then we can stop
9700 // iterating, the loop can't continue to change.
9701 if (StoppedEvolving)
9702 return RetVal = CurrentIterVals[PN];
9703
9704 CurrentIterVals.swap(NextIterVals);
9705 }
9706}
9707
9708const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
9709 Value *Cond,
9710 bool ExitWhen) {
9712 if (!PN) return getCouldNotCompute();
9713
9714 // If the loop is canonicalized, the PHI will have exactly two entries.
9715 // That's the only form we support here.
9716 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
9717
9719 BasicBlock *Header = L->getHeader();
9720 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9721
9722 BasicBlock *Latch = L->getLoopLatch();
9723 assert(Latch && "Should follow from NumIncomingValues == 2!");
9724
9725 for (PHINode &PHI : Header->phis()) {
9726 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9727 CurrentIterVals[&PHI] = StartCST;
9728 }
9729 if (!CurrentIterVals.count(PN))
9730 return getCouldNotCompute();
9731
9732 // Okay, we find a PHI node that defines the trip count of this loop. Execute
9733 // the loop symbolically to determine when the condition gets a value of
9734 // "ExitWhen".
9735 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
9736 const DataLayout &DL = getDataLayout();
9737 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
9738 auto *CondVal = dyn_cast_or_null<ConstantInt>(
9739 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
9740
9741 // Couldn't symbolically evaluate.
9742 if (!CondVal) return getCouldNotCompute();
9743
9744 if (CondVal->getValue() == uint64_t(ExitWhen)) {
9745 ++NumBruteForceTripCountsComputed;
9746 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
9747 }
9748
9749 // Update all the PHI nodes for the next iteration.
9751
9752 // Create a list of which PHIs we need to compute. We want to do this before
9753 // calling EvaluateExpression on them because that may invalidate iterators
9754 // into CurrentIterVals.
9755 SmallVector<PHINode *, 8> PHIsToCompute;
9756 for (const auto &I : CurrentIterVals) {
9757 PHINode *PHI = dyn_cast<PHINode>(I.first);
9758 if (!PHI || PHI->getParent() != Header) continue;
9759 PHIsToCompute.push_back(PHI);
9760 }
9761 for (PHINode *PHI : PHIsToCompute) {
9762 Constant *&NextPHI = NextIterVals[PHI];
9763 if (NextPHI) continue; // Already computed!
9764
9765 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9766 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9767 }
9768 CurrentIterVals.swap(NextIterVals);
9769 }
9770
9771 // Too many iterations were needed to evaluate.
9772 return getCouldNotCompute();
9773}
9774
9775const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
9777 ValuesAtScopes[V];
9778 // Check to see if we've folded this expression at this loop before.
9779 for (auto &LS : Values)
9780 if (LS.first == L)
9781 return LS.second ? LS.second : V;
9782
9783 Values.emplace_back(L, nullptr);
9784
9785 // Otherwise compute it.
9786 const SCEV *C = computeSCEVAtScope(V, L);
9787 for (auto &LS : reverse(ValuesAtScopes[V]))
9788 if (LS.first == L) {
9789 LS.second = C;
9790 if (!isa<SCEVConstant>(C))
9791 ValuesAtScopesUsers[C].push_back({L, V});
9792 break;
9793 }
9794 return C;
9795}
9796
9797/// This builds up a Constant using the ConstantExpr interface. That way, we
9798/// will return Constants for objects which aren't represented by a
9799/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
9800/// Returns NULL if the SCEV isn't representable as a Constant.
9802 switch (V->getSCEVType()) {
9803 case scCouldNotCompute:
9804 case scAddRecExpr:
9805 case scVScale:
9806 return nullptr;
9807 case scConstant:
9808 return cast<SCEVConstant>(V)->getValue();
9809 case scUnknown:
9810 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
9811 case scPtrToInt: {
9812 const SCEVPtrToIntExpr *P2I = cast<SCEVPtrToIntExpr>(V);
9813 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
9814 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
9815
9816 return nullptr;
9817 }
9818 case scTruncate: {
9819 const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
9820 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
9821 return ConstantExpr::getTrunc(CastOp, ST->getType());
9822 return nullptr;
9823 }
9824 case scAddExpr: {
9825 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
9826 Constant *C = nullptr;
9827 for (const SCEV *Op : SA->operands()) {
9829 if (!OpC)
9830 return nullptr;
9831 if (!C) {
9832 C = OpC;
9833 continue;
9834 }
9835 assert(!C->getType()->isPointerTy() &&
9836 "Can only have one pointer, and it must be last");
9837 if (OpC->getType()->isPointerTy()) {
9838 // The offsets have been converted to bytes. We can add bytes using
9839 // an i8 GEP.
9841 OpC, C);
9842 } else {
9843 C = ConstantExpr::getAdd(C, OpC);
9844 }
9845 }
9846 return C;
9847 }
9848 case scMulExpr:
9849 case scSignExtend:
9850 case scZeroExtend:
9851 case scUDivExpr:
9852 case scSMaxExpr:
9853 case scUMaxExpr:
9854 case scSMinExpr:
9855 case scUMinExpr:
9857 return nullptr;
9858 }
9859 llvm_unreachable("Unknown SCEV kind!");
9860}
9861
9862const SCEV *
9863ScalarEvolution::getWithOperands(const SCEV *S,
9865 switch (S->getSCEVType()) {
9866 case scTruncate:
9867 case scZeroExtend:
9868 case scSignExtend:
9869 case scPtrToInt:
9870 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
9871 case scAddRecExpr: {
9872 auto *AddRec = cast<SCEVAddRecExpr>(S);
9873 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
9874 }
9875 case scAddExpr:
9876 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
9877 case scMulExpr:
9878 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
9879 case scUDivExpr:
9880 return getUDivExpr(NewOps[0], NewOps[1]);
9881 case scUMaxExpr:
9882 case scSMaxExpr:
9883 case scUMinExpr:
9884 case scSMinExpr:
9885 return getMinMaxExpr(S->getSCEVType(), NewOps);
9887 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
9888 case scConstant:
9889 case scVScale:
9890 case scUnknown:
9891 return S;
9892 case scCouldNotCompute:
9893 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
9894 }
9895 llvm_unreachable("Unknown SCEV kind!");
9896}
9897
9898const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
9899 switch (V->getSCEVType()) {
9900 case scConstant:
9901 case scVScale:
9902 return V;
9903 case scAddRecExpr: {
9904 // If this is a loop recurrence for a loop that does not contain L, then we
9905 // are dealing with the final value computed by the loop.
9906 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
9907 // First, attempt to evaluate each operand.
9908 // Avoid performing the look-up in the common case where the specified
9909 // expression has no loop-variant portions.
9910 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
9911 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
9912 if (OpAtScope == AddRec->getOperand(i))
9913 continue;
9914
9915 // Okay, at least one of these operands is loop variant but might be
9916 // foldable. Build a new instance of the folded commutative expression.
9918 NewOps.reserve(AddRec->getNumOperands());
9919 append_range(NewOps, AddRec->operands().take_front(i));
9920 NewOps.push_back(OpAtScope);
9921 for (++i; i != e; ++i)
9922 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
9923
9924 const SCEV *FoldedRec = getAddRecExpr(
9925 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
9926 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
9927 // The addrec may be folded to a nonrecurrence, for example, if the
9928 // induction variable is multiplied by zero after constant folding. Go
9929 // ahead and return the folded value.
9930 if (!AddRec)
9931 return FoldedRec;
9932 break;
9933 }
9934
9935 // If the scope is outside the addrec's loop, evaluate it by using the
9936 // loop exit value of the addrec.
9937 if (!AddRec->getLoop()->contains(L)) {
9938 // To evaluate this recurrence, we need to know how many times the AddRec
9939 // loop iterates. Compute this now.
9940 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
9941 if (BackedgeTakenCount == getCouldNotCompute())
9942 return AddRec;
9943
9944 // Then, evaluate the AddRec.
9945 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
9946 }
9947
9948 return AddRec;
9949 }
9950 case scTruncate:
9951 case scZeroExtend:
9952 case scSignExtend:
9953 case scPtrToInt:
9954 case scAddExpr:
9955 case scMulExpr:
9956 case scUDivExpr:
9957 case scUMaxExpr:
9958 case scSMaxExpr:
9959 case scUMinExpr:
9960 case scSMinExpr:
9961 case scSequentialUMinExpr: {
9962 ArrayRef<const SCEV *> Ops = V->operands();
9963 // Avoid performing the look-up in the common case where the specified
9964 // expression has no loop-variant portions.
9965 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
9966 const SCEV *OpAtScope = getSCEVAtScope(Ops[i], L);
9967 if (OpAtScope != Ops[i]) {
9968 // Okay, at least one of these operands is loop variant but might be
9969 // foldable. Build a new instance of the folded commutative expression.
9971 NewOps.reserve(Ops.size());
9972 append_range(NewOps, Ops.take_front(i));
9973 NewOps.push_back(OpAtScope);
9974
9975 for (++i; i != e; ++i) {
9976 OpAtScope = getSCEVAtScope(Ops[i], L);
9977 NewOps.push_back(OpAtScope);
9978 }
9979
9980 return getWithOperands(V, NewOps);
9981 }
9982 }
9983 // If we got here, all operands are loop invariant.
9984 return V;
9985 }
9986 case scUnknown: {
9987 // If this instruction is evolved from a constant-evolving PHI, compute the
9988 // exit value from the loop without using SCEVs.
9989 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
9990 Instruction *I = dyn_cast<Instruction>(SU->getValue());
9991 if (!I)
9992 return V; // This is some other type of SCEVUnknown, just return it.
9993
9994 if (PHINode *PN = dyn_cast<PHINode>(I)) {
9995 const Loop *CurrLoop = this->LI[I->getParent()];
9996 // Looking for loop exit value.
9997 if (CurrLoop && CurrLoop->getParentLoop() == L &&
9998 PN->getParent() == CurrLoop->getHeader()) {
9999 // Okay, there is no closed form solution for the PHI node. Check
10000 // to see if the loop that contains it has a known backedge-taken
10001 // count. If so, we may be able to force computation of the exit
10002 // value.
10003 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10004 // This trivial case can show up in some degenerate cases where
10005 // the incoming IR has not yet been fully simplified.
10006 if (BackedgeTakenCount->isZero()) {
10007 Value *InitValue = nullptr;
10008 bool MultipleInitValues = false;
10009 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10010 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10011 if (!InitValue)
10012 InitValue = PN->getIncomingValue(i);
10013 else if (InitValue != PN->getIncomingValue(i)) {
10014 MultipleInitValues = true;
10015 break;
10016 }
10017 }
10018 }
10019 if (!MultipleInitValues && InitValue)
10020 return getSCEV(InitValue);
10021 }
10022 // Do we have a loop invariant value flowing around the backedge
10023 // for a loop which must execute the backedge?
10024 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10025 isKnownNonZero(BackedgeTakenCount) &&
10026 PN->getNumIncomingValues() == 2) {
10027
10028 unsigned InLoopPred =
10029 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10030 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10031 if (CurrLoop->isLoopInvariant(BackedgeVal))
10032 return getSCEV(BackedgeVal);
10033 }
10034 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10035 // Okay, we know how many times the containing loop executes. If
10036 // this is a constant evolving PHI node, get the final value at
10037 // the specified iteration number.
10038 Constant *RV =
10039 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10040 if (RV)
10041 return getSCEV(RV);
10042 }
10043 }
10044 }
10045
10046 // Okay, this is an expression that we cannot symbolically evaluate
10047 // into a SCEV. Check to see if it's possible to symbolically evaluate
10048 // the arguments into constants, and if so, try to constant propagate the
10049 // result. This is particularly useful for computing loop exit values.
10050 if (!CanConstantFold(I))
10051 return V; // This is some other type of SCEVUnknown, just return it.
10052
10054 Operands.reserve(I->getNumOperands());
10055 bool MadeImprovement = false;
10056 for (Value *Op : I->operands()) {
10057 if (Constant *C = dyn_cast<Constant>(Op)) {
10058 Operands.push_back(C);
10059 continue;
10060 }
10061
10062 // If any of the operands is non-constant and if they are
10063 // non-integer and non-pointer, don't even try to analyze them
10064 // with scev techniques.
10065 if (!isSCEVable(Op->getType()))
10066 return V;
10067
10068 const SCEV *OrigV = getSCEV(Op);
10069 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10070 MadeImprovement |= OrigV != OpV;
10071
10073 if (!C)
10074 return V;
10075 assert(C->getType() == Op->getType() && "Type mismatch");
10076 Operands.push_back(C);
10077 }
10078
10079 // Check to see if getSCEVAtScope actually made an improvement.
10080 if (!MadeImprovement)
10081 return V; // This is some other type of SCEVUnknown, just return it.
10082
10083 Constant *C = nullptr;
10084 const DataLayout &DL = getDataLayout();
10085 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10086 /*AllowNonDeterministic=*/false);
10087 if (!C)
10088 return V;
10089 return getSCEV(C);
10090 }
10091 case scCouldNotCompute:
10092 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10093 }
10094 llvm_unreachable("Unknown SCEV type!");
10095}
10096
10098 return getSCEVAtScope(getSCEV(V), L);
10099}
10100
10101const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10102 if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S))
10103 return stripInjectiveFunctions(ZExt->getOperand());
10104 if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S))
10105 return stripInjectiveFunctions(SExt->getOperand());
10106 return S;
10107}
10108
10109/// Finds the minimum unsigned root of the following equation:
10110///
10111/// A * X = B (mod N)
10112///
10113/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10114/// A and B isn't important.
10115///
10116/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10117static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
10118 ScalarEvolution &SE) {
10119 uint32_t BW = A.getBitWidth();
10120 assert(BW == SE.getTypeSizeInBits(B->getType()));
10121 assert(A != 0 && "A must be non-zero.");
10122
10123 // 1. D = gcd(A, N)
10124 //
10125 // The gcd of A and N may have only one prime factor: 2. The number of
10126 // trailing zeros in A is its multiplicity
10127 uint32_t Mult2 = A.countr_zero();
10128 // D = 2^Mult2
10129
10130 // 2. Check if B is divisible by D.
10131 //
10132 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10133 // is not less than multiplicity of this prime factor for D.
10134 if (SE.getMinTrailingZeros(B) < Mult2)
10135 return SE.getCouldNotCompute();
10136
10137 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10138 // modulo (N / D).
10139 //
10140 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10141 // (N / D) in general. The inverse itself always fits into BW bits, though,
10142 // so we immediately truncate it.
10143 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10144 APInt I = AD.multiplicativeInverse().zext(BW);
10145
10146 // 4. Compute the minimum unsigned root of the equation:
10147 // I * (B / D) mod (N / D)
10148 // To simplify the computation, we factor out the divide by D:
10149 // (I * B mod N) / D
10150 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10151 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10152}
10153
10154/// For a given quadratic addrec, generate coefficients of the corresponding
10155/// quadratic equation, multiplied by a common value to ensure that they are
10156/// integers.
10157/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10158/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10159/// were multiplied by, and BitWidth is the bit width of the original addrec
10160/// coefficients.
10161/// This function returns std::nullopt if the addrec coefficients are not
10162/// compile- time constants.
10163static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10165 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10166 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10167 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10168 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10169 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10170 << *AddRec << '\n');
10171
10172 // We currently can only solve this if the coefficients are constants.
10173 if (!LC || !MC || !NC) {
10174 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10175 return std::nullopt;
10176 }
10177
10178 APInt L = LC->getAPInt();
10179 APInt M = MC->getAPInt();
10180 APInt N = NC->getAPInt();
10181 assert(!N.isZero() && "This is not a quadratic addrec");
10182
10183 unsigned BitWidth = LC->getAPInt().getBitWidth();
10184 unsigned NewWidth = BitWidth + 1;
10185 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10186 << BitWidth << '\n');
10187 // The sign-extension (as opposed to a zero-extension) here matches the
10188 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10189 N = N.sext(NewWidth);
10190 M = M.sext(NewWidth);
10191 L = L.sext(NewWidth);
10192
10193 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10194 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10195 // L+M, L+2M+N, L+3M+3N, ...
10196 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10197 //
10198 // The equation Acc = 0 is then
10199 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10200 // In a quadratic form it becomes:
10201 // N n^2 + (2M-N) n + 2L = 0.
10202
10203 APInt A = N;
10204 APInt B = 2 * M - A;
10205 APInt C = 2 * L;
10206 APInt T = APInt(NewWidth, 2);
10207 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10208 << "x + " << C << ", coeff bw: " << NewWidth
10209 << ", multiplied by " << T << '\n');
10210 return std::make_tuple(A, B, C, T, BitWidth);
10211}
10212
10213/// Helper function to compare optional APInts:
10214/// (a) if X and Y both exist, return min(X, Y),
10215/// (b) if neither X nor Y exist, return std::nullopt,
10216/// (c) if exactly one of X and Y exists, return that value.
10217static std::optional<APInt> MinOptional(std::optional<APInt> X,
10218 std::optional<APInt> Y) {
10219 if (X && Y) {
10220 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10221 APInt XW = X->sext(W);
10222 APInt YW = Y->sext(W);
10223 return XW.slt(YW) ? *X : *Y;
10224 }
10225 if (!X && !Y)
10226 return std::nullopt;
10227 return X ? *X : *Y;
10228}
10229
10230/// Helper function to truncate an optional APInt to a given BitWidth.
10231/// When solving addrec-related equations, it is preferable to return a value
10232/// that has the same bit width as the original addrec's coefficients. If the
10233/// solution fits in the original bit width, truncate it (except for i1).
10234/// Returning a value of a different bit width may inhibit some optimizations.
10235///
10236/// In general, a solution to a quadratic equation generated from an addrec
10237/// may require BW+1 bits, where BW is the bit width of the addrec's
10238/// coefficients. The reason is that the coefficients of the quadratic
10239/// equation are BW+1 bits wide (to avoid truncation when converting from
10240/// the addrec to the equation).
10241static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10242 unsigned BitWidth) {
10243 if (!X)
10244 return std::nullopt;
10245 unsigned W = X->getBitWidth();
10246 if (BitWidth > 1 && BitWidth < W && X->isIntN(BitWidth))
10247 return X->trunc(BitWidth);
10248 return X;
10249}
10250
10251/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10252/// iterations. The values L, M, N are assumed to be signed, and they
10253/// should all have the same bit widths.
10254/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10255/// where BW is the bit width of the addrec's coefficients.
10256/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10257/// returned as such, otherwise the bit width of the returned value may
10258/// be greater than BW.
10259///
10260/// This function returns std::nullopt if
10261/// (a) the addrec coefficients are not constant, or
10262/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10263/// like x^2 = 5, no integer solutions exist, in other cases an integer
10264/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10265static std::optional<APInt>
10267 APInt A, B, C, M;
10268 unsigned BitWidth;
10269 auto T = GetQuadraticEquation(AddRec);
10270 if (!T)
10271 return std::nullopt;
10272
10273 std::tie(A, B, C, M, BitWidth) = *T;
10274 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10275 std::optional<APInt> X =
10277 if (!X)
10278 return std::nullopt;
10279
10280 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10281 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10282 if (!V->isZero())
10283 return std::nullopt;
10284
10285 return TruncIfPossible(X, BitWidth);
10286}
10287
10288/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10289/// iterations. The values M, N are assumed to be signed, and they
10290/// should all have the same bit widths.
10291/// Find the least n such that c(n) does not belong to the given range,
10292/// while c(n-1) does.
10293///
10294/// This function returns std::nullopt if
10295/// (a) the addrec coefficients are not constant, or
10296/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10297/// bounds of the range.
10298static std::optional<APInt>
10300 const ConstantRange &Range, ScalarEvolution &SE) {
10301 assert(AddRec->getOperand(0)->isZero() &&
10302 "Starting value of addrec should be 0");
10303 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10304 << Range << ", addrec " << *AddRec << '\n');
10305 // This case is handled in getNumIterationsInRange. Here we can assume that
10306 // we start in the range.
10307 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10308 "Addrec's initial value should be in range");
10309
10310 APInt A, B, C, M;
10311 unsigned BitWidth;
10312 auto T = GetQuadraticEquation(AddRec);
10313 if (!T)
10314 return std::nullopt;
10315
10316 // Be careful about the return value: there can be two reasons for not
10317 // returning an actual number. First, if no solutions to the equations
10318 // were found, and second, if the solutions don't leave the given range.
10319 // The first case means that the actual solution is "unknown", the second
10320 // means that it's known, but not valid. If the solution is unknown, we
10321 // cannot make any conclusions.
10322 // Return a pair: the optional solution and a flag indicating if the
10323 // solution was found.
10324 auto SolveForBoundary =
10325 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10326 // Solve for signed overflow and unsigned overflow, pick the lower
10327 // solution.
10328 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10329 << Bound << " (before multiplying by " << M << ")\n");
10330 Bound *= M; // The quadratic equation multiplier.
10331
10332 std::optional<APInt> SO;
10333 if (BitWidth > 1) {
10334 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10335 "signed overflow\n");
10337 }
10338 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10339 "unsigned overflow\n");
10340 std::optional<APInt> UO =
10342
10343 auto LeavesRange = [&] (const APInt &X) {
10344 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10345 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10346 if (Range.contains(V0->getValue()))
10347 return false;
10348 // X should be at least 1, so X-1 is non-negative.
10349 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10350 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10351 if (Range.contains(V1->getValue()))
10352 return true;
10353 return false;
10354 };
10355
10356 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10357 // can be a solution, but the function failed to find it. We cannot treat it
10358 // as "no solution".
10359 if (!SO || !UO)
10360 return {std::nullopt, false};
10361
10362 // Check the smaller value first to see if it leaves the range.
10363 // At this point, both SO and UO must have values.
10364 std::optional<APInt> Min = MinOptional(SO, UO);
10365 if (LeavesRange(*Min))
10366 return { Min, true };
10367 std::optional<APInt> Max = Min == SO ? UO : SO;
10368 if (LeavesRange(*Max))
10369 return { Max, true };
10370
10371 // Solutions were found, but were eliminated, hence the "true".
10372 return {std::nullopt, true};
10373 };
10374
10375 std::tie(A, B, C, M, BitWidth) = *T;
10376 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10377 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10378 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10379 auto SL = SolveForBoundary(Lower);
10380 auto SU = SolveForBoundary(Upper);
10381 // If any of the solutions was unknown, no meaninigful conclusions can
10382 // be made.
10383 if (!SL.second || !SU.second)
10384 return std::nullopt;
10385
10386 // Claim: The correct solution is not some value between Min and Max.
10387 //
10388 // Justification: Assuming that Min and Max are different values, one of
10389 // them is when the first signed overflow happens, the other is when the
10390 // first unsigned overflow happens. Crossing the range boundary is only
10391 // possible via an overflow (treating 0 as a special case of it, modeling
10392 // an overflow as crossing k*2^W for some k).
10393 //
10394 // The interesting case here is when Min was eliminated as an invalid
10395 // solution, but Max was not. The argument is that if there was another
10396 // overflow between Min and Max, it would also have been eliminated if
10397 // it was considered.
10398 //
10399 // For a given boundary, it is possible to have two overflows of the same
10400 // type (signed/unsigned) without having the other type in between: this
10401 // can happen when the vertex of the parabola is between the iterations
10402 // corresponding to the overflows. This is only possible when the two
10403 // overflows cross k*2^W for the same k. In such case, if the second one
10404 // left the range (and was the first one to do so), the first overflow
10405 // would have to enter the range, which would mean that either we had left
10406 // the range before or that we started outside of it. Both of these cases
10407 // are contradictions.
10408 //
10409 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10410 // solution is not some value between the Max for this boundary and the
10411 // Min of the other boundary.
10412 //
10413 // Justification: Assume that we had such Max_A and Min_B corresponding
10414 // to range boundaries A and B and such that Max_A < Min_B. If there was
10415 // a solution between Max_A and Min_B, it would have to be caused by an
10416 // overflow corresponding to either A or B. It cannot correspond to B,
10417 // since Min_B is the first occurrence of such an overflow. If it
10418 // corresponded to A, it would have to be either a signed or an unsigned
10419 // overflow that is larger than both eliminated overflows for A. But
10420 // between the eliminated overflows and this overflow, the values would
10421 // cover the entire value space, thus crossing the other boundary, which
10422 // is a contradiction.
10423
10424 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10425}
10426
10427ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10428 const Loop *L,
10429 bool ControlsOnlyExit,
10430 bool AllowPredicates) {
10431
10432 // This is only used for loops with a "x != y" exit test. The exit condition
10433 // is now expressed as a single expression, V = x-y. So the exit test is
10434 // effectively V != 0. We know and take advantage of the fact that this
10435 // expression only being used in a comparison by zero context.
10436
10438 // If the value is a constant
10439 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10440 // If the value is already zero, the branch will execute zero times.
10441 if (C->getValue()->isZero()) return C;
10442 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10443 }
10444
10445 const SCEVAddRecExpr *AddRec =
10446 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10447
10448 if (!AddRec && AllowPredicates)
10449 // Try to make this an AddRec using runtime tests, in the first X
10450 // iterations of this loop, where X is the SCEV expression found by the
10451 // algorithm below.
10452 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10453
10454 if (!AddRec || AddRec->getLoop() != L)
10455 return getCouldNotCompute();
10456
10457 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10458 // the quadratic equation to solve it.
10459 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10460 // We can only use this value if the chrec ends up with an exact zero
10461 // value at this index. When solving for "X*X != 5", for example, we
10462 // should not accept a root of 2.
10463 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10464 const auto *R = cast<SCEVConstant>(getConstant(*S));
10465 return ExitLimit(R, R, R, false, Predicates);
10466 }
10467 return getCouldNotCompute();
10468 }
10469
10470 // Otherwise we can only handle this if it is affine.
10471 if (!AddRec->isAffine())
10472 return getCouldNotCompute();
10473
10474 // If this is an affine expression, the execution count of this branch is
10475 // the minimum unsigned root of the following equation:
10476 //
10477 // Start + Step*N = 0 (mod 2^BW)
10478 //
10479 // equivalent to:
10480 //
10481 // Step*N = -Start (mod 2^BW)
10482 //
10483 // where BW is the common bit width of Start and Step.
10484
10485 // Get the initial value for the loop.
10486 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10487 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10488 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10489
10490 if (!isLoopInvariant(Step, L))
10491 return getCouldNotCompute();
10492
10493 // Specialize step for this loop so we get context sensitive facts below.
10494 const SCEV *StepWLG = applyLoopGuards(Step, L);
10495
10496 // For positive steps (counting up until unsigned overflow):
10497 // N = -Start/Step (as unsigned)
10498 // For negative steps (counting down to zero):
10499 // N = Start/-Step
10500 // First compute the unsigned distance from zero in the direction of Step.
10501 bool CountDown = isKnownNegative(StepWLG);
10502 if (!CountDown && !isKnownNonNegative(StepWLG))
10503 return getCouldNotCompute();
10504
10505 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10506 // Handle unitary steps, which cannot wraparound.
10507 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10508 // N = Distance (as unsigned)
10509 if (StepC &&
10510 (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne())) {
10511 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, L));
10512 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10513
10514 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10515 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10516 // case, and see if we can improve the bound.
10517 //
10518 // Explicitly handling this here is necessary because getUnsignedRange
10519 // isn't context-sensitive; it doesn't know that we only care about the
10520 // range inside the loop.
10521 const SCEV *Zero = getZero(Distance->getType());
10522 const SCEV *One = getOne(Distance->getType());
10523 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10524 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10525 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10526 // as "unsigned_max(Distance + 1) - 1".
10527 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10528 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10529 }
10530 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10531 Predicates);
10532 }
10533
10534 // If the condition controls loop exit (the loop exits only if the expression
10535 // is true) and the addition is no-wrap we can use unsigned divide to
10536 // compute the backedge count. In this case, the step may not divide the
10537 // distance, but we don't care because if the condition is "missed" the loop
10538 // will have undefined behavior due to wrapping.
10539 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10540 loopHasNoAbnormalExits(AddRec->getLoop())) {
10541
10542 // If the stride is zero, the loop must be infinite. In C++, most loops
10543 // are finite by assumption, in which case the step being zero implies
10544 // UB must execute if the loop is entered.
10545 if (!loopIsFiniteByAssumption(L) && !isKnownNonZero(StepWLG))
10546 return getCouldNotCompute();
10547
10548 const SCEV *Exact =
10549 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10550 const SCEV *ConstantMax = getCouldNotCompute();
10551 if (Exact != getCouldNotCompute()) {
10553 ConstantMax =
10555 }
10556 const SCEV *SymbolicMax =
10557 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10558 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10559 }
10560
10561 // Solve the general equation.
10562 if (!StepC || StepC->getValue()->isZero())
10563 return getCouldNotCompute();
10564 const SCEV *E = SolveLinEquationWithOverflow(StepC->getAPInt(),
10565 getNegativeSCEV(Start), *this);
10566
10567 const SCEV *M = E;
10568 if (E != getCouldNotCompute()) {
10569 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, L));
10570 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10571 }
10572 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10573 return ExitLimit(E, M, S, false, Predicates);
10574}
10575
10577ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10578 // Loops that look like: while (X == 0) are very strange indeed. We don't
10579 // handle them yet except for the trivial case. This could be expanded in the
10580 // future as needed.
10581
10582 // If the value is a constant, check to see if it is known to be non-zero
10583 // already. If so, the backedge will execute zero times.
10584 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10585 if (!C->getValue()->isZero())
10586 return getZero(C->getType());
10587 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10588 }
10589
10590 // We could implement others, but I really doubt anyone writes loops like
10591 // this, and if they did, they would already be constant folded.
10592 return getCouldNotCompute();
10593}
10594
10595std::pair<const BasicBlock *, const BasicBlock *>
10596ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10597 const {
10598 // If the block has a unique predecessor, then there is no path from the
10599 // predecessor to the block that does not go through the direct edge
10600 // from the predecessor to the block.
10601 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10602 return {Pred, BB};
10603
10604 // A loop's header is defined to be a block that dominates the loop.
10605 // If the header has a unique predecessor outside the loop, it must be
10606 // a block that has exactly one successor that can reach the loop.
10607 if (const Loop *L = LI.getLoopFor(BB))
10608 return {L->getLoopPredecessor(), L->getHeader()};
10609
10610 return {nullptr, nullptr};
10611}
10612
10613/// SCEV structural equivalence is usually sufficient for testing whether two
10614/// expressions are equal, however for the purposes of looking for a condition
10615/// guarding a loop, it can be useful to be a little more general, since a
10616/// front-end may have replicated the controlling expression.
10617static bool HasSameValue(const SCEV *A, const SCEV *B) {
10618 // Quick check to see if they are the same SCEV.
10619 if (A == B) return true;
10620
10621 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
10622 // Not all instructions that are "identical" compute the same value. For
10623 // instance, two distinct alloca instructions allocating the same type are
10624 // identical and do not read memory; but compute distinct values.
10625 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
10626 };
10627
10628 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
10629 // two different instructions with the same value. Check for this case.
10630 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
10631 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
10632 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
10633 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
10634 if (ComputesEqualValues(AI, BI))
10635 return true;
10636
10637 // Otherwise assume they may have a different value.
10638 return false;
10639}
10640
10641static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
10642 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S);
10643 if (!Add || Add->getNumOperands() != 2)
10644 return false;
10645 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
10646 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10647 LHS = Add->getOperand(1);
10648 RHS = ME->getOperand(1);
10649 return true;
10650 }
10651 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
10652 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10653 LHS = Add->getOperand(0);
10654 RHS = ME->getOperand(1);
10655 return true;
10656 }
10657 return false;
10658}
10659
10661 const SCEV *&LHS, const SCEV *&RHS,
10662 unsigned Depth) {
10663 bool Changed = false;
10664 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
10665 // '0 != 0'.
10666 auto TrivialCase = [&](bool TriviallyTrue) {
10668 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
10669 return true;
10670 };
10671 // If we hit the max recursion limit bail out.
10672 if (Depth >= 3)
10673 return false;
10674
10675 // Canonicalize a constant to the right side.
10676 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
10677 // Check for both operands constant.
10678 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
10679 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
10680 return TrivialCase(false);
10681 return TrivialCase(true);
10682 }
10683 // Otherwise swap the operands to put the constant on the right.
10684 std::swap(LHS, RHS);
10685 Pred = ICmpInst::getSwappedPredicate(Pred);
10686 Changed = true;
10687 }
10688
10689 // If we're comparing an addrec with a value which is loop-invariant in the
10690 // addrec's loop, put the addrec on the left. Also make a dominance check,
10691 // as both operands could be addrecs loop-invariant in each other's loop.
10692 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
10693 const Loop *L = AR->getLoop();
10694 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
10695 std::swap(LHS, RHS);
10696 Pred = ICmpInst::getSwappedPredicate(Pred);
10697 Changed = true;
10698 }
10699 }
10700
10701 // If there's a constant operand, canonicalize comparisons with boundary
10702 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
10703 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
10704 const APInt &RA = RC->getAPInt();
10705
10706 bool SimplifiedByConstantRange = false;
10707
10708 if (!ICmpInst::isEquality(Pred)) {
10710 if (ExactCR.isFullSet())
10711 return TrivialCase(true);
10712 if (ExactCR.isEmptySet())
10713 return TrivialCase(false);
10714
10715 APInt NewRHS;
10716 CmpInst::Predicate NewPred;
10717 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
10718 ICmpInst::isEquality(NewPred)) {
10719 // We were able to convert an inequality to an equality.
10720 Pred = NewPred;
10721 RHS = getConstant(NewRHS);
10722 Changed = SimplifiedByConstantRange = true;
10723 }
10724 }
10725
10726 if (!SimplifiedByConstantRange) {
10727 switch (Pred) {
10728 default:
10729 break;
10730 case ICmpInst::ICMP_EQ:
10731 case ICmpInst::ICMP_NE:
10732 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
10733 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
10734 Changed = true;
10735 break;
10736
10737 // The "Should have been caught earlier!" messages refer to the fact
10738 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
10739 // should have fired on the corresponding cases, and canonicalized the
10740 // check to trivial case.
10741
10742 case ICmpInst::ICMP_UGE:
10743 assert(!RA.isMinValue() && "Should have been caught earlier!");
10744 Pred = ICmpInst::ICMP_UGT;
10745 RHS = getConstant(RA - 1);
10746 Changed = true;
10747 break;
10748 case ICmpInst::ICMP_ULE:
10749 assert(!RA.isMaxValue() && "Should have been caught earlier!");
10750 Pred = ICmpInst::ICMP_ULT;
10751 RHS = getConstant(RA + 1);
10752 Changed = true;
10753 break;
10754 case ICmpInst::ICMP_SGE:
10755 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
10756 Pred = ICmpInst::ICMP_SGT;
10757 RHS = getConstant(RA - 1);
10758 Changed = true;
10759 break;
10760 case ICmpInst::ICMP_SLE:
10761 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
10762 Pred = ICmpInst::ICMP_SLT;
10763 RHS = getConstant(RA + 1);
10764 Changed = true;
10765 break;
10766 }
10767 }
10768 }
10769
10770 // Check for obvious equality.
10771 if (HasSameValue(LHS, RHS)) {
10772 if (ICmpInst::isTrueWhenEqual(Pred))
10773 return TrivialCase(true);
10775 return TrivialCase(false);
10776 }
10777
10778 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
10779 // adding or subtracting 1 from one of the operands.
10780 switch (Pred) {
10781 case ICmpInst::ICMP_SLE:
10782 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
10783 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10785 Pred = ICmpInst::ICMP_SLT;
10786 Changed = true;
10787 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
10788 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
10790 Pred = ICmpInst::ICMP_SLT;
10791 Changed = true;
10792 }
10793 break;
10794 case ICmpInst::ICMP_SGE:
10795 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
10796 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
10798 Pred = ICmpInst::ICMP_SGT;
10799 Changed = true;
10800 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
10801 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10803 Pred = ICmpInst::ICMP_SGT;
10804 Changed = true;
10805 }
10806 break;
10807 case ICmpInst::ICMP_ULE:
10808 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
10809 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10811 Pred = ICmpInst::ICMP_ULT;
10812 Changed = true;
10813 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
10814 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
10815 Pred = ICmpInst::ICMP_ULT;
10816 Changed = true;
10817 }
10818 break;
10819 case ICmpInst::ICMP_UGE:
10820 if (!getUnsignedRangeMin(RHS).isMinValue()) {
10821 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
10822 Pred = ICmpInst::ICMP_UGT;
10823 Changed = true;
10824 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
10825 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10827 Pred = ICmpInst::ICMP_UGT;
10828 Changed = true;
10829 }
10830 break;
10831 default:
10832 break;
10833 }
10834
10835 // TODO: More simplifications are possible here.
10836
10837 // Recursively simplify until we either hit a recursion limit or nothing
10838 // changes.
10839 if (Changed)
10840 return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
10841
10842 return Changed;
10843}
10844
10846 return getSignedRangeMax(S).isNegative();
10847}
10848
10851}
10852
10854 return !getSignedRangeMin(S).isNegative();
10855}
10856
10859}
10860
10862 // Query push down for cases where the unsigned range is
10863 // less than sufficient.
10864 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
10865 return isKnownNonZero(SExt->getOperand(0));
10866 return getUnsignedRangeMin(S) != 0;
10867}
10868
10869std::pair<const SCEV *, const SCEV *>
10871 // Compute SCEV on entry of loop L.
10872 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
10873 if (Start == getCouldNotCompute())
10874 return { Start, Start };
10875 // Compute post increment SCEV for loop L.
10876 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
10877 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
10878 return { Start, PostInc };
10879}
10880
10882 const SCEV *LHS, const SCEV *RHS) {
10883 // First collect all loops.
10885 getUsedLoops(LHS, LoopsUsed);
10886 getUsedLoops(RHS, LoopsUsed);
10887
10888 if (LoopsUsed.empty())
10889 return false;
10890
10891 // Domination relationship must be a linear order on collected loops.
10892#ifndef NDEBUG
10893 for (const auto *L1 : LoopsUsed)
10894 for (const auto *L2 : LoopsUsed)
10895 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
10896 DT.dominates(L2->getHeader(), L1->getHeader())) &&
10897 "Domination relationship is not a linear order");
10898#endif
10899
10900 const Loop *MDL =
10901 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
10902 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
10903 });
10904
10905 // Get init and post increment value for LHS.
10906 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
10907 // if LHS contains unknown non-invariant SCEV then bail out.
10908 if (SplitLHS.first == getCouldNotCompute())
10909 return false;
10910 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
10911 // Get init and post increment value for RHS.
10912 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
10913 // if RHS contains unknown non-invariant SCEV then bail out.
10914 if (SplitRHS.first == getCouldNotCompute())
10915 return false;
10916 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
10917 // It is possible that init SCEV contains an invariant load but it does
10918 // not dominate MDL and is not available at MDL loop entry, so we should
10919 // check it here.
10920 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
10921 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
10922 return false;
10923
10924 // It seems backedge guard check is faster than entry one so in some cases
10925 // it can speed up whole estimation by short circuit
10926 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
10927 SplitRHS.second) &&
10928 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
10929}
10930
10932 const SCEV *LHS, const SCEV *RHS) {
10933 // Canonicalize the inputs first.
10934 (void)SimplifyICmpOperands(Pred, LHS, RHS);
10935
10936 if (isKnownViaInduction(Pred, LHS, RHS))
10937 return true;
10938
10939 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
10940 return true;
10941
10942 // Otherwise see what can be done with some simple reasoning.
10943 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
10944}
10945
10947 const SCEV *LHS,
10948 const SCEV *RHS) {
10949 if (isKnownPredicate(Pred, LHS, RHS))
10950 return true;
10952 return false;
10953 return std::nullopt;
10954}
10955
10957 const SCEV *LHS, const SCEV *RHS,
10958 const Instruction *CtxI) {
10959 // TODO: Analyze guards and assumes from Context's block.
10960 return isKnownPredicate(Pred, LHS, RHS) ||
10962}
10963
10964std::optional<bool>
10966 const SCEV *RHS, const Instruction *CtxI) {
10967 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
10968 if (KnownWithoutContext)
10969 return KnownWithoutContext;
10970
10971 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
10972 return true;
10975 LHS, RHS))
10976 return false;
10977 return std::nullopt;
10978}
10979
10981 const SCEVAddRecExpr *LHS,
10982 const SCEV *RHS) {
10983 const Loop *L = LHS->getLoop();
10984 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
10985 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
10986}
10987
10988std::optional<ScalarEvolution::MonotonicPredicateType>
10990 ICmpInst::Predicate Pred) {
10991 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
10992
10993#ifndef NDEBUG
10994 // Verify an invariant: inverting the predicate should turn a monotonically
10995 // increasing change to a monotonically decreasing one, and vice versa.
10996 if (Result) {
10997 auto ResultSwapped =
10998 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
10999
11000 assert(*ResultSwapped != *Result &&
11001 "monotonicity should flip as we flip the predicate");
11002 }
11003#endif
11004
11005 return Result;
11006}
11007
11008std::optional<ScalarEvolution::MonotonicPredicateType>
11009ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11010 ICmpInst::Predicate Pred) {
11011 // A zero step value for LHS means the induction variable is essentially a
11012 // loop invariant value. We don't really depend on the predicate actually
11013 // flipping from false to true (for increasing predicates, and the other way
11014 // around for decreasing predicates), all we care about is that *if* the
11015 // predicate changes then it only changes from false to true.
11016 //
11017 // A zero step value in itself is not very useful, but there may be places
11018 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11019 // as general as possible.
11020
11021 // Only handle LE/LT/GE/GT predicates.
11022 if (!ICmpInst::isRelational(Pred))
11023 return std::nullopt;
11024
11025 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11026 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11027 "Should be greater or less!");
11028
11029 // Check that AR does not wrap.
11030 if (ICmpInst::isUnsigned(Pred)) {
11031 if (!LHS->hasNoUnsignedWrap())
11032 return std::nullopt;
11034 }
11035 assert(ICmpInst::isSigned(Pred) &&
11036 "Relational predicate is either signed or unsigned!");
11037 if (!LHS->hasNoSignedWrap())
11038 return std::nullopt;
11039
11040 const SCEV *Step = LHS->getStepRecurrence(*this);
11041
11042 if (isKnownNonNegative(Step))
11044
11045 if (isKnownNonPositive(Step))
11047
11048 return std::nullopt;
11049}
11050
11051std::optional<ScalarEvolution::LoopInvariantPredicate>
11053 const SCEV *LHS, const SCEV *RHS,
11054 const Loop *L,
11055 const Instruction *CtxI) {
11056 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11057 if (!isLoopInvariant(RHS, L)) {
11058 if (!isLoopInvariant(LHS, L))
11059 return std::nullopt;
11060
11061 std::swap(LHS, RHS);
11062 Pred = ICmpInst::getSwappedPredicate(Pred);
11063 }
11064
11065 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11066 if (!ArLHS || ArLHS->getLoop() != L)
11067 return std::nullopt;
11068
11069 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11070 if (!MonotonicType)
11071 return std::nullopt;
11072 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11073 // true as the loop iterates, and the backedge is control dependent on
11074 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11075 //
11076 // * if the predicate was false in the first iteration then the predicate
11077 // is never evaluated again, since the loop exits without taking the
11078 // backedge.
11079 // * if the predicate was true in the first iteration then it will
11080 // continue to be true for all future iterations since it is
11081 // monotonically increasing.
11082 //
11083 // For both the above possibilities, we can replace the loop varying
11084 // predicate with its value on the first iteration of the loop (which is
11085 // loop invariant).
11086 //
11087 // A similar reasoning applies for a monotonically decreasing predicate, by
11088 // replacing true with false and false with true in the above two bullets.
11089 bool Increasing = *MonotonicType == ScalarEvolution::MonotonicallyIncreasing;
11090 auto P = Increasing ? Pred : ICmpInst::getInversePredicate(Pred);
11091
11094 RHS);
11095
11096 if (!CtxI)
11097 return std::nullopt;
11098 // Try to prove via context.
11099 // TODO: Support other cases.
11100 switch (Pred) {
11101 default:
11102 break;
11103 case ICmpInst::ICMP_ULE:
11104 case ICmpInst::ICMP_ULT: {
11105 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11106 // Given preconditions
11107 // (1) ArLHS does not cross the border of positive and negative parts of
11108 // range because of:
11109 // - Positive step; (TODO: lift this limitation)
11110 // - nuw - does not cross zero boundary;
11111 // - nsw - does not cross SINT_MAX boundary;
11112 // (2) ArLHS <s RHS
11113 // (3) RHS >=s 0
11114 // we can replace the loop variant ArLHS <u RHS condition with loop
11115 // invariant Start(ArLHS) <u RHS.
11116 //
11117 // Because of (1) there are two options:
11118 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11119 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11120 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11121 // Because of (2) ArLHS <u RHS is trivially true.
11122 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11123 // We can strengthen this to Start(ArLHS) <u RHS.
11124 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11125 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11126 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11128 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11130 RHS);
11131 }
11132 }
11133
11134 return std::nullopt;
11135}
11136
11137std::optional<ScalarEvolution::LoopInvariantPredicate>
11139 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11140 const Instruction *CtxI, const SCEV *MaxIter) {
11142 Pred, LHS, RHS, L, CtxI, MaxIter))
11143 return LIP;
11144 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11145 // Number of iterations expressed as UMIN isn't always great for expressing
11146 // the value on the last iteration. If the straightforward approach didn't
11147 // work, try the following trick: if the a predicate is invariant for X, it
11148 // is also invariant for umin(X, ...). So try to find something that works
11149 // among subexpressions of MaxIter expressed as umin.
11150 for (auto *Op : UMin->operands())
11152 Pred, LHS, RHS, L, CtxI, Op))
11153 return LIP;
11154 return std::nullopt;
11155}
11156
11157std::optional<ScalarEvolution::LoopInvariantPredicate>
11159 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11160 const Instruction *CtxI, const SCEV *MaxIter) {
11161 // Try to prove the following set of facts:
11162 // - The predicate is monotonic in the iteration space.
11163 // - If the check does not fail on the 1st iteration:
11164 // - No overflow will happen during first MaxIter iterations;
11165 // - It will not fail on the MaxIter'th iteration.
11166 // If the check does fail on the 1st iteration, we leave the loop and no
11167 // other checks matter.
11168
11169 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11170 if (!isLoopInvariant(RHS, L)) {
11171 if (!isLoopInvariant(LHS, L))
11172 return std::nullopt;
11173
11174 std::swap(LHS, RHS);
11175 Pred = ICmpInst::getSwappedPredicate(Pred);
11176 }
11177
11178 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11179 if (!AR || AR->getLoop() != L)
11180 return std::nullopt;
11181
11182 // The predicate must be relational (i.e. <, <=, >=, >).
11183 if (!ICmpInst::isRelational(Pred))
11184 return std::nullopt;
11185
11186 // TODO: Support steps other than +/- 1.
11187 const SCEV *Step = AR->getStepRecurrence(*this);
11188 auto *One = getOne(Step->getType());
11189 auto *MinusOne = getNegativeSCEV(One);
11190 if (Step != One && Step != MinusOne)
11191 return std::nullopt;
11192
11193 // Type mismatch here means that MaxIter is potentially larger than max
11194 // unsigned value in start type, which mean we cannot prove no wrap for the
11195 // indvar.
11196 if (AR->getType() != MaxIter->getType())
11197 return std::nullopt;
11198
11199 // Value of IV on suggested last iteration.
11200 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11201 // Does it still meet the requirement?
11202 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11203 return std::nullopt;
11204 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11205 // not exceed max unsigned value of this type), this effectively proves
11206 // that there is no wrap during the iteration. To prove that there is no
11207 // signed/unsigned wrap, we need to check that
11208 // Start <= Last for step = 1 or Start >= Last for step = -1.
11209 ICmpInst::Predicate NoOverflowPred =
11211 if (Step == MinusOne)
11212 NoOverflowPred = CmpInst::getSwappedPredicate(NoOverflowPred);
11213 const SCEV *Start = AR->getStart();
11214 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11215 return std::nullopt;
11216
11217 // Everything is fine.
11218 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11219}
11220
11221bool ScalarEvolution::isKnownPredicateViaConstantRanges(
11222 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) {
11223 if (HasSameValue(LHS, RHS))
11224 return ICmpInst::isTrueWhenEqual(Pred);
11225
11226 // This code is split out from isKnownPredicate because it is called from
11227 // within isLoopEntryGuardedByCond.
11228
11229 auto CheckRanges = [&](const ConstantRange &RangeLHS,
11230 const ConstantRange &RangeRHS) {
11231 return RangeLHS.icmp(Pred, RangeRHS);
11232 };
11233
11234 // The check at the top of the function catches the case where the values are
11235 // known to be equal.
11236 if (Pred == CmpInst::ICMP_EQ)
11237 return false;
11238
11239 if (Pred == CmpInst::ICMP_NE) {
11240 auto SL = getSignedRange(LHS);
11241 auto SR = getSignedRange(RHS);
11242 if (CheckRanges(SL, SR))
11243 return true;
11244 auto UL = getUnsignedRange(LHS);
11245 auto UR = getUnsignedRange(RHS);
11246 if (CheckRanges(UL, UR))
11247 return true;
11248 auto *Diff = getMinusSCEV(LHS, RHS);
11249 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11250 }
11251
11252 if (CmpInst::isSigned(Pred)) {
11253 auto SL = getSignedRange(LHS);
11254 auto SR = getSignedRange(RHS);
11255 return CheckRanges(SL, SR);
11256 }
11257
11258 auto UL = getUnsignedRange(LHS);
11259 auto UR = getUnsignedRange(RHS);
11260 return CheckRanges(UL, UR);
11261}
11262
11263bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
11264 const SCEV *LHS,
11265 const SCEV *RHS) {
11266 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11267 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11268 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11269 // OutC1 and OutC2.
11270 auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
11271 APInt &OutC1, APInt &OutC2,
11272 SCEV::NoWrapFlags ExpectedFlags) {
11273 const SCEV *XNonConstOp, *XConstOp;
11274 const SCEV *YNonConstOp, *YConstOp;
11275 SCEV::NoWrapFlags XFlagsPresent;
11276 SCEV::NoWrapFlags YFlagsPresent;
11277
11278 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11279 XConstOp = getZero(X->getType());
11280 XNonConstOp = X;
11281 XFlagsPresent = ExpectedFlags;
11282 }
11283 if (!isa<SCEVConstant>(XConstOp) ||
11284 (XFlagsPresent & ExpectedFlags) != ExpectedFlags)
11285 return false;
11286
11287 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11288 YConstOp = getZero(Y->getType());
11289 YNonConstOp = Y;
11290 YFlagsPresent = ExpectedFlags;
11291 }
11292
11293 if (!isa<SCEVConstant>(YConstOp) ||
11294 (YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11295 return false;
11296
11297 if (YNonConstOp != XNonConstOp)
11298 return false;
11299
11300 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11301 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11302
11303 return true;
11304 };
11305
11306 APInt C1;
11307 APInt C2;
11308
11309 switch (Pred) {
11310 default:
11311 break;
11312
11313 case ICmpInst::ICMP_SGE:
11314 std::swap(LHS, RHS);
11315 [[fallthrough]];
11316 case ICmpInst::ICMP_SLE:
11317 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11318 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11319 return true;
11320
11321 break;
11322
11323 case ICmpInst::ICMP_SGT:
11324 std::swap(LHS, RHS);
11325 [[fallthrough]];
11326 case ICmpInst::ICMP_SLT:
11327 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11328 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11329 return true;
11330
11331 break;
11332
11333 case ICmpInst::ICMP_UGE:
11334 std::swap(LHS, RHS);
11335 [[fallthrough]];
11336 case ICmpInst::ICMP_ULE:
11337 // (X + C1)<nuw> u<= (X + C2)<nuw> for C1 u<= C2.
11338 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11339 return true;
11340
11341 break;
11342
11343 case ICmpInst::ICMP_UGT:
11344 std::swap(LHS, RHS);
11345 [[fallthrough]];
11346 case ICmpInst::ICMP_ULT:
11347 // (X + C1)<nuw> u< (X + C2)<nuw> if C1 u< C2.
11348 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11349 return true;
11350 break;
11351 }
11352
11353 return false;
11354}
11355
11356bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred,
11357 const SCEV *LHS,
11358 const SCEV *RHS) {
11359 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11360 return false;
11361
11362 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11363 // the stack can result in exponential time complexity.
11364 SaveAndRestore Restore(ProvingSplitPredicate, true);
11365
11366 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11367 //
11368 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11369 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11370 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11371 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11372 // use isKnownPredicate later if needed.
11373 return isKnownNonNegative(RHS) &&
11376}
11377
11378bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB,
11380 const SCEV *LHS, const SCEV *RHS) {
11381 // No need to even try if we know the module has no guards.
11382 if (!HasGuards)
11383 return false;
11384
11385 return any_of(*BB, [&](const Instruction &I) {
11386 using namespace llvm::PatternMatch;
11387
11388 Value *Condition;
11389 return match(&I, m_Intrinsic<Intrinsic::experimental_guard>(
11390 m_Value(Condition))) &&
11391 isImpliedCond(Pred, LHS, RHS, Condition, false);
11392 });
11393}
11394
11395/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11396/// protected by a conditional between LHS and RHS. This is used to
11397/// to eliminate casts.
11398bool
11401 const SCEV *LHS, const SCEV *RHS) {
11402 // Interpret a null as meaning no loop, where there is obviously no guard
11403 // (interprocedural conditions notwithstanding). Do not bother about
11404 // unreachable loops.
11405 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11406 return true;
11407
11408 if (VerifyIR)
11409 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11410 "This cannot be done on broken IR!");
11411
11412
11413 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11414 return true;
11415
11416 BasicBlock *Latch = L->getLoopLatch();
11417 if (!Latch)
11418 return false;
11419
11420 BranchInst *LoopContinuePredicate =
11421 dyn_cast<BranchInst>(Latch->getTerminator());
11422 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
11423 isImpliedCond(Pred, LHS, RHS,
11424 LoopContinuePredicate->getCondition(),
11425 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11426 return true;
11427
11428 // We don't want more than one activation of the following loops on the stack
11429 // -- that can lead to O(n!) time complexity.
11430 if (WalkingBEDominatingConds)
11431 return false;
11432
11433 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11434
11435 // See if we can exploit a trip count to prove the predicate.
11436 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11437 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11438 if (LatchBECount != getCouldNotCompute()) {
11439 // We know that Latch branches back to the loop header exactly
11440 // LatchBECount times. This means the backdege condition at Latch is
11441 // equivalent to "{0,+,1} u< LatchBECount".
11442 Type *Ty = LatchBECount->getType();
11443 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11444 const SCEV *LoopCounter =
11445 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11446 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11447 LatchBECount))
11448 return true;
11449 }
11450
11451 // Check conditions due to any @llvm.assume intrinsics.
11452 for (auto &AssumeVH : AC.assumptions()) {
11453 if (!AssumeVH)
11454 continue;
11455 auto *CI = cast<CallInst>(AssumeVH);
11456 if (!DT.dominates(CI, Latch->getTerminator()))
11457 continue;
11458
11459 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11460 return true;
11461 }
11462
11463 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11464 return true;
11465
11466 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11467 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11468 assert(DTN && "should reach the loop header before reaching the root!");
11469
11470 BasicBlock *BB = DTN->getBlock();
11471 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11472 return true;
11473
11474 BasicBlock *PBB = BB->getSinglePredecessor();
11475 if (!PBB)
11476 continue;
11477
11478 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
11479 if (!ContinuePredicate || !ContinuePredicate->isConditional())
11480 continue;
11481
11482 Value *Condition = ContinuePredicate->getCondition();
11483
11484 // If we have an edge `E` within the loop body that dominates the only
11485 // latch, the condition guarding `E` also guards the backedge. This
11486 // reasoning works only for loops with a single latch.
11487
11488 BasicBlockEdge DominatingEdge(PBB, BB);
11489 if (DominatingEdge.isSingleEdge()) {
11490 // We're constructively (and conservatively) enumerating edges within the
11491 // loop body that dominate the latch. The dominator tree better agree
11492 // with us on this:
11493 assert(DT.dominates(DominatingEdge, Latch) && "should be!");
11494
11495 if (isImpliedCond(Pred, LHS, RHS, Condition,
11496 BB != ContinuePredicate->getSuccessor(0)))
11497 return true;
11498 }
11499 }
11500
11501 return false;
11502}
11503
11506 const SCEV *LHS,
11507 const SCEV *RHS) {
11508 // Do not bother proving facts for unreachable code.
11509 if (!DT.isReachableFromEntry(BB))
11510 return true;
11511 if (VerifyIR)
11512 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11513 "This cannot be done on broken IR!");
11514
11515 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11516 // the facts (a >= b && a != b) separately. A typical situation is when the
11517 // non-strict comparison is known from ranges and non-equality is known from
11518 // dominating predicates. If we are proving strict comparison, we always try
11519 // to prove non-equality and non-strict comparison separately.
11520 auto NonStrictPredicate = ICmpInst::getNonStrictPredicate(Pred);
11521 const bool ProvingStrictComparison = (Pred != NonStrictPredicate);
11522 bool ProvedNonStrictComparison = false;
11523 bool ProvedNonEquality = false;
11524
11525 auto SplitAndProve =
11526 [&](std::function<bool(ICmpInst::Predicate)> Fn) -> bool {
11527 if (!ProvedNonStrictComparison)
11528 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11529 if (!ProvedNonEquality)
11530 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11531 if (ProvedNonStrictComparison && ProvedNonEquality)
11532 return true;
11533 return false;
11534 };
11535
11536 if (ProvingStrictComparison) {
11537 auto ProofFn = [&](ICmpInst::Predicate P) {
11538 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
11539 };
11540 if (SplitAndProve(ProofFn))
11541 return true;
11542 }
11543
11544 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
11545 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
11546 const Instruction *CtxI = &BB->front();
11547 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
11548 return true;
11549 if (ProvingStrictComparison) {
11550 auto ProofFn = [&](ICmpInst::Predicate P) {
11551 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
11552 };
11553 if (SplitAndProve(ProofFn))
11554 return true;
11555 }
11556 return false;
11557 };
11558
11559 // Starting at the block's predecessor, climb up the predecessor chain, as long
11560 // as there are predecessors that can be found that have unique successors
11561 // leading to the original block.
11562 const Loop *ContainingLoop = LI.getLoopFor(BB);
11563 const BasicBlock *PredBB;
11564 if (ContainingLoop && ContainingLoop->getHeader() == BB)
11565 PredBB = ContainingLoop->getLoopPredecessor();
11566 else
11567 PredBB = BB->getSinglePredecessor();
11568 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
11569 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
11570 const BranchInst *BlockEntryPredicate =
11571 dyn_cast<BranchInst>(Pair.first->getTerminator());
11572 if (!BlockEntryPredicate || BlockEntryPredicate->isUnconditional())
11573 continue;
11574
11575 if (ProveViaCond(BlockEntryPredicate->getCondition(),
11576 BlockEntryPredicate->getSuccessor(0) != Pair.second))
11577 return true;
11578 }
11579
11580 // Check conditions due to any @llvm.assume intrinsics.
11581 for (auto &AssumeVH : AC.assumptions()) {
11582 if (!AssumeVH)
11583 continue;
11584 auto *CI = cast<CallInst>(AssumeVH);
11585 if (!DT.dominates(CI, BB))
11586 continue;
11587
11588 if (ProveViaCond(CI->getArgOperand(0), false))
11589 return true;
11590 }
11591
11592 // Check conditions due to any @llvm.experimental.guard intrinsics.
11593 auto *GuardDecl = F.getParent()->getFunction(
11594 Intrinsic::getName(Intrinsic::experimental_guard));
11595 if (GuardDecl)
11596 for (const auto *GU : GuardDecl->users())
11597 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
11598 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
11599 if (ProveViaCond(Guard->getArgOperand(0), false))
11600 return true;
11601 return false;
11602}
11603
11606 const SCEV *LHS,
11607 const SCEV *RHS) {
11608 // Interpret a null as meaning no loop, where there is obviously no guard
11609 // (interprocedural conditions notwithstanding).
11610 if (!L)
11611 return false;
11612
11613 // Both LHS and RHS must be available at loop entry.
11615 "LHS is not available at Loop Entry");
11617 "RHS is not available at Loop Entry");
11618
11619 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11620 return true;
11621
11622 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
11623}
11624
11625bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
11626 const SCEV *RHS,
11627 const Value *FoundCondValue, bool Inverse,
11628 const Instruction *CtxI) {
11629 // False conditions implies anything. Do not bother analyzing it further.
11630 if (FoundCondValue ==
11631 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
11632 return true;
11633
11634 if (!PendingLoopPredicates.insert(FoundCondValue).second)
11635 return false;
11636
11637 auto ClearOnExit =
11638 make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); });
11639
11640 // Recursively handle And and Or conditions.
11641 const Value *Op0, *Op1;
11642 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
11643 if (!Inverse)
11644 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11645 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11646 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
11647 if (Inverse)
11648 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11649 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11650 }
11651
11652 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
11653 if (!ICI) return false;
11654
11655 // Now that we found a conditional branch that dominates the loop or controls
11656 // the loop latch. Check to see if it is the comparison we are looking for.
11657 ICmpInst::Predicate FoundPred;
11658 if (Inverse)
11659 FoundPred = ICI->getInversePredicate();
11660 else
11661 FoundPred = ICI->getPredicate();
11662
11663 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
11664 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
11665
11666 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
11667}
11668
11669bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
11670 const SCEV *RHS,
11671 ICmpInst::Predicate FoundPred,
11672 const SCEV *FoundLHS, const SCEV *FoundRHS,
11673 const Instruction *CtxI) {
11674 // Balance the types.
11675 if (getTypeSizeInBits(LHS->getType()) <
11676 getTypeSizeInBits(FoundLHS->getType())) {
11677 // For unsigned and equality predicates, try to prove that both found
11678 // operands fit into narrow unsigned range. If so, try to prove facts in
11679 // narrow types.
11680 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
11681 !FoundRHS->getType()->isPointerTy()) {
11682 auto *NarrowType = LHS->getType();
11683 auto *WideType = FoundLHS->getType();
11684 auto BitWidth = getTypeSizeInBits(NarrowType);
11685 const SCEV *MaxValue = getZeroExtendExpr(
11687 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
11688 MaxValue) &&
11689 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
11690 MaxValue)) {
11691 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
11692 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
11693 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, TruncFoundLHS,
11694 TruncFoundRHS, CtxI))
11695 return true;
11696 }
11697 }
11698
11699 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
11700 return false;
11701 if (CmpInst::isSigned(Pred)) {
11702 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
11703 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
11704 } else {
11705 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
11706 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
11707 }
11708 } else if (getTypeSizeInBits(LHS->getType()) >
11709 getTypeSizeInBits(FoundLHS->getType())) {
11710 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
11711 return false;
11712 if (CmpInst::isSigned(FoundPred)) {
11713 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
11714 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
11715 } else {
11716 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
11717 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
11718 }
11719 }
11720 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
11721 FoundRHS, CtxI);
11722}
11723
11724bool ScalarEvolution::isImpliedCondBalancedTypes(
11725 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
11726 ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, const SCEV *FoundRHS,
11727 const Instruction *CtxI) {
11729 getTypeSizeInBits(FoundLHS->getType()) &&
11730 "Types should be balanced!");
11731 // Canonicalize the query to match the way instcombine will have
11732 // canonicalized the comparison.
11733 if (SimplifyICmpOperands(Pred, LHS, RHS))
11734 if (LHS == RHS)
11735 return CmpInst::isTrueWhenEqual(Pred);
11736 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
11737 if (FoundLHS == FoundRHS)
11738 return CmpInst::isFalseWhenEqual(FoundPred);
11739
11740 // Check to see if we can make the LHS or RHS match.
11741 if (LHS == FoundRHS || RHS == FoundLHS) {
11742 if (isa<SCEVConstant>(RHS)) {
11743 std::swap(FoundLHS, FoundRHS);
11744 FoundPred = ICmpInst::getSwappedPredicate(FoundPred);
11745 } else {
11746 std::swap(LHS, RHS);
11747 Pred = ICmpInst::getSwappedPredicate(Pred);
11748 }
11749 }
11750
11751 // Check whether the found predicate is the same as the desired predicate.
11752 if (FoundPred == Pred)
11753 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11754
11755 // Check whether swapping the found predicate makes it the same as the
11756 // desired predicate.
11757 if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) {
11758 // We can write the implication
11759 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
11760 // using one of the following ways:
11761 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
11762 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
11763 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
11764 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
11765 // Forms 1. and 2. require swapping the operands of one condition. Don't
11766 // do this if it would break canonical constant/addrec ordering.
11767 if (!isa<SCEVConstant>(RHS) && !isa<SCEVAddRecExpr>(LHS))
11768 return isImpliedCondOperands(FoundPred, RHS, LHS, FoundLHS, FoundRHS,
11769 CtxI);
11770 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
11771 return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS, CtxI);
11772
11773 // There's no clear preference between forms 3. and 4., try both. Avoid
11774 // forming getNotSCEV of pointer values as the resulting subtract is
11775 // not legal.
11776 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
11777 isImpliedCondOperands(FoundPred, getNotSCEV(LHS), getNotSCEV(RHS),
11778 FoundLHS, FoundRHS, CtxI))
11779 return true;
11780
11781 if (!FoundLHS->getType()->isPointerTy() &&
11782 !FoundRHS->getType()->isPointerTy() &&
11783 isImpliedCondOperands(Pred, LHS, RHS, getNotSCEV(FoundLHS),
11784 getNotSCEV(FoundRHS), CtxI))
11785 return true;
11786
11787 return false;
11788 }
11789
11790 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
11791 CmpInst::Predicate P2) {
11792 assert(P1 != P2 && "Handled earlier!");
11793 return CmpInst::isRelational(P2) &&
11795 };
11796 if (IsSignFlippedPredicate(Pred, FoundPred)) {
11797 // Unsigned comparison is the same as signed comparison when both the
11798 // operands are non-negative or negative.
11799 if ((isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) ||
11800 (isKnownNegative(FoundLHS) && isKnownNegative(FoundRHS)))
11801 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11802 // Create local copies that we can freely swap and canonicalize our
11803 // conditions to "le/lt".
11804 ICmpInst::Predicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
11805 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
11806 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
11807 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
11808 CanonicalPred = ICmpInst::getSwappedPredicate(CanonicalPred);
11809 CanonicalFoundPred = ICmpInst::getSwappedPredicate(CanonicalFoundPred);
11810 std::swap(CanonicalLHS, CanonicalRHS);
11811 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
11812 }
11813 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
11814 "Must be!");
11815 assert((ICmpInst::isLT(CanonicalFoundPred) ||
11816 ICmpInst::isLE(CanonicalFoundPred)) &&
11817 "Must be!");
11818 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
11819 // Use implication:
11820 // x <u y && y >=s 0 --> x <s y.
11821 // If we can prove the left part, the right part is also proven.
11822 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11823 CanonicalRHS, CanonicalFoundLHS,
11824 CanonicalFoundRHS);
11825 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
11826 // Use implication:
11827 // x <s y && y <s 0 --> x <u y.
11828 // If we can prove the left part, the right part is also proven.
11829 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11830 CanonicalRHS, CanonicalFoundLHS,
11831 CanonicalFoundRHS);
11832 }
11833
11834 // Check if we can make progress by sharpening ranges.
11835 if (FoundPred == ICmpInst::ICMP_NE &&
11836 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
11837
11838 const SCEVConstant *C = nullptr;
11839 const SCEV *V = nullptr;
11840
11841 if (isa<SCEVConstant>(FoundLHS)) {
11842 C = cast<SCEVConstant>(FoundLHS);
11843 V = FoundRHS;
11844 } else {
11845 C = cast<SCEVConstant>(FoundRHS);
11846 V = FoundLHS;
11847 }
11848
11849 // The guarding predicate tells us that C != V. If the known range
11850 // of V is [C, t), we can sharpen the range to [C + 1, t). The
11851 // range we consider has to correspond to same signedness as the
11852 // predicate we're interested in folding.
11853
11854 APInt Min = ICmpInst::isSigned(Pred) ?
11856
11857 if (Min == C->getAPInt()) {
11858 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
11859 // This is true even if (Min + 1) wraps around -- in case of
11860 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
11861
11862 APInt SharperMin = Min + 1;
11863
11864 switch (Pred) {
11865 case ICmpInst::ICMP_SGE:
11866 case ICmpInst::ICMP_UGE:
11867 // We know V `Pred` SharperMin. If this implies LHS `Pred`
11868 // RHS, we're done.
11869 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
11870 CtxI))
11871 return true;
11872 [[fallthrough]];
11873
11874 case ICmpInst::ICMP_SGT:
11875 case ICmpInst::ICMP_UGT:
11876 // We know from the range information that (V `Pred` Min ||
11877 // V == Min). We know from the guarding condition that !(V
11878 // == Min). This gives us
11879 //
11880 // V `Pred` Min || V == Min && !(V == Min)
11881 // => V `Pred` Min
11882 //
11883 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
11884
11885 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
11886 return true;
11887 break;
11888
11889 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
11890 case ICmpInst::ICMP_SLE:
11891 case ICmpInst::ICMP_ULE:
11892 if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
11893 LHS, V, getConstant(SharperMin), CtxI))
11894 return true;
11895 [[fallthrough]];
11896
11897 case ICmpInst::ICMP_SLT:
11898 case ICmpInst::ICMP_ULT:
11899 if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
11900 LHS, V, getConstant(Min), CtxI))
11901 return true;
11902 break;
11903
11904 default:
11905 // No change
11906 break;
11907 }
11908 }
11909 }
11910
11911 // Check whether the actual condition is beyond sufficient.
11912 if (FoundPred == ICmpInst::ICMP_EQ)
11913 if (ICmpInst::isTrueWhenEqual(Pred))
11914 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
11915 return true;
11916 if (Pred == ICmpInst::ICMP_NE)
11917 if (!ICmpInst::isTrueWhenEqual(FoundPred))
11918 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
11919 return true;
11920
11921 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
11922 return true;
11923
11924 // Otherwise assume the worst.
11925 return false;
11926}
11927
11928bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
11929 const SCEV *&L, const SCEV *&R,
11930 SCEV::NoWrapFlags &Flags) {
11931 const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
11932 if (!AE || AE->getNumOperands() != 2)
11933 return false;
11934
11935 L = AE->getOperand(0);
11936 R = AE->getOperand(1);
11937 Flags = AE->getNoWrapFlags();
11938 return true;
11939}
11940
11941std::optional<APInt>
11943 // We avoid subtracting expressions here because this function is usually
11944 // fairly deep in the call stack (i.e. is called many times).
11945
11946 // X - X = 0.
11947 if (More == Less)
11948 return APInt(getTypeSizeInBits(More->getType()), 0);
11949
11950 if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
11951 const auto *LAR = cast<SCEVAddRecExpr>(Less);
11952 const auto *MAR = cast<SCEVAddRecExpr>(More);
11953
11954 if (LAR->getLoop() != MAR->getLoop())
11955 return std::nullopt;
11956
11957 // We look at affine expressions only; not for correctness but to keep
11958 // getStepRecurrence cheap.
11959 if (!LAR->isAffine() || !MAR->isAffine())
11960 return std::nullopt;
11961
11962 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
11963 return std::nullopt;
11964
11965 Less = LAR->getStart();
11966 More = MAR->getStart();
11967
11968 // fall through
11969 }
11970
11971 if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) {
11972 const auto &M = cast<SCEVConstant>(More)->getAPInt();
11973 const auto &L = cast<SCEVConstant>(Less)->getAPInt();
11974 return M - L;
11975 }
11976
11977 SCEV::NoWrapFlags Flags;
11978 const SCEV *LLess = nullptr, *RLess = nullptr;
11979 const SCEV *LMore = nullptr, *RMore = nullptr;
11980 const SCEVConstant *C1 = nullptr, *C2 = nullptr;
11981 // Compare (X + C1) vs X.
11982 if (splitBinaryAdd(Less, LLess, RLess, Flags))
11983 if ((C1 = dyn_cast<SCEVConstant>(LLess)))
11984 if (RLess == More)
11985 return -(C1->getAPInt());
11986
11987 // Compare X vs (X + C2).
11988 if (splitBinaryAdd(More, LMore, RMore, Flags))
11989 if ((C2 = dyn_cast<SCEVConstant>(LMore)))
11990 if (RMore == Less)
11991 return C2->getAPInt();
11992
11993 // Compare (X + C1) vs (X + C2).
11994 if (C1 && C2 && RLess == RMore)
11995 return C2->getAPInt() - C1->getAPInt();
11996
11997 return std::nullopt;
11998}
11999
12000bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12001 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
12002 const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) {
12003 // Try to recognize the following pattern:
12004 //
12005 // FoundRHS = ...
12006 // ...
12007 // loop:
12008 // FoundLHS = {Start,+,W}
12009 // context_bb: // Basic block from the same loop
12010 // known(Pred, FoundLHS, FoundRHS)
12011 //
12012 // If some predicate is known in the context of a loop, it is also known on
12013 // each iteration of this loop, including the first iteration. Therefore, in
12014 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12015 // prove the original pred using this fact.
12016 if (!CtxI)
12017 return false;
12018 const BasicBlock *ContextBB = CtxI->getParent();
12019 // Make sure AR varies in the context block.
12020 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12021 const Loop *L = AR->getLoop();
12022 // Make sure that context belongs to the loop and executes on 1st iteration
12023 // (if it ever executes at all).
12024 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12025 return false;
12026 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12027 return false;
12028 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12029 }
12030
12031 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12032 const Loop *L = AR->getLoop();
12033 // Make sure that context belongs to the loop and executes on 1st iteration
12034 // (if it ever executes at all).
12035 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12036 return false;
12037 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12038 return false;
12039 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12040 }
12041
12042 return false;
12043}
12044
12045bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(
12046 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
12047 const SCEV *FoundLHS, const SCEV *FoundRHS) {
12048 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12049 return false;
12050
12051 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12052 if (!AddRecLHS)
12053 return false;
12054
12055 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12056 if (!AddRecFoundLHS)
12057 return false;
12058
12059 // We'd like to let SCEV reason about control dependencies, so we constrain
12060 // both the inequalities to be about add recurrences on the same loop. This
12061 // way we can use isLoopEntryGuardedByCond later.
12062
12063 const Loop *L = AddRecFoundLHS->getLoop();
12064 if (L != AddRecLHS->getLoop())
12065 return false;
12066
12067 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12068 //
12069 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12070 // ... (2)
12071 //
12072 // Informal proof for (2), assuming (1) [*]:
12073 //
12074 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12075 //
12076 // Then
12077 //
12078 // FoundLHS s< FoundRHS s< INT_MIN - C
12079 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12080 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12081 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12082 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12083 // <=> FoundLHS + C s< FoundRHS + C
12084 //
12085 // [*]: (1) can be proved by ruling out overflow.
12086 //
12087 // [**]: This can be proved by analyzing all the four possibilities:
12088 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12089 // (A s>= 0, B s>= 0).
12090 //
12091 // Note:
12092 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12093 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12094 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12095 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12096 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12097 // C)".
12098
12099 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12100 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12101 if (!LDiff || !RDiff || *LDiff != *RDiff)
12102 return false;
12103
12104 if (LDiff->isMinValue())
12105 return true;
12106
12107 APInt FoundRHSLimit;
12108
12109 if (Pred == CmpInst::ICMP_ULT) {
12110 FoundRHSLimit = -(*RDiff);
12111 } else {
12112 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12113 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12114 }
12115
12116 // Try to prove (1) or (2), as needed.
12117 return isAvailableAtLoopEntry(FoundRHS, L) &&
12118 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12119 getConstant(FoundRHSLimit));
12120}
12121
12122bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred,
12123 const SCEV *LHS, const SCEV *RHS,
12124 const SCEV *FoundLHS,
12125 const SCEV *FoundRHS, unsigned Depth) {
12126 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12127
12128 auto ClearOnExit = make_scope_exit([&]() {
12129 if (LPhi) {
12130 bool Erased = PendingMerges.erase(LPhi);
12131 assert(Erased && "Failed to erase LPhi!");
12132 (void)Erased;
12133 }
12134 if (RPhi) {
12135 bool Erased = PendingMerges.erase(RPhi);
12136 assert(Erased && "Failed to erase RPhi!");
12137 (void)Erased;
12138 }
12139 });
12140
12141 // Find respective Phis and check that they are not being pending.
12142 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12143 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12144 if (!PendingMerges.insert(Phi).second)
12145 return false;
12146 LPhi = Phi;
12147 }
12148 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12149 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12150 // If we detect a loop of Phi nodes being processed by this method, for
12151 // example:
12152 //
12153 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12154 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12155 //
12156 // we don't want to deal with a case that complex, so return conservative
12157 // answer false.
12158 if (!PendingMerges.insert(Phi).second)
12159 return false;
12160 RPhi = Phi;
12161 }
12162
12163 // If none of LHS, RHS is a Phi, nothing to do here.
12164 if (!LPhi && !RPhi)
12165 return false;
12166
12167 // If there is a SCEVUnknown Phi we are interested in, make it left.
12168 if (!LPhi) {
12169 std::swap(LHS, RHS);
12170 std::swap(FoundLHS, FoundRHS);
12171 std::swap(LPhi, RPhi);
12172 Pred = ICmpInst::getSwappedPredicate(Pred);
12173 }
12174
12175 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12176 const BasicBlock *LBB = LPhi->getParent();
12177 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12178
12179 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12180 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12181 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12182 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12183 };
12184
12185 if (RPhi && RPhi->getParent() == LBB) {
12186 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12187 // If we compare two Phis from the same block, and for each entry block
12188 // the predicate is true for incoming values from this block, then the
12189 // predicate is also true for the Phis.
12190 for (const BasicBlock *IncBB : predecessors(LBB)) {
12191 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12192 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12193 if (!ProvedEasily(L, R))
12194 return false;
12195 }
12196 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12197 // Case two: RHS is also a Phi from the same basic block, and it is an
12198 // AddRec. It means that there is a loop which has both AddRec and Unknown
12199 // PHIs, for it we can compare incoming values of AddRec from above the loop
12200 // and latch with their respective incoming values of LPhi.
12201 // TODO: Generalize to handle loops with many inputs in a header.
12202 if (LPhi->getNumIncomingValues() != 2) return false;
12203
12204 auto *RLoop = RAR->getLoop();
12205 auto *Predecessor = RLoop->getLoopPredecessor();
12206 assert(Predecessor && "Loop with AddRec with no predecessor?");
12207 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12208 if (!ProvedEasily(L1, RAR->getStart()))
12209 return false;
12210 auto *Latch = RLoop->getLoopLatch();
12211 assert(Latch && "Loop with AddRec with no latch?");
12212 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12213 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12214 return false;
12215 } else {
12216 // In all other cases go over inputs of LHS and compare each of them to RHS,
12217 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12218 // At this point RHS is either a non-Phi, or it is a Phi from some block
12219 // different from LBB.
12220 for (const BasicBlock *IncBB : predecessors(LBB)) {
12221 // Check that RHS is available in this block.
12222 if (!dominates(RHS, IncBB))
12223 return false;
12224 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12225 // Make sure L does not refer to a value from a potentially previous
12226 // iteration of a loop.
12227 if (!properlyDominates(L, LBB))
12228 return false;
12229 if (!ProvedEasily(L, RHS))
12230 return false;
12231 }
12232 }
12233 return true;
12234}
12235
12236bool ScalarEvolution::isImpliedCondOperandsViaShift(ICmpInst::Predicate Pred,
12237 const SCEV *LHS,
12238 const SCEV *RHS,
12239 const SCEV *FoundLHS,
12240 const SCEV *FoundRHS) {
12241 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12242 // sure that we are dealing with same LHS.
12243 if (RHS == FoundRHS) {
12244 std::swap(LHS, RHS);
12245 std::swap(FoundLHS, FoundRHS);
12246 Pred = ICmpInst::getSwappedPredicate(Pred);
12247 }
12248 if (LHS != FoundLHS)
12249 return false;
12250
12251 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12252 if (!SUFoundRHS)
12253 return false;
12254
12255 Value *Shiftee, *ShiftValue;
12256
12257 using namespace PatternMatch;
12258 if (match(SUFoundRHS->getValue(),
12259 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12260 auto *ShifteeS = getSCEV(Shiftee);
12261 // Prove one of the following:
12262 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12263 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12264 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12265 // ---> LHS <s RHS
12266 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12267 // ---> LHS <=s RHS
12268 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12269 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12270 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12271 if (isKnownNonNegative(ShifteeS))
12272 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12273 }
12274
12275 return false;
12276}
12277
12278bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
12279 const SCEV *LHS, const SCEV *RHS,
12280 const SCEV *FoundLHS,
12281 const SCEV *FoundRHS,
12282 const Instruction *CtxI) {
12283 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS, FoundRHS))
12284 return true;
12285
12286 if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS))
12287 return true;
12288
12289 if (isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS))
12290 return true;
12291
12292 if (isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12293 CtxI))
12294 return true;
12295
12296 return isImpliedCondOperandsHelper(Pred, LHS, RHS,
12297 FoundLHS, FoundRHS);
12298}
12299
12300/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12301template <typename MinMaxExprType>
12302static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12303 const SCEV *Candidate) {
12304 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12305 if (!MinMaxExpr)
12306 return false;
12307
12308 return is_contained(MinMaxExpr->operands(), Candidate);
12309}
12310
12313 const SCEV *LHS, const SCEV *RHS) {
12314 // If both sides are affine addrecs for the same loop, with equal
12315 // steps, and we know the recurrences don't wrap, then we only
12316 // need to check the predicate on the starting values.
12317
12318 if (!ICmpInst::isRelational(Pred))
12319 return false;
12320
12321 const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS);
12322 if (!LAR)
12323 return false;
12324 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12325 if (!RAR)
12326 return false;
12327 if (LAR->getLoop() != RAR->getLoop())
12328 return false;
12329 if (!LAR->isAffine() || !RAR->isAffine())
12330 return false;
12331
12332 if (LAR->getStepRecurrence(SE) != RAR->getStepRecurrence(SE))
12333 return false;
12334
12337 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12338 return false;
12339
12340 return SE.isKnownPredicate(Pred, LAR->getStart(), RAR->getStart());
12341}
12342
12343/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12344/// expression?
12347 const SCEV *LHS, const SCEV *RHS) {
12348 switch (Pred) {
12349 default:
12350 return false;
12351
12352 case ICmpInst::ICMP_SGE:
12353 std::swap(LHS, RHS);
12354 [[fallthrough]];
12355 case ICmpInst::ICMP_SLE:
12356 return
12357 // min(A, ...) <= A
12358 IsMinMaxConsistingOf<SCEVSMinExpr>(LHS, RHS) ||
12359 // A <= max(A, ...)
12360 IsMinMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS);
12361
12362 case ICmpInst::ICMP_UGE:
12363 std::swap(LHS, RHS);
12364 [[fallthrough]];
12365 case ICmpInst::ICMP_ULE:
12366 return
12367 // min(A, ...) <= A
12368 // FIXME: what about umin_seq?
12369 IsMinMaxConsistingOf<SCEVUMinExpr>(LHS, RHS) ||
12370 // A <= max(A, ...)
12371 IsMinMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS);
12372 }
12373
12374 llvm_unreachable("covered switch fell through?!");
12375}
12376
12377bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred,
12378 const SCEV *LHS, const SCEV *RHS,
12379 const SCEV *FoundLHS,
12380 const SCEV *FoundRHS,
12381 unsigned Depth) {
12384 "LHS and RHS have different sizes?");
12385 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12386 getTypeSizeInBits(FoundRHS->getType()) &&
12387 "FoundLHS and FoundRHS have different sizes?");
12388 // We want to avoid hurting the compile time with analysis of too big trees.
12390 return false;
12391
12392 // We only want to work with GT comparison so far.
12393 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) {
12394 Pred = CmpInst::getSwappedPredicate(Pred);
12395 std::swap(LHS, RHS);
12396 std::swap(FoundLHS, FoundRHS);
12397 }
12398
12399 // For unsigned, try to reduce it to corresponding signed comparison.
12400 if (Pred == ICmpInst::ICMP_UGT)
12401 // We can replace unsigned predicate with its signed counterpart if all
12402 // involved values are non-negative.
12403 // TODO: We could have better support for unsigned.
12404 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12405 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12406 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12407 // use this fact to prove that LHS and RHS are non-negative.
12408 const SCEV *MinusOne = getMinusOne(LHS->getType());
12409 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12410 FoundRHS) &&
12411 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12412 FoundRHS))
12413 Pred = ICmpInst::ICMP_SGT;
12414 }
12415
12416 if (Pred != ICmpInst::ICMP_SGT)
12417 return false;
12418
12419 auto GetOpFromSExt = [&](const SCEV *S) {
12420 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12421 return Ext->getOperand();
12422 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12423 // the constant in some cases.
12424 return S;
12425 };
12426
12427 // Acquire values from extensions.
12428 auto *OrigLHS = LHS;
12429 auto *OrigFoundLHS = FoundLHS;
12430 LHS = GetOpFromSExt(LHS);
12431 FoundLHS = GetOpFromSExt(FoundLHS);
12432
12433 // Is the SGT predicate can be proved trivially or using the found context.
12434 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12435 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12436 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12437 FoundRHS, Depth + 1);
12438 };
12439
12440 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12441 // We want to avoid creation of any new non-constant SCEV. Since we are
12442 // going to compare the operands to RHS, we should be certain that we don't
12443 // need any size extensions for this. So let's decline all cases when the
12444 // sizes of types of LHS and RHS do not match.
12445 // TODO: Maybe try to get RHS from sext to catch more cases?
12447 return false;
12448
12449 // Should not overflow.
12450 if (!LHSAddExpr->hasNoSignedWrap())
12451 return false;
12452
12453 auto *LL = LHSAddExpr->getOperand(0);
12454 auto *LR = LHSAddExpr->getOperand(1);
12455 auto *MinusOne = getMinusOne(RHS->getType());
12456
12457 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12458 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12459 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12460 };
12461 // Try to prove the following rule:
12462 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12463 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12464 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12465 return true;
12466 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12467 Value *LL, *LR;
12468 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12469
12470 using namespace llvm::PatternMatch;
12471
12472 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12473 // Rules for division.
12474 // We are going to perform some comparisons with Denominator and its
12475 // derivative expressions. In general case, creating a SCEV for it may
12476 // lead to a complex analysis of the entire graph, and in particular it
12477 // can request trip count recalculation for the same loop. This would
12478 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12479 // this, we only want to create SCEVs that are constants in this section.
12480 // So we bail if Denominator is not a constant.
12481 if (!isa<ConstantInt>(LR))
12482 return false;
12483
12484 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12485
12486 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
12487 // then a SCEV for the numerator already exists and matches with FoundLHS.
12488 auto *Numerator = getExistingSCEV(LL);
12489 if (!Numerator || Numerator->getType() != FoundLHS->getType())
12490 return false;
12491
12492 // Make sure that the numerator matches with FoundLHS and the denominator
12493 // is positive.
12494 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
12495 return false;
12496
12497 auto *DTy = Denominator->getType();
12498 auto *FRHSTy = FoundRHS->getType();
12499 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
12500 // One of types is a pointer and another one is not. We cannot extend
12501 // them properly to a wider type, so let us just reject this case.
12502 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
12503 // to avoid this check.
12504 return false;
12505
12506 // Given that:
12507 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
12508 auto *WTy = getWiderType(DTy, FRHSTy);
12509 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
12510 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
12511
12512 // Try to prove the following rule:
12513 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
12514 // For example, given that FoundLHS > 2. It means that FoundLHS is at
12515 // least 3. If we divide it by Denominator < 4, we will have at least 1.
12516 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
12517 if (isKnownNonPositive(RHS) &&
12518 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
12519 return true;
12520
12521 // Try to prove the following rule:
12522 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
12523 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
12524 // If we divide it by Denominator > 2, then:
12525 // 1. If FoundLHS is negative, then the result is 0.
12526 // 2. If FoundLHS is non-negative, then the result is non-negative.
12527 // Anyways, the result is non-negative.
12528 auto *MinusOne = getMinusOne(WTy);
12529 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
12530 if (isKnownNegative(RHS) &&
12531 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
12532 return true;
12533 }
12534 }
12535
12536 // If our expression contained SCEVUnknown Phis, and we split it down and now
12537 // need to prove something for them, try to prove the predicate for every
12538 // possible incoming values of those Phis.
12539 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
12540 return true;
12541
12542 return false;
12543}
12544
12546 const SCEV *LHS, const SCEV *RHS) {
12547 // zext x u<= sext x, sext x s<= zext x
12548 switch (Pred) {
12549 case ICmpInst::ICMP_SGE:
12550 std::swap(LHS, RHS);
12551 [[fallthrough]];
12552 case ICmpInst::ICMP_SLE: {
12553 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12554 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(LHS);
12555 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(RHS);
12556 if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
12557 return true;
12558 break;
12559 }
12560 case ICmpInst::ICMP_UGE:
12561 std::swap(LHS, RHS);
12562 [[fallthrough]];
12563 case ICmpInst::ICMP_ULE: {
12564 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then ZExt <u SExt.
12565 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS);
12566 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(RHS);
12567 if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
12568 return true;
12569 break;
12570 }
12571 default:
12572 break;
12573 };
12574 return false;
12575}
12576
12577bool
12578ScalarEvolution::isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred,
12579 const SCEV *LHS, const SCEV *RHS) {
12580 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
12581 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
12582 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
12583 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
12584 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
12585}
12586
12587bool
12588ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
12589 const SCEV *LHS, const SCEV *RHS,
12590 const SCEV *FoundLHS,
12591 const SCEV *FoundRHS) {
12592 switch (Pred) {
12593 default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
12594 case ICmpInst::ICMP_EQ:
12595 case ICmpInst::ICMP_NE:
12596 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
12597 return true;
12598 break;
12599 case ICmpInst::ICMP_SLT:
12600 case ICmpInst::ICMP_SLE:
12601 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
12602 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
12603 return true;
12604 break;
12605 case ICmpInst::ICMP_SGT:
12606 case ICmpInst::ICMP_SGE:
12607 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
12608 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
12609 return true;
12610 break;
12611 case ICmpInst::ICMP_ULT:
12612 case ICmpInst::ICMP_ULE:
12613 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
12614 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
12615 return true;
12616 break;
12617 case ICmpInst::ICMP_UGT:
12618 case ICmpInst::ICMP_UGE:
12619 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
12620 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
12621 return true;
12622 break;
12623 }
12624
12625 // Maybe it can be proved via operations?
12626 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
12627 return true;
12628
12629 return false;
12630}
12631
12632bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred,
12633 const SCEV *LHS,
12634 const SCEV *RHS,
12635 ICmpInst::Predicate FoundPred,
12636 const SCEV *FoundLHS,
12637 const SCEV *FoundRHS) {
12638 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
12639 // The restriction on `FoundRHS` be lifted easily -- it exists only to
12640 // reduce the compile time impact of this optimization.
12641 return false;
12642
12643 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
12644 if (!Addend)
12645 return false;
12646
12647 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
12648
12649 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
12650 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
12651 ConstantRange FoundLHSRange =
12652 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
12653
12654 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
12655 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
12656
12657 // We can also compute the range of values for `LHS` that satisfy the
12658 // consequent, "`LHS` `Pred` `RHS`":
12659 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
12660 // The antecedent implies the consequent if every value of `LHS` that
12661 // satisfies the antecedent also satisfies the consequent.
12662 return LHSRange.icmp(Pred, ConstRHS);
12663}
12664
12665bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
12666 bool IsSigned) {
12667 assert(isKnownPositive(Stride) && "Positive stride expected!");
12668
12669 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12670 const SCEV *One = getOne(Stride->getType());
12671
12672 if (IsSigned) {
12673 APInt MaxRHS = getSignedRangeMax(RHS);
12675 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12676
12677 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
12678 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
12679 }
12680
12681 APInt MaxRHS = getUnsignedRangeMax(RHS);
12682 APInt MaxValue = APInt::getMaxValue(BitWidth);
12683 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12684
12685 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
12686 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
12687}
12688
12689bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
12690 bool IsSigned) {
12691
12692 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12693 const SCEV *One = getOne(Stride->getType());
12694
12695 if (IsSigned) {
12696 APInt MinRHS = getSignedRangeMin(RHS);
12698 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12699
12700 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
12701 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
12702 }
12703
12704 APInt MinRHS = getUnsignedRangeMin(RHS);
12705 APInt MinValue = APInt::getMinValue(BitWidth);
12706 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12707
12708 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
12709 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
12710}
12711
12713 // umin(N, 1) + floor((N - umin(N, 1)) / D)
12714 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
12715 // expression fixes the case of N=0.
12716 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
12717 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
12718 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
12719}
12720
12721const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
12722 const SCEV *Stride,
12723 const SCEV *End,
12724 unsigned BitWidth,
12725 bool IsSigned) {
12726 // The logic in this function assumes we can represent a positive stride.
12727 // If we can't, the backedge-taken count must be zero.
12728 if (IsSigned && BitWidth == 1)
12729 return getZero(Stride->getType());
12730
12731 // This code below only been closely audited for negative strides in the
12732 // unsigned comparison case, it may be correct for signed comparison, but
12733 // that needs to be established.
12734 if (IsSigned && isKnownNegative(Stride))
12735 return getCouldNotCompute();
12736
12737 // Calculate the maximum backedge count based on the range of values
12738 // permitted by Start, End, and Stride.
12739 APInt MinStart =
12740 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
12741
12742 APInt MinStride =
12743 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
12744
12745 // We assume either the stride is positive, or the backedge-taken count
12746 // is zero. So force StrideForMaxBECount to be at least one.
12747 APInt One(BitWidth, 1);
12748 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
12749 : APIntOps::umax(One, MinStride);
12750
12751 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
12752 : APInt::getMaxValue(BitWidth);
12753 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
12754
12755 // Although End can be a MAX expression we estimate MaxEnd considering only
12756 // the case End = RHS of the loop termination condition. This is safe because
12757 // in the other case (End - Start) is zero, leading to a zero maximum backedge
12758 // taken count.
12759 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
12760 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
12761
12762 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
12763 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
12764 : APIntOps::umax(MaxEnd, MinStart);
12765
12766 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
12767 getConstant(StrideForMaxBECount) /* Step */);
12768}
12769
12771ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
12772 const Loop *L, bool IsSigned,
12773 bool ControlsOnlyExit, bool AllowPredicates) {
12775
12776 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
12777 bool PredicatedIV = false;
12778
12779 auto canAssumeNoSelfWrap = [&](const SCEVAddRecExpr *AR) {
12780 // Can we prove this loop *must* be UB if overflow of IV occurs?
12781 // Reasoning goes as follows:
12782 // * Suppose the IV did self wrap.
12783 // * If Stride evenly divides the iteration space, then once wrap
12784 // occurs, the loop must revisit the same values.
12785 // * We know that RHS is invariant, and that none of those values
12786 // caused this exit to be taken previously. Thus, this exit is
12787 // dynamically dead.
12788 // * If this is the sole exit, then a dead exit implies the loop
12789 // must be infinite if there are no abnormal exits.
12790 // * If the loop were infinite, then it must either not be mustprogress
12791 // or have side effects. Otherwise, it must be UB.
12792 // * It can't (by assumption), be UB so we have contradicted our
12793 // premise and can conclude the IV did not in fact self-wrap.
12794 if (!isLoopInvariant(RHS, L))
12795 return false;
12796
12797 auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
12798 if (!StrideC || !StrideC->getAPInt().isPowerOf2())
12799 return false;
12800
12801 if (!ControlsOnlyExit || !loopHasNoAbnormalExits(L))
12802 return false;
12803
12804 return loopIsFiniteByAssumption(L);
12805 };
12806
12807 if (!IV) {
12808 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
12809 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
12810 if (AR && AR->getLoop() == L && AR->isAffine()) {
12811 auto canProveNUW = [&]() {
12812 // We can use the comparison to infer no-wrap flags only if it fully
12813 // controls the loop exit.
12814 if (!ControlsOnlyExit)
12815 return false;
12816
12817 if (!isLoopInvariant(RHS, L))
12818 return false;
12819
12820 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
12821 // We need the sequence defined by AR to strictly increase in the
12822 // unsigned integer domain for the logic below to hold.
12823 return false;
12824
12825 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
12826 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
12827 // If RHS <=u Limit, then there must exist a value V in the sequence
12828 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
12829 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
12830 // overflow occurs. This limit also implies that a signed comparison
12831 // (in the wide bitwidth) is equivalent to an unsigned comparison as
12832 // the high bits on both sides must be zero.
12833 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
12834 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
12835 Limit = Limit.zext(OuterBitWidth);
12836 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
12837 };
12838 auto Flags = AR->getNoWrapFlags();
12839 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
12840 Flags = setFlags(Flags, SCEV::FlagNUW);
12841
12842 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
12843 if (AR->hasNoUnsignedWrap()) {
12844 // Emulate what getZeroExtendExpr would have done during construction
12845 // if we'd been able to infer the fact just above at that time.
12846 const SCEV *Step = AR->getStepRecurrence(*this);
12847 Type *Ty = ZExt->getType();
12848 auto *S = getAddRecExpr(
12849 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, 0),
12850 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
12851 IV = dyn_cast<SCEVAddRecExpr>(S);
12852 }
12853 }
12854 }
12855 }
12856
12857
12858 if (!IV && AllowPredicates) {
12859 // Try to make this an AddRec using runtime tests, in the first X
12860 // iterations of this loop, where X is the SCEV expression found by the
12861 // algorithm below.
12862 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
12863 PredicatedIV = true;
12864 }
12865
12866 // Avoid weird loops
12867 if (!IV || IV->getLoop() != L || !IV->isAffine())
12868 return getCouldNotCompute();
12869
12870 // A precondition of this method is that the condition being analyzed
12871 // reaches an exiting branch which dominates the latch. Given that, we can
12872 // assume that an increment which violates the nowrap specification and
12873 // produces poison must cause undefined behavior when the resulting poison
12874 // value is branched upon and thus we can conclude that the backedge is
12875 // taken no more often than would be required to produce that poison value.
12876 // Note that a well defined loop can exit on the iteration which violates
12877 // the nowrap specification if there is another exit (either explicit or
12878 // implicit/exceptional) which causes the loop to execute before the
12879 // exiting instruction we're analyzing would trigger UB.
12880 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
12881 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
12883
12884 const SCEV *Stride = IV->getStepRecurrence(*this);
12885
12886 bool PositiveStride = isKnownPositive(Stride);
12887
12888 // Avoid negative or zero stride values.
12889 if (!PositiveStride) {
12890 // We can compute the correct backedge taken count for loops with unknown
12891 // strides if we can prove that the loop is not an infinite loop with side
12892 // effects. Here's the loop structure we are trying to handle -
12893 //
12894 // i = start
12895 // do {
12896 // A[i] = i;
12897 // i += s;
12898 // } while (i < end);
12899 //
12900 // The backedge taken count for such loops is evaluated as -
12901 // (max(end, start + stride) - start - 1) /u stride
12902 //
12903 // The additional preconditions that we need to check to prove correctness
12904 // of the above formula is as follows -
12905 //
12906 // a) IV is either nuw or nsw depending upon signedness (indicated by the
12907 // NoWrap flag).
12908 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
12909 // no side effects within the loop)
12910 // c) loop has a single static exit (with no abnormal exits)
12911 //
12912 // Precondition a) implies that if the stride is negative, this is a single
12913 // trip loop. The backedge taken count formula reduces to zero in this case.
12914 //
12915 // Precondition b) and c) combine to imply that if rhs is invariant in L,
12916 // then a zero stride means the backedge can't be taken without executing
12917 // undefined behavior.
12918 //
12919 // The positive stride case is the same as isKnownPositive(Stride) returning
12920 // true (original behavior of the function).
12921 //
12922 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
12924 return getCouldNotCompute();
12925
12926 if (!isKnownNonZero(Stride)) {
12927 // If we have a step of zero, and RHS isn't invariant in L, we don't know
12928 // if it might eventually be greater than start and if so, on which
12929 // iteration. We can't even produce a useful upper bound.
12930 if (!isLoopInvariant(RHS, L))
12931 return getCouldNotCompute();
12932
12933 // We allow a potentially zero stride, but we need to divide by stride
12934 // below. Since the loop can't be infinite and this check must control
12935 // the sole exit, we can infer the exit must be taken on the first
12936 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
12937 // we know the numerator in the divides below must be zero, so we can
12938 // pick an arbitrary non-zero value for the denominator (e.g. stride)
12939 // and produce the right result.
12940 // FIXME: Handle the case where Stride is poison?
12941 auto wouldZeroStrideBeUB = [&]() {
12942 // Proof by contradiction. Suppose the stride were zero. If we can
12943 // prove that the backedge *is* taken on the first iteration, then since
12944 // we know this condition controls the sole exit, we must have an
12945 // infinite loop. We can't have a (well defined) infinite loop per
12946 // check just above.
12947 // Note: The (Start - Stride) term is used to get the start' term from
12948 // (start' + stride,+,stride). Remember that we only care about the
12949 // result of this expression when stride == 0 at runtime.
12950 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
12951 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
12952 };
12953 if (!wouldZeroStrideBeUB()) {
12954 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
12955 }
12956 }
12957 } else if (!Stride->isOne() && !NoWrap) {
12958 auto isUBOnWrap = [&]() {
12959 // From no-self-wrap, we need to then prove no-(un)signed-wrap. This
12960 // follows trivially from the fact that every (un)signed-wrapped, but
12961 // not self-wrapped value must be LT than the last value before
12962 // (un)signed wrap. Since we know that last value didn't exit, nor
12963 // will any smaller one.
12964 return canAssumeNoSelfWrap(IV);
12965 };
12966
12967 // Avoid proven overflow cases: this will ensure that the backedge taken
12968 // count will not generate any unsigned overflow. Relaxed no-overflow
12969 // conditions exploit NoWrapFlags, allowing to optimize in presence of
12970 // undefined behaviors like the case of C language.
12971 if (canIVOverflowOnLT(RHS, Stride, IsSigned) && !isUBOnWrap())
12972 return getCouldNotCompute();
12973 }
12974
12975 // On all paths just preceeding, we established the following invariant:
12976 // IV can be assumed not to overflow up to and including the exiting
12977 // iteration. We proved this in one of two ways:
12978 // 1) We can show overflow doesn't occur before the exiting iteration
12979 // 1a) canIVOverflowOnLT, and b) step of one
12980 // 2) We can show that if overflow occurs, the loop must execute UB
12981 // before any possible exit.
12982 // Note that we have not yet proved RHS invariant (in general).
12983
12984 const SCEV *Start = IV->getStart();
12985
12986 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
12987 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
12988 // Use integer-typed versions for actual computation; we can't subtract
12989 // pointers in general.
12990 const SCEV *OrigStart = Start;
12991 const SCEV *OrigRHS = RHS;
12992 if (Start->getType()->isPointerTy()) {
12993 Start = getLosslessPtrToIntExpr(Start);
12994 if (isa<SCEVCouldNotCompute>(Start))
12995 return Start;
12996 }
12997 if (RHS->getType()->isPointerTy()) {
12999 if (isa<SCEVCouldNotCompute>(RHS))
13000 return RHS;
13001 }
13002
13003 const SCEV *End = nullptr, *BECount = nullptr,
13004 *BECountIfBackedgeTaken = nullptr;
13005 if (!isLoopInvariant(RHS, L)) {
13006 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13007 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13008 RHSAddRec->getNoWrapFlags()) {
13009 // The structure of loop we are trying to calculate backedge count of:
13010 //
13011 // left = left_start
13012 // right = right_start
13013 //
13014 // while(left < right){
13015 // ... do something here ...
13016 // left += s1; // stride of left is s1 (s1 > 0)
13017 // right += s2; // stride of right is s2 (s2 < 0)
13018 // }
13019 //
13020
13021 const SCEV *RHSStart = RHSAddRec->getStart();
13022 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13023
13024 // If Stride - RHSStride is positive and does not overflow, we can write
13025 // backedge count as ->
13026 // ceil((End - Start) /u (Stride - RHSStride))
13027 // Where, End = max(RHSStart, Start)
13028
13029 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13030 if (isKnownNegative(RHSStride) &&
13031 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13032 RHSStride)) {
13033
13034 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13035 if (isKnownPositive(Denominator)) {
13036 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13037 : getUMaxExpr(RHSStart, Start);
13038
13039 // We can do this because End >= Start, as End = max(RHSStart, Start)
13040 const SCEV *Delta = getMinusSCEV(End, Start);
13041
13042 BECount = getUDivCeilSCEV(Delta, Denominator);
13043 BECountIfBackedgeTaken =
13044 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13045 }
13046 }
13047 }
13048 if (BECount == nullptr) {
13049 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13050 // given the start, stride and max value for the end bound of the
13051 // loop (RHS), and the fact that IV does not overflow (which is
13052 // checked above).
13053 const SCEV *MaxBECount = computeMaxBECountForLT(
13054 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13055 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13056 MaxBECount, false /*MaxOrZero*/, Predicates);
13057 }
13058 } else {
13059 // We use the expression (max(End,Start)-Start)/Stride to describe the
13060 // backedge count, as if the backedge is taken at least once
13061 // max(End,Start) is End and so the result is as above, and if not
13062 // max(End,Start) is Start so we get a backedge count of zero.
13063 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13064 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13065 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13066 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13067 // Can we prove (max(RHS,Start) > Start - Stride?
13068 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13069 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13070 // In this case, we can use a refined formula for computing backedge
13071 // taken count. The general formula remains:
13072 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13073 // We want to use the alternate formula:
13074 // "((End - 1) - (Start - Stride)) /u Stride"
13075 // Let's do a quick case analysis to show these are equivalent under
13076 // our precondition that max(RHS,Start) > Start - Stride.
13077 // * For RHS <= Start, the backedge-taken count must be zero.
13078 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13079 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13080 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13081 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13082 // reducing this to the stride of 1 case.
13083 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13084 // Stride".
13085 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13086 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13087 // "((RHS - (Start - Stride) - 1) /u Stride".
13088 // Our preconditions trivially imply no overflow in that form.
13089 const SCEV *MinusOne = getMinusOne(Stride->getType());
13090 const SCEV *Numerator =
13091 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13092 BECount = getUDivExpr(Numerator, Stride);
13093 }
13094
13095 if (!BECount) {
13096 auto canProveRHSGreaterThanEqualStart = [&]() {
13097 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13098 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13099 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13100
13101 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13102 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13103 return true;
13104
13105 // (RHS > Start - 1) implies RHS >= Start.
13106 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13107 // "Start - 1" doesn't overflow.
13108 // * For signed comparison, if Start - 1 does overflow, it's equal
13109 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13110 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13111 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13112 //
13113 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13114 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13115 auto *StartMinusOne =
13116 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13117 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13118 };
13119
13120 // If we know that RHS >= Start in the context of loop, then we know
13121 // that max(RHS, Start) = RHS at this point.
13122 if (canProveRHSGreaterThanEqualStart()) {
13123 End = RHS;
13124 } else {
13125 // If RHS < Start, the backedge will be taken zero times. So in
13126 // general, we can write the backedge-taken count as:
13127 //
13128 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13129 //
13130 // We convert it to the following to make it more convenient for SCEV:
13131 //
13132 // ceil(max(RHS, Start) - Start) / Stride
13133 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13134
13135 // See what would happen if we assume the backedge is taken. This is
13136 // used to compute MaxBECount.
13137 BECountIfBackedgeTaken =
13138 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13139 }
13140
13141 // At this point, we know:
13142 //
13143 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13144 // 2. The index variable doesn't overflow.
13145 //
13146 // Therefore, we know N exists such that
13147 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13148 // doesn't overflow.
13149 //
13150 // Using this information, try to prove whether the addition in
13151 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13152 const SCEV *One = getOne(Stride->getType());
13153 bool MayAddOverflow = [&] {
13154 if (auto *StrideC = dyn_cast<SCEVConstant>(Stride)) {
13155 if (StrideC->getAPInt().isPowerOf2()) {
13156 // Suppose Stride is a power of two, and Start/End are unsigned
13157 // integers. Let UMAX be the largest representable unsigned
13158 // integer.
13159 //
13160 // By the preconditions of this function, we know
13161 // "(Start + Stride * N) >= End", and this doesn't overflow.
13162 // As a formula:
13163 //
13164 // End <= (Start + Stride * N) <= UMAX
13165 //
13166 // Subtracting Start from all the terms:
13167 //
13168 // End - Start <= Stride * N <= UMAX - Start
13169 //
13170 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13171 //
13172 // End - Start <= Stride * N <= UMAX
13173 //
13174 // Stride * N is a multiple of Stride. Therefore,
13175 //
13176 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13177 //
13178 // Since Stride is a power of two, UMAX + 1 is divisible by
13179 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13180 // write:
13181 //
13182 // End - Start <= Stride * N <= UMAX - Stride - 1
13183 //
13184 // Dropping the middle term:
13185 //
13186 // End - Start <= UMAX - Stride - 1
13187 //
13188 // Adding Stride - 1 to both sides:
13189 //
13190 // (End - Start) + (Stride - 1) <= UMAX
13191 //
13192 // In other words, the addition doesn't have unsigned overflow.
13193 //
13194 // A similar proof works if we treat Start/End as signed values.
13195 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13196 // to use signed max instead of unsigned max. Note that we're
13197 // trying to prove a lack of unsigned overflow in either case.
13198 return false;
13199 }
13200 }
13201 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13202 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13203 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13204 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13205 // 1 <s End.
13206 //
13207 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13208 // End.
13209 return false;
13210 }
13211 return true;
13212 }();
13213
13214 const SCEV *Delta = getMinusSCEV(End, Start);
13215 if (!MayAddOverflow) {
13216 // floor((D + (S - 1)) / S)
13217 // We prefer this formulation if it's legal because it's fewer
13218 // operations.
13219 BECount =
13220 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13221 } else {
13222 BECount = getUDivCeilSCEV(Delta, Stride);
13223 }
13224 }
13225 }
13226
13227 const SCEV *ConstantMaxBECount;
13228 bool MaxOrZero = false;
13229 if (isa<SCEVConstant>(BECount)) {
13230 ConstantMaxBECount = BECount;
13231 } else if (BECountIfBackedgeTaken &&
13232 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13233 // If we know exactly how many times the backedge will be taken if it's
13234 // taken at least once, then the backedge count will either be that or
13235 // zero.
13236 ConstantMaxBECount = BECountIfBackedgeTaken;
13237 MaxOrZero = true;
13238 } else {
13239 ConstantMaxBECount = computeMaxBECountForLT(
13240 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13241 }
13242
13243 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13244 !isa<SCEVCouldNotCompute>(BECount))
13245 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13246
13247 const SCEV *SymbolicMaxBECount =
13248 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13249 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13250 Predicates);
13251}
13252
13253ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13254 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13255 bool ControlsOnlyExit, bool AllowPredicates) {
13257 // We handle only IV > Invariant
13258 if (!isLoopInvariant(RHS, L))
13259 return getCouldNotCompute();
13260
13261 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13262 if (!IV && AllowPredicates)
13263 // Try to make this an AddRec using runtime tests, in the first X
13264 // iterations of this loop, where X is the SCEV expression found by the
13265 // algorithm below.
13266 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13267
13268 // Avoid weird loops
13269 if (!IV || IV->getLoop() != L || !IV->isAffine())
13270 return getCouldNotCompute();
13271
13272 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13273 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13275
13276 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13277
13278 // Avoid negative or zero stride values
13279 if (!isKnownPositive(Stride))
13280 return getCouldNotCompute();
13281
13282 // Avoid proven overflow cases: this will ensure that the backedge taken count
13283 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13284 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13285 // behaviors like the case of C language.
13286 if (!Stride->isOne() && !NoWrap)
13287 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13288 return getCouldNotCompute();
13289
13290 const SCEV *Start = IV->getStart();
13291 const SCEV *End = RHS;
13292 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13293 // If we know that Start >= RHS in the context of loop, then we know that
13294 // min(RHS, Start) = RHS at this point.
13296 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13297 End = RHS;
13298 else
13299 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13300 }
13301
13302 if (Start->getType()->isPointerTy()) {
13303 Start = getLosslessPtrToIntExpr(Start);
13304 if (isa<SCEVCouldNotCompute>(Start))
13305 return Start;
13306 }
13307 if (End->getType()->isPointerTy()) {
13309 if (isa<SCEVCouldNotCompute>(End))
13310 return End;
13311 }
13312
13313 // Compute ((Start - End) + (Stride - 1)) / Stride.
13314 // FIXME: This can overflow. Holding off on fixing this for now;
13315 // howManyGreaterThans will hopefully be gone soon.
13316 const SCEV *One = getOne(Stride->getType());
13317 const SCEV *BECount = getUDivExpr(
13318 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13319
13320 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13321 : getUnsignedRangeMax(Start);
13322
13323 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13324 : getUnsignedRangeMin(Stride);
13325
13326 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13327 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13328 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13329
13330 // Although End can be a MIN expression we estimate MinEnd considering only
13331 // the case End = RHS. This is safe because in the other case (Start - End)
13332 // is zero, leading to a zero maximum backedge taken count.
13333 APInt MinEnd =
13334 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13335 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13336
13337 const SCEV *ConstantMaxBECount =
13338 isa<SCEVConstant>(BECount)
13339 ? BECount
13340 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13341 getConstant(MinStride));
13342
13343 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13344 ConstantMaxBECount = BECount;
13345 const SCEV *SymbolicMaxBECount =
13346 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13347
13348 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13349 Predicates);
13350}
13351
13353 ScalarEvolution &SE) const {
13354 if (Range.isFullSet()) // Infinite loop.
13355 return SE.getCouldNotCompute();
13356
13357 // If the start is a non-zero constant, shift the range to simplify things.
13358 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13359 if (!SC->getValue()->isZero()) {
13361 Operands[0] = SE.getZero(SC->getType());
13362 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13363 getNoWrapFlags(FlagNW));
13364 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13365 return ShiftedAddRec->getNumIterationsInRange(
13366 Range.subtract(SC->getAPInt()), SE);
13367 // This is strange and shouldn't happen.
13368 return SE.getCouldNotCompute();
13369 }
13370
13371 // The only time we can solve this is when we have all constant indices.
13372 // Otherwise, we cannot determine the overflow conditions.
13373 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13374 return SE.getCouldNotCompute();
13375
13376 // Okay at this point we know that all elements of the chrec are constants and
13377 // that the start element is zero.
13378
13379 // First check to see if the range contains zero. If not, the first
13380 // iteration exits.
13381 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13382 if (!Range.contains(APInt(BitWidth, 0)))
13383 return SE.getZero(getType());
13384
13385 if (isAffine()) {
13386 // If this is an affine expression then we have this situation:
13387 // Solve {0,+,A} in Range === Ax in Range
13388
13389 // We know that zero is in the range. If A is positive then we know that
13390 // the upper value of the range must be the first possible exit value.
13391 // If A is negative then the lower of the range is the last possible loop
13392 // value. Also note that we already checked for a full range.
13393 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13394 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13395
13396 // The exit value should be (End+A)/A.
13397 APInt ExitVal = (End + A).udiv(A);
13398 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13399
13400 // Evaluate at the exit value. If we really did fall out of the valid
13401 // range, then we computed our trip count, otherwise wrap around or other
13402 // things must have happened.
13403 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13404 if (Range.contains(Val->getValue()))
13405 return SE.getCouldNotCompute(); // Something strange happened
13406
13407 // Ensure that the previous value is in the range.
13410 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13411 "Linear scev computation is off in a bad way!");
13412 return SE.getConstant(ExitValue);
13413 }
13414
13415 if (isQuadratic()) {
13416 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13417 return SE.getConstant(*S);
13418 }
13419
13420 return SE.getCouldNotCompute();
13421}
13422
13423const SCEVAddRecExpr *
13425 assert(getNumOperands() > 1 && "AddRec with zero step?");
13426 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13427 // but in this case we cannot guarantee that the value returned will be an
13428 // AddRec because SCEV does not have a fixed point where it stops
13429 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13430 // may happen if we reach arithmetic depth limit while simplifying. So we
13431 // construct the returned value explicitly.
13433 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13434 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13435 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13436 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13437 // We know that the last operand is not a constant zero (otherwise it would
13438 // have been popped out earlier). This guarantees us that if the result has
13439 // the same last operand, then it will also not be popped out, meaning that
13440 // the returned value will be an AddRec.
13441 const SCEV *Last = getOperand(getNumOperands() - 1);
13442 assert(!Last->isZero() && "Recurrency with zero step?");
13443 Ops.push_back(Last);
13444 return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, getLoop(),
13446}
13447
13448// Return true when S contains at least an undef value.
13450 return SCEVExprContains(S, [](const SCEV *S) {
13451 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13452 return isa<UndefValue>(SU->getValue());
13453 return false;
13454 });
13455}
13456
13457// Return true when S contains a value that is a nullptr.
13459 return SCEVExprContains(S, [](const SCEV *S) {
13460 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13461 return SU->getValue() == nullptr;
13462 return false;
13463 });
13464}
13465
13466/// Return the size of an element read or written by Inst.
13468 Type *Ty;
13469 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13470 Ty = Store->getValueOperand()->getType();
13471 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13472 Ty = Load->getType();
13473 else
13474 return nullptr;
13475
13477 return getSizeOfExpr(ETy, Ty);
13478}
13479
13480//===----------------------------------------------------------------------===//
13481// SCEVCallbackVH Class Implementation
13482//===----------------------------------------------------------------------===//
13483
13484void ScalarEvolution::SCEVCallbackVH::deleted() {
13485 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13486 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13487 SE->ConstantEvolutionLoopExitValue.erase(PN);
13488 SE->eraseValueFromMap(getValPtr());
13489 // this now dangles!
13490}
13491
13492void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13493 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13494
13495 // Forget all the expressions associated with users of the old value,
13496 // so that future queries will recompute the expressions using the new
13497 // value.
13498 SE->forgetValue(getValPtr());
13499 // this now dangles!
13500}
13501
13502ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13503 : CallbackVH(V), SE(se) {}
13504
13505//===----------------------------------------------------------------------===//
13506// ScalarEvolution Class Implementation
13507//===----------------------------------------------------------------------===//
13508
13511 LoopInfo &LI)
13512 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
13513 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13514 LoopDispositions(64), BlockDispositions(64) {
13515 // To use guards for proving predicates, we need to scan every instruction in
13516 // relevant basic blocks, and not just terminators. Doing this is a waste of
13517 // time if the IR does not actually contain any calls to
13518 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13519 //
13520 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13521 // to _add_ guards to the module when there weren't any before, and wants
13522 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13523 // efficient in lieu of being smart in that rather obscure case.
13524
13525 auto *GuardDecl = F.getParent()->getFunction(
13526 Intrinsic::getName(Intrinsic::experimental_guard));
13527 HasGuards = GuardDecl && !GuardDecl->use_empty();
13528}
13529
13531 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
13532 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13533 ValueExprMap(std::move(Arg.ValueExprMap)),
13534 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13535 PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
13536 PendingMerges(std::move(Arg.PendingMerges)),
13537 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
13538 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13539 PredicatedBackedgeTakenCounts(
13540 std::move(Arg.PredicatedBackedgeTakenCounts)),
13541 BECountUsers(std::move(Arg.BECountUsers)),
13542 ConstantEvolutionLoopExitValue(
13543 std::move(Arg.ConstantEvolutionLoopExitValue)),
13544 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
13545 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
13546 LoopDispositions(std::move(Arg.LoopDispositions)),
13547 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
13548 BlockDispositions(std::move(Arg.BlockDispositions)),
13549 SCEVUsers(std::move(Arg.SCEVUsers)),
13550 UnsignedRanges(std::move(Arg.UnsignedRanges)),
13551 SignedRanges(std::move(Arg.SignedRanges)),
13552 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
13553 UniquePreds(std::move(Arg.UniquePreds)),
13554 SCEVAllocator(std::move(Arg.SCEVAllocator)),
13555 LoopUsers(std::move(Arg.LoopUsers)),
13556 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
13557 FirstUnknown(Arg.FirstUnknown) {
13558 Arg.FirstUnknown = nullptr;
13559}
13560
13562 // Iterate through all the SCEVUnknown instances and call their
13563 // destructors, so that they release their references to their values.
13564 for (SCEVUnknown *U = FirstUnknown; U;) {
13565 SCEVUnknown *Tmp = U;
13566 U = U->Next;
13567 Tmp->~SCEVUnknown();
13568 }
13569 FirstUnknown = nullptr;
13570
13571 ExprValueMap.clear();
13572 ValueExprMap.clear();
13573 HasRecMap.clear();
13574 BackedgeTakenCounts.clear();
13575 PredicatedBackedgeTakenCounts.clear();
13576
13577 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
13578 assert(PendingPhiRanges.empty() && "getRangeRef garbage");
13579 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
13580 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
13581 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
13582}
13583
13585 return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
13586}
13587
13588/// When printing a top-level SCEV for trip counts, it's helpful to include
13589/// a type for constants which are otherwise hard to disambiguate.
13590static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
13591 if (isa<SCEVConstant>(S))
13592 OS << *S->getType() << " ";
13593 OS << *S;
13594}
13595
13597 const Loop *L) {
13598 // Print all inner loops first
13599 for (Loop *I : *L)
13600 PrintLoopInfo(OS, SE, I);
13601
13602 OS << "Loop ";
13603 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13604 OS << ": ";
13605
13606 SmallVector<BasicBlock *, 8> ExitingBlocks;
13607 L->getExitingBlocks(ExitingBlocks);
13608 if (ExitingBlocks.size() != 1)
13609 OS << "<multiple exits> ";
13610
13611 auto *BTC = SE->getBackedgeTakenCount(L);
13612 if (!isa<SCEVCouldNotCompute>(BTC)) {
13613 OS << "backedge-taken count is ";
13615 } else
13616 OS << "Unpredictable backedge-taken count.";
13617 OS << "\n";
13618
13619 if (ExitingBlocks.size() > 1)
13620 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13621 OS << " exit count for " << ExitingBlock->getName() << ": ";
13622 PrintSCEVWithTypeHint(OS, SE->getExitCount(L, ExitingBlock));
13623 OS << "\n";
13624 }
13625
13626 OS << "Loop ";
13627 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13628 OS << ": ";
13629
13630 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
13631 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
13632 OS << "constant max backedge-taken count is ";
13633 PrintSCEVWithTypeHint(OS, ConstantBTC);
13635 OS << ", actual taken count either this or zero.";
13636 } else {
13637 OS << "Unpredictable constant max backedge-taken count. ";
13638 }
13639
13640 OS << "\n"
13641 "Loop ";
13642 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13643 OS << ": ";
13644
13645 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
13646 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
13647 OS << "symbolic max backedge-taken count is ";
13648 PrintSCEVWithTypeHint(OS, SymbolicBTC);
13650 OS << ", actual taken count either this or zero.";
13651 } else {
13652 OS << "Unpredictable symbolic max backedge-taken count. ";
13653 }
13654 OS << "\n";
13655
13656 if (ExitingBlocks.size() > 1)
13657 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13658 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
13659 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
13661 PrintSCEVWithTypeHint(OS, ExitBTC);
13662 OS << "\n";
13663 }
13664
13666 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
13667 if (PBT != BTC || !Preds.empty()) {
13668 OS << "Loop ";
13669 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13670 OS << ": ";
13671 if (!isa<SCEVCouldNotCompute>(PBT)) {
13672 OS << "Predicated backedge-taken count is ";
13674 } else
13675 OS << "Unpredictable predicated backedge-taken count.";
13676 OS << "\n";
13677 OS << " Predicates:\n";
13678 for (const auto *P : Preds)
13679 P->print(OS, 4);
13680 }
13681
13682 Preds.clear();
13683 auto *PredSymbolicMax =
13685 if (SymbolicBTC != PredSymbolicMax) {
13686 OS << "Loop ";
13687 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13688 OS << ": ";
13689 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
13690 OS << "Predicated symbolic max backedge-taken count is ";
13691 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
13692 } else
13693 OS << "Unpredictable predicated symbolic max backedge-taken count.";
13694 OS << "\n";
13695 OS << " Predicates:\n";
13696 for (const auto *P : Preds)
13697 P->print(OS, 4);
13698 }
13699
13701 OS << "Loop ";
13702 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13703 OS << ": ";
13704 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
13705 }
13706}
13707
13708namespace llvm {
13710 switch (LD) {
13712 OS << "Variant";
13713 break;
13715 OS << "Invariant";
13716 break;
13718 OS << "Computable";
13719 break;
13720 }
13721 return OS;
13722}
13723
13725 switch (BD) {
13727 OS << "DoesNotDominate";
13728 break;
13730 OS << "Dominates";
13731 break;
13733 OS << "ProperlyDominates";
13734 break;
13735 }
13736 return OS;
13737}
13738} // namespace llvm
13739
13741 // ScalarEvolution's implementation of the print method is to print
13742 // out SCEV values of all instructions that are interesting. Doing
13743 // this potentially causes it to create new SCEV objects though,
13744 // which technically conflicts with the const qualifier. This isn't
13745 // observable from outside the class though, so casting away the
13746 // const isn't dangerous.
13747 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
13748
13749 if (ClassifyExpressions) {
13750 OS << "Classifying expressions for: ";
13751 F.printAsOperand(OS, /*PrintType=*/false);
13752 OS << "\n";
13753 for (Instruction &I : instructions(F))
13754 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
13755 OS << I << '\n';
13756 OS << " --> ";
13757 const SCEV *SV = SE.getSCEV(&I);
13758 SV->print(OS);
13759 if (!isa<SCEVCouldNotCompute>(SV)) {
13760 OS << " U: ";
13761 SE.getUnsignedRange(SV).print(OS);
13762 OS << " S: ";
13763 SE.getSignedRange(SV).print(OS);
13764 }
13765
13766 const Loop *L = LI.getLoopFor(I.getParent());
13767
13768 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
13769 if (AtUse != SV) {
13770 OS << " --> ";
13771 AtUse->print(OS);
13772 if (!isa<SCEVCouldNotCompute>(AtUse)) {
13773 OS << " U: ";
13774 SE.getUnsignedRange(AtUse).print(OS);
13775 OS << " S: ";
13776 SE.getSignedRange(AtUse).print(OS);
13777 }
13778 }
13779
13780 if (L) {
13781 OS << "\t\t" "Exits: ";
13782 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
13783 if (!SE.isLoopInvariant(ExitValue, L)) {
13784 OS << "<<Unknown>>";
13785 } else {
13786 OS << *ExitValue;
13787 }
13788
13789 bool First = true;
13790 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
13791 if (First) {
13792 OS << "\t\t" "LoopDispositions: { ";
13793 First = false;
13794 } else {
13795 OS << ", ";
13796 }
13797
13798 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13799 OS << ": " << SE.getLoopDisposition(SV, Iter);
13800 }
13801
13802 for (const auto *InnerL : depth_first(L)) {
13803 if (InnerL == L)
13804 continue;
13805 if (First) {
13806 OS << "\t\t" "LoopDispositions: { ";
13807 First = false;
13808 } else {
13809 OS << ", ";
13810 }
13811
13812 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13813 OS << ": " << SE.getLoopDisposition(SV, InnerL);
13814 }
13815
13816 OS << " }";
13817 }
13818
13819 OS << "\n";
13820 }
13821 }
13822
13823 OS << "Determining loop execution counts for: ";
13824 F.printAsOperand(OS, /*PrintType=*/false);
13825 OS << "\n";
13826 for (Loop *I : LI)
13827 PrintLoopInfo(OS, &SE, I);
13828}
13829
13832 auto &Values = LoopDispositions[S];
13833 for (auto &V : Values) {
13834 if (V.getPointer() == L)
13835 return V.getInt();
13836 }
13837 Values.emplace_back(L, LoopVariant);
13838 LoopDisposition D = computeLoopDisposition(S, L);
13839 auto &Values2 = LoopDispositions[S];
13840 for (auto &V : llvm::reverse(Values2)) {
13841 if (V.getPointer() == L) {
13842 V.setInt(D);
13843 break;
13844 }
13845 }
13846 return D;
13847}
13848
13850ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
13851 switch (S->getSCEVType()) {
13852 case scConstant:
13853 case scVScale:
13854 return LoopInvariant;
13855 case scAddRecExpr: {
13856 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
13857
13858 // If L is the addrec's loop, it's computable.
13859 if (AR->getLoop() == L)
13860 return LoopComputable;
13861
13862 // Add recurrences are never invariant in the function-body (null loop).
13863 if (!L)
13864 return LoopVariant;
13865
13866 // Everything that is not defined at loop entry is variant.
13867 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
13868 return LoopVariant;
13869 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
13870 " dominate the contained loop's header?");
13871
13872 // This recurrence is invariant w.r.t. L if AR's loop contains L.
13873 if (AR->getLoop()->contains(L))
13874 return LoopInvariant;
13875
13876 // This recurrence is variant w.r.t. L if any of its operands
13877 // are variant.
13878 for (const auto *Op : AR->operands())
13879 if (!isLoopInvariant(Op, L))
13880 return LoopVariant;
13881
13882 // Otherwise it's loop-invariant.
13883 return LoopInvariant;
13884 }
13885 case scTruncate:
13886 case scZeroExtend:
13887 case scSignExtend:
13888 case scPtrToInt:
13889 case scAddExpr:
13890 case scMulExpr:
13891 case scUDivExpr:
13892 case scUMaxExpr:
13893 case scSMaxExpr:
13894 case scUMinExpr:
13895 case scSMinExpr:
13896 case scSequentialUMinExpr: {
13897 bool HasVarying = false;
13898 for (const auto *Op : S->operands()) {
13900 if (D == LoopVariant)
13901 return LoopVariant;
13902 if (D == LoopComputable)
13903 HasVarying = true;
13904 }
13905 return HasVarying ? LoopComputable : LoopInvariant;
13906 }
13907 case scUnknown:
13908 // All non-instruction values are loop invariant. All instructions are loop
13909 // invariant if they are not contained in the specified loop.
13910 // Instructions are never considered invariant in the function body
13911 // (null loop) because they are defined within the "loop".
13912 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
13913 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
13914 return LoopInvariant;
13915 case scCouldNotCompute:
13916 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
13917 }
13918 llvm_unreachable("Unknown SCEV kind!");
13919}
13920
13922 return getLoopDisposition(S, L) == LoopInvariant;
13923}
13924
13926 return getLoopDisposition(S, L) == LoopComputable;
13927}
13928
13931 auto &Values = BlockDispositions[S];
13932 for (auto &V : Values) {
13933 if (V.getPointer() == BB)
13934 return V.getInt();
13935 }
13936 Values.emplace_back(BB, DoesNotDominateBlock);
13937 BlockDisposition D = computeBlockDisposition(S, BB);
13938 auto &Values2 = BlockDispositions[S];
13939 for (auto &V : llvm::reverse(Values2)) {
13940 if (V.getPointer() == BB) {
13941 V.setInt(D);
13942 break;
13943 }
13944 }
13945 return D;
13946}
13947
13949ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
13950 switch (S->getSCEVType()) {
13951 case scConstant:
13952 case scVScale:
13954 case scAddRecExpr: {
13955 // This uses a "dominates" query instead of "properly dominates" query
13956 // to test for proper dominance too, because the instruction which
13957 // produces the addrec's value is a PHI, and a PHI effectively properly
13958 // dominates its entire containing block.
13959 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
13960 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
13961 return DoesNotDominateBlock;
13962
13963 // Fall through into SCEVNAryExpr handling.
13964 [[fallthrough]];
13965 }
13966 case scTruncate:
13967 case scZeroExtend:
13968 case scSignExtend:
13969 case scPtrToInt:
13970 case scAddExpr:
13971 case scMulExpr:
13972 case scUDivExpr:
13973 case scUMaxExpr:
13974 case scSMaxExpr:
13975 case scUMinExpr:
13976 case scSMinExpr:
13977 case scSequentialUMinExpr: {
13978 bool Proper = true;
13979 for (const SCEV *NAryOp : S->operands()) {
13981 if (D == DoesNotDominateBlock)
13982 return DoesNotDominateBlock;
13983 if (D == DominatesBlock)
13984 Proper = false;
13985 }
13986 return Proper ? ProperlyDominatesBlock : DominatesBlock;
13987 }
13988 case scUnknown:
13989 if (Instruction *I =
13990 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
13991 if (I->getParent() == BB)
13992 return DominatesBlock;
13993 if (DT.properlyDominates(I->getParent(), BB))
13995 return DoesNotDominateBlock;
13996 }
13998 case scCouldNotCompute:
13999 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14000 }
14001 llvm_unreachable("Unknown SCEV kind!");
14002}
14003
14004bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14005 return getBlockDisposition(S, BB) >= DominatesBlock;
14006}
14007
14010}
14011
14012bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14013 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14014}
14015
14016void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14017 bool Predicated) {
14018 auto &BECounts =
14019 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14020 auto It = BECounts.find(L);
14021 if (It != BECounts.end()) {
14022 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14023 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14024 if (!isa<SCEVConstant>(S)) {
14025 auto UserIt = BECountUsers.find(S);
14026 assert(UserIt != BECountUsers.end());
14027 UserIt->second.erase({L, Predicated});
14028 }
14029 }
14030 }
14031 BECounts.erase(It);
14032 }
14033}
14034
14035void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
14036 SmallPtrSet<const SCEV *, 8> ToForget(SCEVs.begin(), SCEVs.end());
14037 SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
14038
14039 while (!Worklist.empty()) {
14040 const SCEV *Curr = Worklist.pop_back_val();
14041 auto Users = SCEVUsers.find(Curr);
14042 if (Users != SCEVUsers.end())
14043 for (const auto *User : Users->second)
14044 if (ToForget.insert(User).second)
14045 Worklist.push_back(User);
14046 }
14047
14048 for (const auto *S : ToForget)
14049 forgetMemoizedResultsImpl(S);
14050
14051 for (auto I = PredicatedSCEVRewrites.begin();
14052 I != PredicatedSCEVRewrites.end();) {
14053 std::pair<const SCEV *, const Loop *> Entry = I->first;
14054 if (ToForget.count(Entry.first))
14055 PredicatedSCEVRewrites.erase(I++);
14056 else
14057 ++I;
14058 }
14059}
14060
14061void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14062 LoopDispositions.erase(S);
14063 BlockDispositions.erase(S);
14064 UnsignedRanges.erase(S);
14065 SignedRanges.erase(S);
14066 HasRecMap.erase(S);
14067 ConstantMultipleCache.erase(S);
14068
14069 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14070 UnsignedWrapViaInductionTried.erase(AR);
14071 SignedWrapViaInductionTried.erase(AR);
14072 }
14073
14074 auto ExprIt = ExprValueMap.find(S);
14075 if (ExprIt != ExprValueMap.end()) {
14076 for (Value *V : ExprIt->second) {
14077 auto ValueIt = ValueExprMap.find_as(V);
14078 if (ValueIt != ValueExprMap.end())
14079 ValueExprMap.erase(ValueIt);
14080 }
14081 ExprValueMap.erase(ExprIt);
14082 }
14083
14084 auto ScopeIt = ValuesAtScopes.find(S);
14085 if (ScopeIt != ValuesAtScopes.end()) {
14086 for (const auto &Pair : ScopeIt->second)
14087 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14088 llvm::erase(ValuesAtScopesUsers[Pair.second],
14089 std::make_pair(Pair.first, S));
14090 ValuesAtScopes.erase(ScopeIt);
14091 }
14092
14093 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14094 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14095 for (const auto &Pair : ScopeUserIt->second)
14096 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14097 ValuesAtScopesUsers.erase(ScopeUserIt);
14098 }
14099
14100 auto BEUsersIt = BECountUsers.find(S);
14101 if (BEUsersIt != BECountUsers.end()) {
14102 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14103 auto Copy = BEUsersIt->second;
14104 for (const auto &Pair : Copy)
14105 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14106 BECountUsers.erase(BEUsersIt);
14107 }
14108
14109 auto FoldUser = FoldCacheUser.find(S);
14110 if (FoldUser != FoldCacheUser.end())
14111 for (auto &KV : FoldUser->second)
14112 FoldCache.erase(KV);
14113 FoldCacheUser.erase(S);
14114}
14115
14116void
14117ScalarEvolution::getUsedLoops(const SCEV *S,
14118 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14119 struct FindUsedLoops {
14120 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14121 : LoopsUsed(LoopsUsed) {}
14123 bool follow(const SCEV *S) {
14124 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14125 LoopsUsed.insert(AR->getLoop());
14126 return true;
14127 }
14128
14129 bool isDone() const { return false; }
14130 };
14131
14132 FindUsedLoops F(LoopsUsed);
14134}
14135
14136void ScalarEvolution::getReachableBlocks(
14139 Worklist.push_back(&F.getEntryBlock());
14140 while (!Worklist.empty()) {
14141 BasicBlock *BB = Worklist.pop_back_val();
14142 if (!Reachable.insert(BB).second)
14143 continue;
14144
14145 Value *Cond;
14146 BasicBlock *TrueBB, *FalseBB;
14147 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14148 m_BasicBlock(FalseBB)))) {
14149 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14150 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14151 continue;
14152 }
14153
14154 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14155 const SCEV *L = getSCEV(Cmp->getOperand(0));
14156 const SCEV *R = getSCEV(Cmp->getOperand(1));
14157 if (isKnownPredicateViaConstantRanges(Cmp->getPredicate(), L, R)) {
14158 Worklist.push_back(TrueBB);
14159 continue;
14160 }
14161 if (isKnownPredicateViaConstantRanges(Cmp->getInversePredicate(), L,
14162 R)) {
14163 Worklist.push_back(FalseBB);
14164 continue;
14165 }
14166 }
14167 }
14168
14169 append_range(Worklist, successors(BB));
14170 }
14171}
14172
14174 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14175 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14176
14177 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14178
14179 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14180 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14181 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14182
14183 const SCEV *visitConstant(const SCEVConstant *Constant) {
14184 return SE.getConstant(Constant->getAPInt());
14185 }
14186
14187 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14188 return SE.getUnknown(Expr->getValue());
14189 }
14190
14191 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14192 return SE.getCouldNotCompute();
14193 }
14194 };
14195
14196 SCEVMapper SCM(SE2);
14197 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14198 SE2.getReachableBlocks(ReachableBlocks, F);
14199
14200 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14201 if (containsUndefs(Old) || containsUndefs(New)) {
14202 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14203 // not propagate undef aggressively). This means we can (and do) fail
14204 // verification in cases where a transform makes a value go from "undef"
14205 // to "undef+1" (say). The transform is fine, since in both cases the
14206 // result is "undef", but SCEV thinks the value increased by 1.
14207 return nullptr;
14208 }
14209
14210 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14211 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14212 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14213 return nullptr;
14214
14215 return Delta;
14216 };
14217
14218 while (!LoopStack.empty()) {
14219 auto *L = LoopStack.pop_back_val();
14220 llvm::append_range(LoopStack, *L);
14221
14222 // Only verify BECounts in reachable loops. For an unreachable loop,
14223 // any BECount is legal.
14224 if (!ReachableBlocks.contains(L->getHeader()))
14225 continue;
14226
14227 // Only verify cached BECounts. Computing new BECounts may change the
14228 // results of subsequent SCEV uses.
14229 auto It = BackedgeTakenCounts.find(L);
14230 if (It == BackedgeTakenCounts.end())
14231 continue;
14232
14233 auto *CurBECount =
14234 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14235 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14236
14237 if (CurBECount == SE2.getCouldNotCompute() ||
14238 NewBECount == SE2.getCouldNotCompute()) {
14239 // NB! This situation is legal, but is very suspicious -- whatever pass
14240 // change the loop to make a trip count go from could not compute to
14241 // computable or vice-versa *should have* invalidated SCEV. However, we
14242 // choose not to assert here (for now) since we don't want false
14243 // positives.
14244 continue;
14245 }
14246
14247 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14248 SE.getTypeSizeInBits(NewBECount->getType()))
14249 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14250 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14251 SE.getTypeSizeInBits(NewBECount->getType()))
14252 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14253
14254 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14255 if (Delta && !Delta->isZero()) {
14256 dbgs() << "Trip Count for " << *L << " Changed!\n";
14257 dbgs() << "Old: " << *CurBECount << "\n";
14258 dbgs() << "New: " << *NewBECount << "\n";
14259 dbgs() << "Delta: " << *Delta << "\n";
14260 std::abort();
14261 }
14262 }
14263
14264 // Collect all valid loops currently in LoopInfo.
14265 SmallPtrSet<Loop *, 32> ValidLoops;
14266 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14267 while (!Worklist.empty()) {
14268 Loop *L = Worklist.pop_back_val();
14269 if (ValidLoops.insert(L).second)
14270 Worklist.append(L->begin(), L->end());
14271 }
14272 for (const auto &KV : ValueExprMap) {
14273#ifndef NDEBUG
14274 // Check for SCEV expressions referencing invalid/deleted loops.
14275 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14276 assert(ValidLoops.contains(AR->getLoop()) &&
14277 "AddRec references invalid loop");
14278 }
14279#endif
14280
14281 // Check that the value is also part of the reverse map.
14282 auto It = ExprValueMap.find(KV.second);
14283 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14284 dbgs() << "Value " << *KV.first
14285 << " is in ValueExprMap but not in ExprValueMap\n";
14286 std::abort();
14287 }
14288
14289 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14290 if (!ReachableBlocks.contains(I->getParent()))
14291 continue;
14292 const SCEV *OldSCEV = SCM.visit(KV.second);
14293 const SCEV *NewSCEV = SE2.getSCEV(I);
14294 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14295 if (Delta && !Delta->isZero()) {
14296 dbgs() << "SCEV for value " << *I << " changed!\n"
14297 << "Old: " << *OldSCEV << "\n"
14298 << "New: " << *NewSCEV << "\n"
14299 << "Delta: " << *Delta << "\n";
14300 std::abort();
14301 }
14302 }
14303 }
14304
14305 for (const auto &KV : ExprValueMap) {
14306 for (Value *V : KV.second) {
14307 auto It = ValueExprMap.find_as(V);
14308 if (It == ValueExprMap.end()) {
14309 dbgs() << "Value " << *V
14310 << " is in ExprValueMap but not in ValueExprMap\n";
14311 std::abort();
14312 }
14313 if (It->second != KV.first) {
14314 dbgs() << "Value " << *V << " mapped to " << *It->second
14315 << " rather than " << *KV.first << "\n";
14316 std::abort();
14317 }
14318 }
14319 }
14320
14321 // Verify integrity of SCEV users.
14322 for (const auto &S : UniqueSCEVs) {
14323 for (const auto *Op : S.operands()) {
14324 // We do not store dependencies of constants.
14325 if (isa<SCEVConstant>(Op))
14326 continue;
14327 auto It = SCEVUsers.find(Op);
14328 if (It != SCEVUsers.end() && It->second.count(&S))
14329 continue;
14330 dbgs() << "Use of operand " << *Op << " by user " << S
14331 << " is not being tracked!\n";
14332 std::abort();
14333 }
14334 }
14335
14336 // Verify integrity of ValuesAtScopes users.
14337 for (const auto &ValueAndVec : ValuesAtScopes) {
14338 const SCEV *Value = ValueAndVec.first;
14339 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14340 const Loop *L = LoopAndValueAtScope.first;
14341 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14342 if (!isa<SCEVConstant>(ValueAtScope)) {
14343 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14344 if (It != ValuesAtScopesUsers.end() &&
14345 is_contained(It->second, std::make_pair(L, Value)))
14346 continue;
14347 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14348 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14349 std::abort();
14350 }
14351 }
14352 }
14353
14354 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14355 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14356 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14357 const Loop *L = LoopAndValue.first;
14358 const SCEV *Value = LoopAndValue.second;
14359 assert(!isa<SCEVConstant>(Value));
14360 auto It = ValuesAtScopes.find(Value);
14361 if (It != ValuesAtScopes.end() &&
14362 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14363 continue;
14364 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14365 << *ValueAtScope << " missing in ValuesAtScopes\n";
14366 std::abort();
14367 }
14368 }
14369
14370 // Verify integrity of BECountUsers.
14371 auto VerifyBECountUsers = [&](bool Predicated) {
14372 auto &BECounts =
14373 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14374 for (const auto &LoopAndBEInfo : BECounts) {
14375 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14376 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14377 if (!isa<SCEVConstant>(S)) {
14378 auto UserIt = BECountUsers.find(S);
14379 if (UserIt != BECountUsers.end() &&
14380 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14381 continue;
14382 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14383 << " missing from BECountUsers\n";
14384 std::abort();
14385 }
14386 }
14387 }
14388 }
14389 };
14390 VerifyBECountUsers(/* Predicated */ false);
14391 VerifyBECountUsers(/* Predicated */ true);
14392
14393 // Verify intergity of loop disposition cache.
14394 for (auto &[S, Values] : LoopDispositions) {
14395 for (auto [Loop, CachedDisposition] : Values) {
14396 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14397 if (CachedDisposition != RecomputedDisposition) {
14398 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14399 << " is incorrect: cached " << CachedDisposition << ", actual "
14400 << RecomputedDisposition << "\n";
14401 std::abort();
14402 }
14403 }
14404 }
14405
14406 // Verify integrity of the block disposition cache.
14407 for (auto &[S, Values] : BlockDispositions) {
14408 for (auto [BB, CachedDisposition] : Values) {
14409 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14410 if (CachedDisposition != RecomputedDisposition) {
14411 dbgs() << "Cached disposition of " << *S << " for block %"
14412 << BB->getName() << " is incorrect: cached " << CachedDisposition
14413 << ", actual " << RecomputedDisposition << "\n";
14414 std::abort();
14415 }
14416 }
14417 }
14418
14419 // Verify FoldCache/FoldCacheUser caches.
14420 for (auto [FoldID, Expr] : FoldCache) {
14421 auto I = FoldCacheUser.find(Expr);
14422 if (I == FoldCacheUser.end()) {
14423 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14424 << "!\n";
14425 std::abort();
14426 }
14427 if (!is_contained(I->second, FoldID)) {
14428 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14429 std::abort();
14430 }
14431 }
14432 for (auto [Expr, IDs] : FoldCacheUser) {
14433 for (auto &FoldID : IDs) {
14434 auto I = FoldCache.find(FoldID);
14435 if (I == FoldCache.end()) {
14436 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14437 << "!\n";
14438 std::abort();
14439 }
14440 if (I->second != Expr) {
14441 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: "
14442 << *I->second << " != " << *Expr << "!\n";
14443 std::abort();
14444 }
14445 }
14446 }
14447
14448 // Verify that ConstantMultipleCache computations are correct. We check that
14449 // cached multiples and recomputed multiples are multiples of each other to
14450 // verify correctness. It is possible that a recomputed multiple is different
14451 // from the cached multiple due to strengthened no wrap flags or changes in
14452 // KnownBits computations.
14453 for (auto [S, Multiple] : ConstantMultipleCache) {
14454 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14455 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14456 Multiple.urem(RecomputedMultiple) != 0 &&
14457 RecomputedMultiple.urem(Multiple) != 0)) {
14458 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14459 << *S << " : Computed " << RecomputedMultiple
14460 << " but cache contains " << Multiple << "!\n";
14461 std::abort();
14462 }
14463 }
14464}
14465
14467 Function &F, const PreservedAnalyses &PA,
14469 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14470 // of its dependencies is invalidated.
14471 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14472 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14473 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14475 Inv.invalidate<LoopAnalysis>(F, PA);
14476}
14477
14478AnalysisKey ScalarEvolutionAnalysis::Key;
14479
14482 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14483 auto &AC = AM.getResult<AssumptionAnalysis>(F);
14484 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
14485 auto &LI = AM.getResult<LoopAnalysis>(F);
14486 return ScalarEvolution(F, TLI, AC, DT, LI);
14487}
14488
14492 return PreservedAnalyses::all();
14493}
14494
14497 // For compatibility with opt's -analyze feature under legacy pass manager
14498 // which was not ported to NPM. This keeps tests using
14499 // update_analyze_test_checks.py working.
14500 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
14501 << F.getName() << "':\n";
14503 return PreservedAnalyses::all();
14504}
14505
14507 "Scalar Evolution Analysis", false, true)
14513 "Scalar Evolution Analysis", false, true)
14514
14516
14519}
14520
14522 SE.reset(new ScalarEvolution(
14523 F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F),
14524 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
14525 getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
14526 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
14527 return false;
14528}
14529
14531
14533 SE->print(OS);
14534}
14535
14537 if (!VerifySCEV)
14538 return;
14539
14540 SE->verify();
14541}
14542
14544 AU.setPreservesAll();
14549}
14550
14552 const SCEV *RHS) {
14554}
14555
14556const SCEVPredicate *
14558 const SCEV *LHS, const SCEV *RHS) {
14560 assert(LHS->getType() == RHS->getType() &&
14561 "Type mismatch between LHS and RHS");
14562 // Unique this node based on the arguments
14563 ID.AddInteger(SCEVPredicate::P_Compare);
14564 ID.AddInteger(Pred);
14565 ID.AddPointer(LHS);
14566 ID.AddPointer(RHS);
14567 void *IP = nullptr;
14568 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14569 return S;
14570 SCEVComparePredicate *Eq = new (SCEVAllocator)
14571 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
14572 UniquePreds.InsertNode(Eq, IP);
14573 return Eq;
14574}
14575
14577 const SCEVAddRecExpr *AR,
14580 // Unique this node based on the arguments
14581 ID.AddInteger(SCEVPredicate::P_Wrap);
14582 ID.AddPointer(AR);
14583 ID.AddInteger(AddedFlags);
14584 void *IP = nullptr;
14585 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14586 return S;
14587 auto *OF = new (SCEVAllocator)
14588 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
14589 UniquePreds.InsertNode(OF, IP);
14590 return OF;
14591}
14592
14593namespace {
14594
14595class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14596public:
14597
14598 /// Rewrites \p S in the context of a loop L and the SCEV predication
14599 /// infrastructure.
14600 ///
14601 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
14602 /// equivalences present in \p Pred.
14603 ///
14604 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
14605 /// \p NewPreds such that the result will be an AddRecExpr.
14606 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
14608 const SCEVPredicate *Pred) {
14609 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
14610 return Rewriter.visit(S);
14611 }
14612
14613 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14614 if (Pred) {
14615 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
14616 for (const auto *Pred : U->getPredicates())
14617 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
14618 if (IPred->getLHS() == Expr &&
14619 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14620 return IPred->getRHS();
14621 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
14622 if (IPred->getLHS() == Expr &&
14623 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14624 return IPred->getRHS();
14625 }
14626 }
14627 return convertToAddRecWithPreds(Expr);
14628 }
14629
14630 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14631 const SCEV *Operand = visit(Expr->getOperand());
14632 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14633 if (AR && AR->getLoop() == L && AR->isAffine()) {
14634 // This couldn't be folded because the operand didn't have the nuw
14635 // flag. Add the nusw flag as an assumption that we could make.
14636 const SCEV *Step = AR->getStepRecurrence(SE);
14637 Type *Ty = Expr->getType();
14638 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
14639 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
14640 SE.getSignExtendExpr(Step, Ty), L,
14641 AR->getNoWrapFlags());
14642 }
14643 return SE.getZeroExtendExpr(Operand, Expr->getType());
14644 }
14645
14646 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
14647 const SCEV *Operand = visit(Expr->getOperand());
14648 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14649 if (AR && AR->getLoop() == L && AR->isAffine()) {
14650 // This couldn't be folded because the operand didn't have the nsw
14651 // flag. Add the nssw flag as an assumption that we could make.
14652 const SCEV *Step = AR->getStepRecurrence(SE);
14653 Type *Ty = Expr->getType();
14654 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
14655 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
14656 SE.getSignExtendExpr(Step, Ty), L,
14657 AR->getNoWrapFlags());
14658 }
14659 return SE.getSignExtendExpr(Operand, Expr->getType());
14660 }
14661
14662private:
14663 explicit SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE,
14665 const SCEVPredicate *Pred)
14666 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
14667
14668 bool addOverflowAssumption(const SCEVPredicate *P) {
14669 if (!NewPreds) {
14670 // Check if we've already made this assumption.
14671 return Pred && Pred->implies(P);
14672 }
14673 NewPreds->insert(P);
14674 return true;
14675 }
14676
14677 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
14679 auto *A = SE.getWrapPredicate(AR, AddedFlags);
14680 return addOverflowAssumption(A);
14681 }
14682
14683 // If \p Expr represents a PHINode, we try to see if it can be represented
14684 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
14685 // to add this predicate as a runtime overflow check, we return the AddRec.
14686 // If \p Expr does not meet these conditions (is not a PHI node, or we
14687 // couldn't create an AddRec for it, or couldn't add the predicate), we just
14688 // return \p Expr.
14689 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
14690 if (!isa<PHINode>(Expr->getValue()))
14691 return Expr;
14692 std::optional<
14693 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
14694 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
14695 if (!PredicatedRewrite)
14696 return Expr;
14697 for (const auto *P : PredicatedRewrite->second){
14698 // Wrap predicates from outer loops are not supported.
14699 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
14700 if (L != WP->getExpr()->getLoop())
14701 return Expr;
14702 }
14703 if (!addOverflowAssumption(P))
14704 return Expr;
14705 }
14706 return PredicatedRewrite->first;
14707 }
14708
14710 const SCEVPredicate *Pred;
14711 const Loop *L;
14712};
14713
14714} // end anonymous namespace
14715
14716const SCEV *
14718 const SCEVPredicate &Preds) {
14719 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
14720}
14721
14723 const SCEV *S, const Loop *L,
14726 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
14727 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
14728
14729 if (!AddRec)
14730 return nullptr;
14731
14732 // Since the transformation was successful, we can now transfer the SCEV
14733 // predicates.
14734 for (const auto *P : TransformPreds)
14735 Preds.insert(P);
14736
14737 return AddRec;
14738}
14739
14740/// SCEV predicates
14742 SCEVPredicateKind Kind)
14743 : FastID(ID), Kind(Kind) {}
14744
14746 const ICmpInst::Predicate Pred,
14747 const SCEV *LHS, const SCEV *RHS)
14748 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
14749 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
14750 assert(LHS != RHS && "LHS and RHS are the same SCEV");
14751}
14752
14754 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
14755
14756 if (!Op)
14757 return false;
14758
14759 if (Pred != ICmpInst::ICMP_EQ)
14760 return false;
14761
14762 return Op->LHS == LHS && Op->RHS == RHS;
14763}
14764
14765bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
14766
14768 if (Pred == ICmpInst::ICMP_EQ)
14769 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
14770 else
14771 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
14772 << *RHS << "\n";
14773
14774}
14775
14777 const SCEVAddRecExpr *AR,
14778 IncrementWrapFlags Flags)
14779 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
14780
14781const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
14782
14784 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
14785
14786 return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
14787}
14788
14790 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
14791 IncrementWrapFlags IFlags = Flags;
14792
14793 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
14794 IFlags = clearFlags(IFlags, IncrementNSSW);
14795
14796 return IFlags == IncrementAnyWrap;
14797}
14798
14800 OS.indent(Depth) << *getExpr() << " Added Flags: ";
14802 OS << "<nusw>";
14804 OS << "<nssw>";
14805 OS << "\n";
14806}
14807
14810 ScalarEvolution &SE) {
14811 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
14812 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
14813
14814 // We can safely transfer the NSW flag as NSSW.
14815 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
14816 ImpliedFlags = IncrementNSSW;
14817
14818 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
14819 // If the increment is positive, the SCEV NUW flag will also imply the
14820 // WrapPredicate NUSW flag.
14821 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
14822 if (Step->getValue()->getValue().isNonNegative())
14823 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
14824 }
14825
14826 return ImpliedFlags;
14827}
14828
14829/// Union predicates don't get cached so create a dummy set ID for it.
14831 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
14832 for (const auto *P : Preds)
14833 add(P);
14834}
14835
14837 return all_of(Preds,
14838 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
14839}
14840
14842 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
14843 return all_of(Set->Preds,
14844 [this](const SCEVPredicate *I) { return this->implies(I); });
14845
14846 return any_of(Preds,
14847 [N](const SCEVPredicate *I) { return I->implies(N); });
14848}
14849
14851 for (const auto *Pred : Preds)
14852 Pred->print(OS, Depth);
14853}
14854
14855void SCEVUnionPredicate::add(const SCEVPredicate *N) {
14856 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
14857 for (const auto *Pred : Set->Preds)
14858 add(Pred);
14859 return;
14860 }
14861
14862 // Only add predicate if it is not already implied by this union predicate.
14863 if (!implies(N))
14864 Preds.push_back(N);
14865}
14866
14868 Loop &L)
14869 : SE(SE), L(L) {
14871 Preds = std::make_unique<SCEVUnionPredicate>(Empty);
14872}
14873
14876 for (const auto *Op : Ops)
14877 // We do not expect that forgetting cached data for SCEVConstants will ever
14878 // open any prospects for sharpening or introduce any correctness issues,
14879 // so we don't bother storing their dependencies.
14880 if (!isa<SCEVConstant>(Op))
14881 SCEVUsers[Op].insert(User);
14882}
14883
14885 const SCEV *Expr = SE.getSCEV(V);
14886 RewriteEntry &Entry = RewriteMap[Expr];
14887
14888 // If we already have an entry and the version matches, return it.
14889 if (Entry.second && Generation == Entry.first)
14890 return Entry.second;
14891
14892 // We found an entry but it's stale. Rewrite the stale entry
14893 // according to the current predicate.
14894 if (Entry.second)
14895 Expr = Entry.second;
14896
14897 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
14898 Entry = {Generation, NewSCEV};
14899
14900 return NewSCEV;
14901}
14902
14904 if (!BackedgeCount) {
14906 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
14907 for (const auto *P : Preds)
14908 addPredicate(*P);
14909 }
14910 return BackedgeCount;
14911}
14912
14914 if (!SymbolicMaxBackedgeCount) {
14916 SymbolicMaxBackedgeCount =
14918 for (const auto *P : Preds)
14919 addPredicate(*P);
14920 }
14921 return SymbolicMaxBackedgeCount;
14922}
14923
14925 if (Preds->implies(&Pred))
14926 return;
14927
14928 auto &OldPreds = Preds->getPredicates();
14929 SmallVector<const SCEVPredicate*, 4> NewPreds(OldPreds.begin(), OldPreds.end());
14930 NewPreds.push_back(&Pred);
14931 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds);
14932 updateGeneration();
14933}
14934
14936 return *Preds;
14937}
14938
14939void PredicatedScalarEvolution::updateGeneration() {
14940 // If the generation number wrapped recompute everything.
14941 if (++Generation == 0) {
14942 for (auto &II : RewriteMap) {
14943 const SCEV *Rewritten = II.second.second;
14944 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
14945 }
14946 }
14947}
14948
14951 const SCEV *Expr = getSCEV(V);
14952 const auto *AR = cast<SCEVAddRecExpr>(Expr);
14953
14954 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
14955
14956 // Clear the statically implied flags.
14957 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
14958 addPredicate(*SE.getWrapPredicate(AR, Flags));
14959
14960 auto II = FlagsMap.insert({V, Flags});
14961 if (!II.second)
14962 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
14963}
14964
14967 const SCEV *Expr = getSCEV(V);
14968 const auto *AR = cast<SCEVAddRecExpr>(Expr);
14969
14971 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
14972
14973 auto II = FlagsMap.find(V);
14974
14975 if (II != FlagsMap.end())
14976 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
14977
14979}
14980
14982 const SCEV *Expr = this->getSCEV(V);
14984 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
14985
14986 if (!New)
14987 return nullptr;
14988
14989 for (const auto *P : NewPreds)
14990 addPredicate(*P);
14991
14992 RewriteMap[SE.getSCEV(V)] = {Generation, New};
14993 return New;
14994}
14995
14998 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
14999 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())),
15000 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15001 for (auto I : Init.FlagsMap)
15002 FlagsMap.insert(I);
15003}
15004
15006 // For each block.
15007 for (auto *BB : L.getBlocks())
15008 for (auto &I : *BB) {
15009 if (!SE.isSCEVable(I.getType()))
15010 continue;
15011
15012 auto *Expr = SE.getSCEV(&I);
15013 auto II = RewriteMap.find(Expr);
15014
15015 if (II == RewriteMap.end())
15016 continue;
15017
15018 // Don't print things that are not interesting.
15019 if (II->second.second == Expr)
15020 continue;
15021
15022 OS.indent(Depth) << "[PSE]" << I << ":\n";
15023 OS.indent(Depth + 2) << *Expr << "\n";
15024 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15025 }
15026}
15027
15028// Match the mathematical pattern A - (A / B) * B, where A and B can be
15029// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
15030// for URem with constant power-of-2 second operands.
15031// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
15032// 4, A / B becomes X / 8).
15033bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
15034 const SCEV *&RHS) {
15035 if (Expr->getType()->isPointerTy())
15036 return false;
15037
15038 // Try to match 'zext (trunc A to iB) to iY', which is used
15039 // for URem with constant power-of-2 second operands. Make sure the size of
15040 // the operand A matches the size of the whole expressions.
15041 if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr))
15042 if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) {
15043 LHS = Trunc->getOperand();
15044 // Bail out if the type of the LHS is larger than the type of the
15045 // expression for now.
15046 if (getTypeSizeInBits(LHS->getType()) >
15047 getTypeSizeInBits(Expr->getType()))
15048 return false;
15049 if (LHS->getType() != Expr->getType())
15050 LHS = getZeroExtendExpr(LHS, Expr->getType());
15052 << getTypeSizeInBits(Trunc->getType()));
15053 return true;
15054 }
15055 const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
15056 if (Add == nullptr || Add->getNumOperands() != 2)
15057 return false;
15058
15059 const SCEV *A = Add->getOperand(1);
15060 const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
15061
15062 if (Mul == nullptr)
15063 return false;
15064
15065 const auto MatchURemWithDivisor = [&](const SCEV *B) {
15066 // (SomeExpr + (-(SomeExpr / B) * B)).
15067 if (Expr == getURemExpr(A, B)) {
15068 LHS = A;
15069 RHS = B;
15070 return true;
15071 }
15072 return false;
15073 };
15074
15075 // (SomeExpr + (-1 * (SomeExpr / B) * B)).
15076 if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
15077 return MatchURemWithDivisor(Mul->getOperand(1)) ||
15078 MatchURemWithDivisor(Mul->getOperand(2));
15079
15080 // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
15081 if (Mul->getNumOperands() == 2)
15082 return MatchURemWithDivisor(Mul->getOperand(1)) ||
15083 MatchURemWithDivisor(Mul->getOperand(0)) ||
15084 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
15085 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
15086 return false;
15087}
15088
15089/// A rewriter to replace SCEV expressions in Map with the corresponding entry
15090/// in the map. It skips AddRecExpr because we cannot guarantee that the
15091/// replacement is loop invariant in the loop of the AddRec.
15092class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
15094
15096
15097public:
15100 bool PreserveNUW, bool PreserveNSW)
15101 : SCEVRewriteVisitor(SE), Map(M) {
15102 if (PreserveNUW)
15103 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
15104 if (PreserveNSW)
15105 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
15106 }
15107
15108 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
15109
15110 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
15111 auto I = Map.find(Expr);
15112 if (I == Map.end())
15113 return Expr;
15114 return I->second;
15115 }
15116
15118 auto I = Map.find(Expr);
15119 if (I == Map.end()) {
15120 // If we didn't find the extact ZExt expr in the map, check if there's an
15121 // entry for a smaller ZExt we can use instead.
15122 Type *Ty = Expr->getType();
15123 const SCEV *Op = Expr->getOperand(0);
15124 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
15125 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
15126 Bitwidth > Op->getType()->getScalarSizeInBits()) {
15127 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
15128 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
15129 auto I = Map.find(NarrowExt);
15130 if (I != Map.end())
15131 return SE.getZeroExtendExpr(I->second, Ty);
15132 Bitwidth = Bitwidth / 2;
15133 }
15134
15136 Expr);
15137 }
15138 return I->second;
15139 }
15140
15142 auto I = Map.find(Expr);
15143 if (I == Map.end())
15145 Expr);
15146 return I->second;
15147 }
15148
15149 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
15150 auto I = Map.find(Expr);
15151 if (I == Map.end())
15153 return I->second;
15154 }
15155
15156 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
15157 auto I = Map.find(Expr);
15158 if (I == Map.end())
15160 return I->second;
15161 }
15162
15163 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
15165 bool Changed = false;
15166 for (const auto *Op : Expr->operands()) {
15168 Changed |= Op != Operands.back();
15169 }
15170 // We are only replacing operands with equivalent values, so transfer the
15171 // flags from the original expression.
15172 return !Changed
15173 ? Expr
15175 Expr->getNoWrapFlags(), FlagMask));
15176 }
15177
15178 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
15180 bool Changed = false;
15181 for (const auto *Op : Expr->operands()) {
15183 Changed |= Op != Operands.back();
15184 }
15185 // We are only replacing operands with equivalent values, so transfer the
15186 // flags from the original expression.
15187 return !Changed
15188 ? Expr
15190 Expr->getNoWrapFlags(), FlagMask));
15191 }
15192};
15193
15194const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
15195 SmallVector<const SCEV *> ExprsToRewrite;
15196 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15197 const SCEV *RHS,
15199 &RewriteMap) {
15200 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15201 // replacement SCEV which isn't directly implied by the structure of that
15202 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15203 // legal. See the scoping rules for flags in the header to understand why.
15204
15205 // If LHS is a constant, apply information to the other expression.
15206 if (isa<SCEVConstant>(LHS)) {
15207 std::swap(LHS, RHS);
15208 Predicate = CmpInst::getSwappedPredicate(Predicate);
15209 }
15210
15211 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15212 // create this form when combining two checks of the form (X u< C2 + C1) and
15213 // (X >=u C1).
15214 auto MatchRangeCheckIdiom = [this, Predicate, LHS, RHS, &RewriteMap,
15215 &ExprsToRewrite]() {
15216 auto *AddExpr = dyn_cast<SCEVAddExpr>(LHS);
15217 if (!AddExpr || AddExpr->getNumOperands() != 2)
15218 return false;
15219
15220 auto *C1 = dyn_cast<SCEVConstant>(AddExpr->getOperand(0));
15221 auto *LHSUnknown = dyn_cast<SCEVUnknown>(AddExpr->getOperand(1));
15222 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15223 if (!C1 || !C2 || !LHSUnknown)
15224 return false;
15225
15226 auto ExactRegion =
15227 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15228 .sub(C1->getAPInt());
15229
15230 // Bail out, unless we have a non-wrapping, monotonic range.
15231 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15232 return false;
15233 auto I = RewriteMap.find(LHSUnknown);
15234 const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHSUnknown;
15235 RewriteMap[LHSUnknown] = getUMaxExpr(
15236 getConstant(ExactRegion.getUnsignedMin()),
15237 getUMinExpr(RewrittenLHS, getConstant(ExactRegion.getUnsignedMax())));
15238 ExprsToRewrite.push_back(LHSUnknown);
15239 return true;
15240 };
15241 if (MatchRangeCheckIdiom())
15242 return;
15243
15244 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15245 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15246 // the non-constant operand and in \p LHS the constant operand.
15247 auto IsMinMaxSCEVWithNonNegativeConstant =
15248 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15249 const SCEV *&RHS) {
15250 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15251 if (MinMax->getNumOperands() != 2)
15252 return false;
15253 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15254 if (C->getAPInt().isNegative())
15255 return false;
15256 SCTy = MinMax->getSCEVType();
15257 LHS = MinMax->getOperand(0);
15258 RHS = MinMax->getOperand(1);
15259 return true;
15260 }
15261 }
15262 return false;
15263 };
15264
15265 // Checks whether Expr is a non-negative constant, and Divisor is a positive
15266 // constant, and returns their APInt in ExprVal and in DivisorVal.
15267 auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor,
15268 APInt &ExprVal, APInt &DivisorVal) {
15269 auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
15270 auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15271 if (!ConstExpr || !ConstDivisor)
15272 return false;
15273 ExprVal = ConstExpr->getAPInt();
15274 DivisorVal = ConstDivisor->getAPInt();
15275 return ExprVal.isNonNegative() && !DivisorVal.isNonPositive();
15276 };
15277
15278 // Return a new SCEV that modifies \p Expr to the closest number divides by
15279 // \p Divisor and greater or equal than Expr.
15280 // For now, only handle constant Expr and Divisor.
15281 auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
15282 const SCEV *Divisor) {
15283 APInt ExprVal;
15284 APInt DivisorVal;
15285 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15286 return Expr;
15287 APInt Rem = ExprVal.urem(DivisorVal);
15288 if (!Rem.isZero())
15289 // return the SCEV: Expr + Divisor - Expr % Divisor
15290 return getConstant(ExprVal + DivisorVal - Rem);
15291 return Expr;
15292 };
15293
15294 // Return a new SCEV that modifies \p Expr to the closest number divides by
15295 // \p Divisor and less or equal than Expr.
15296 // For now, only handle constant Expr and Divisor.
15297 auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
15298 const SCEV *Divisor) {
15299 APInt ExprVal;
15300 APInt DivisorVal;
15301 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15302 return Expr;
15303 APInt Rem = ExprVal.urem(DivisorVal);
15304 // return the SCEV: Expr - Expr % Divisor
15305 return getConstant(ExprVal - Rem);
15306 };
15307
15308 // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15309 // recursively. This is done by aligning up/down the constant value to the
15310 // Divisor.
15311 std::function<const SCEV *(const SCEV *, const SCEV *)>
15312 ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15313 const SCEV *Divisor) {
15314 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15315 SCEVTypes SCTy;
15316 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15317 MinMaxRHS))
15318 return MinMaxExpr;
15319 auto IsMin =
15320 isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15321 assert(isKnownNonNegative(MinMaxLHS) &&
15322 "Expected non-negative operand!");
15323 auto *DivisibleExpr =
15324 IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor)
15325 : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor);
15327 ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15328 return getMinMaxExpr(SCTy, Ops);
15329 };
15330
15331 // If we have LHS == 0, check if LHS is computing a property of some unknown
15332 // SCEV %v which we can rewrite %v to express explicitly.
15333 const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS);
15334 if (Predicate == CmpInst::ICMP_EQ && RHSC &&
15335 RHSC->getValue()->isNullValue()) {
15336 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15337 // explicitly express that.
15338 const SCEV *URemLHS = nullptr;
15339 const SCEV *URemRHS = nullptr;
15340 if (matchURem(LHS, URemLHS, URemRHS)) {
15341 if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
15342 auto I = RewriteMap.find(LHSUnknown);
15343 const SCEV *RewrittenLHS =
15344 I != RewriteMap.end() ? I->second : LHSUnknown;
15345 RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15346 const auto *Multiple =
15347 getMulExpr(getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15348 RewriteMap[LHSUnknown] = Multiple;
15349 ExprsToRewrite.push_back(LHSUnknown);
15350 return;
15351 }
15352 }
15353 }
15354
15355 // Do not apply information for constants or if RHS contains an AddRec.
15356 if (isa<SCEVConstant>(LHS) || containsAddRecurrence(RHS))
15357 return;
15358
15359 // If RHS is SCEVUnknown, make sure the information is applied to it.
15360 if (!isa<SCEVUnknown>(LHS) && isa<SCEVUnknown>(RHS)) {
15361 std::swap(LHS, RHS);
15362 Predicate = CmpInst::getSwappedPredicate(Predicate);
15363 }
15364
15365 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15366 // and \p FromRewritten are the same (i.e. there has been no rewrite
15367 // registered for \p From), then puts this value in the list of rewritten
15368 // expressions.
15369 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15370 const SCEV *To) {
15371 if (From == FromRewritten)
15372 ExprsToRewrite.push_back(From);
15373 RewriteMap[From] = To;
15374 };
15375
15376 // Checks whether \p S has already been rewritten. In that case returns the
15377 // existing rewrite because we want to chain further rewrites onto the
15378 // already rewritten value. Otherwise returns \p S.
15379 auto GetMaybeRewritten = [&](const SCEV *S) {
15380 auto I = RewriteMap.find(S);
15381 return I != RewriteMap.end() ? I->second : S;
15382 };
15383
15384 // Check for the SCEV expression (A /u B) * B while B is a constant, inside
15385 // \p Expr. The check is done recuresively on \p Expr, which is assumed to
15386 // be a composition of Min/Max SCEVs. Return whether the SCEV expression (A
15387 // /u B) * B was found, and return the divisor B in \p DividesBy. For
15388 // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since
15389 // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p
15390 // DividesBy.
15391 std::function<bool(const SCEV *, const SCEV *&)> HasDivisibiltyInfo =
15392 [&](const SCEV *Expr, const SCEV *&DividesBy) {
15393 if (auto *Mul = dyn_cast<SCEVMulExpr>(Expr)) {
15394 if (Mul->getNumOperands() != 2)
15395 return false;
15396 auto *MulLHS = Mul->getOperand(0);
15397 auto *MulRHS = Mul->getOperand(1);
15398 if (isa<SCEVConstant>(MulLHS))
15399 std::swap(MulLHS, MulRHS);
15400 if (auto *Div = dyn_cast<SCEVUDivExpr>(MulLHS))
15401 if (Div->getOperand(1) == MulRHS) {
15402 DividesBy = MulRHS;
15403 return true;
15404 }
15405 }
15406 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15407 return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) ||
15408 HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy);
15409 return false;
15410 };
15411
15412 // Return true if Expr known to divide by \p DividesBy.
15413 std::function<bool(const SCEV *, const SCEV *&)> IsKnownToDivideBy =
15414 [&](const SCEV *Expr, const SCEV *DividesBy) {
15415 if (getURemExpr(Expr, DividesBy)->isZero())
15416 return true;
15417 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15418 return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) &&
15419 IsKnownToDivideBy(MinMax->getOperand(1), DividesBy);
15420 return false;
15421 };
15422
15423 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15424 const SCEV *DividesBy = nullptr;
15425 if (HasDivisibiltyInfo(RewrittenLHS, DividesBy))
15426 // Check that the whole expression is divided by DividesBy
15427 DividesBy =
15428 IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr;
15429
15430 // Collect rewrites for LHS and its transitive operands based on the
15431 // condition.
15432 // For min/max expressions, also apply the guard to its operands:
15433 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15434 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15435 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15436 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15437
15438 // We cannot express strict predicates in SCEV, so instead we replace them
15439 // with non-strict ones against plus or minus one of RHS depending on the
15440 // predicate.
15441 const SCEV *One = getOne(RHS->getType());
15442 switch (Predicate) {
15443 case CmpInst::ICMP_ULT:
15444 if (RHS->getType()->isPointerTy())
15445 return;
15446 RHS = getUMaxExpr(RHS, One);
15447 [[fallthrough]];
15448 case CmpInst::ICMP_SLT: {
15449 RHS = getMinusSCEV(RHS, One);
15450 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15451 break;
15452 }
15453 case CmpInst::ICMP_UGT:
15454 case CmpInst::ICMP_SGT:
15455 RHS = getAddExpr(RHS, One);
15456 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15457 break;
15458 case CmpInst::ICMP_ULE:
15459 case CmpInst::ICMP_SLE:
15460 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15461 break;
15462 case CmpInst::ICMP_UGE:
15463 case CmpInst::ICMP_SGE:
15464 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15465 break;
15466 default:
15467 break;
15468 }
15469
15472
15473 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15474 append_range(Worklist, S->operands());
15475 };
15476
15477 while (!Worklist.empty()) {
15478 const SCEV *From = Worklist.pop_back_val();
15479 if (isa<SCEVConstant>(From))
15480 continue;
15481 if (!Visited.insert(From).second)
15482 continue;
15483 const SCEV *FromRewritten = GetMaybeRewritten(From);
15484 const SCEV *To = nullptr;
15485
15486 switch (Predicate) {
15487 case CmpInst::ICMP_ULT:
15488 case CmpInst::ICMP_ULE:
15489 To = getUMinExpr(FromRewritten, RHS);
15490 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15491 EnqueueOperands(UMax);
15492 break;
15493 case CmpInst::ICMP_SLT:
15494 case CmpInst::ICMP_SLE:
15495 To = getSMinExpr(FromRewritten, RHS);
15496 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15497 EnqueueOperands(SMax);
15498 break;
15499 case CmpInst::ICMP_UGT:
15500 case CmpInst::ICMP_UGE:
15501 To = getUMaxExpr(FromRewritten, RHS);
15502 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15503 EnqueueOperands(UMin);
15504 break;
15505 case CmpInst::ICMP_SGT:
15506 case CmpInst::ICMP_SGE:
15507 To = getSMaxExpr(FromRewritten, RHS);
15508 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15509 EnqueueOperands(SMin);
15510 break;
15511 case CmpInst::ICMP_EQ:
15512 if (isa<SCEVConstant>(RHS))
15513 To = RHS;
15514 break;
15515 case CmpInst::ICMP_NE:
15516 if (isa<SCEVConstant>(RHS) &&
15517 cast<SCEVConstant>(RHS)->getValue()->isNullValue()) {
15518 const SCEV *OneAlignedUp =
15519 DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
15520 To = getUMaxExpr(FromRewritten, OneAlignedUp);
15521 }
15522 break;
15523 default:
15524 break;
15525 }
15526
15527 if (To)
15528 AddRewrite(From, FromRewritten, To);
15529 }
15530 };
15531
15532 BasicBlock *Header = L->getHeader();
15534 // First, collect information from assumptions dominating the loop.
15535 for (auto &AssumeVH : AC.assumptions()) {
15536 if (!AssumeVH)
15537 continue;
15538 auto *AssumeI = cast<CallInst>(AssumeVH);
15539 if (!DT.dominates(AssumeI, Header))
15540 continue;
15541 Terms.emplace_back(AssumeI->getOperand(0), true);
15542 }
15543
15544 // Second, collect information from llvm.experimental.guards dominating the loop.
15545 auto *GuardDecl = F.getParent()->getFunction(
15546 Intrinsic::getName(Intrinsic::experimental_guard));
15547 if (GuardDecl)
15548 for (const auto *GU : GuardDecl->users())
15549 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
15550 if (Guard->getFunction() == Header->getParent() && DT.dominates(Guard, Header))
15551 Terms.emplace_back(Guard->getArgOperand(0), true);
15552
15553 // Third, collect conditions from dominating branches. Starting at the loop
15554 // predecessor, climb up the predecessor chain, as long as there are
15555 // predecessors that can be found that have unique successors leading to the
15556 // original header.
15557 // TODO: share this logic with isLoopEntryGuardedByCond.
15558 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(
15559 L->getLoopPredecessor(), Header);
15560 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
15561
15562 const BranchInst *LoopEntryPredicate =
15563 dyn_cast<BranchInst>(Pair.first->getTerminator());
15564 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
15565 continue;
15566
15567 Terms.emplace_back(LoopEntryPredicate->getCondition(),
15568 LoopEntryPredicate->getSuccessor(0) == Pair.second);
15569 }
15570
15571 // Now apply the information from the collected conditions to RewriteMap.
15572 // Conditions are processed in reverse order, so the earliest conditions is
15573 // processed first. This ensures the SCEVs with the shortest dependency chains
15574 // are constructed first.
15576 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
15577 SmallVector<Value *, 8> Worklist;
15579 Worklist.push_back(Term);
15580 while (!Worklist.empty()) {
15581 Value *Cond = Worklist.pop_back_val();
15582 if (!Visited.insert(Cond).second)
15583 continue;
15584
15585 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
15586 auto Predicate =
15587 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
15588 const auto *LHS = getSCEV(Cmp->getOperand(0));
15589 const auto *RHS = getSCEV(Cmp->getOperand(1));
15590 CollectCondition(Predicate, LHS, RHS, RewriteMap);
15591 continue;
15592 }
15593
15594 Value *L, *R;
15595 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
15596 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
15597 Worklist.push_back(L);
15598 Worklist.push_back(R);
15599 }
15600 }
15601 }
15602
15603 if (RewriteMap.empty())
15604 return Expr;
15605
15606 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
15607 // the replacement expressions are contained in the ranges of the replaced
15608 // expressions.
15609 bool PreserveNUW = true;
15610 bool PreserveNSW = true;
15611 for (const SCEV *Expr : ExprsToRewrite) {
15612 const SCEV *RewriteTo = RewriteMap[Expr];
15613 PreserveNUW &= getUnsignedRange(Expr).contains(getUnsignedRange(RewriteTo));
15614 PreserveNSW &= getSignedRange(Expr).contains(getSignedRange(RewriteTo));
15615 }
15616
15617 // Now that all rewrite information is collect, rewrite the collected
15618 // expressions with the information in the map. This applies information to
15619 // sub-expressions.
15620 if (ExprsToRewrite.size() > 1) {
15621 for (const SCEV *Expr : ExprsToRewrite) {
15622 const SCEV *RewriteTo = RewriteMap[Expr];
15623 RewriteMap.erase(Expr);
15624 SCEVLoopGuardRewriter Rewriter(*this, RewriteMap, PreserveNUW,
15625 PreserveNSW);
15626 RewriteMap.insert({Expr, Rewriter.visit(RewriteTo)});
15627 }
15628 }
15629 SCEVLoopGuardRewriter Rewriter(*this, RewriteMap, PreserveNUW, PreserveNSW);
15630 return Rewriter.visit(Expr);
15631}
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static const LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
Expand Atomic instructions
basic Basic Alias true
block Block Frequency Analysis
BlockVerifier::State From
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition: Compiler.h:537
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(X)
Definition: Debug.h:101
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
uint64_t Size
bool End
Definition: ELF_riscv.cpp:480
Generic implementation of equivalence classes through the use Tarjan's efficient union-find algorithm...
static GCMetadataPrinterRegistry::Add< ErlangGCPrinter > X("erlang", "erlang-compatible garbage collector")
static bool isSigned(unsigned int Opcode)
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
iv Induction Variable Users
Definition: IVUsers.cpp:48
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition: Lint.cpp:528
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
mir Rename Register Operands
#define T1
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
static GCMetadataPrinterRegistry::Add< OcamlGCMetadataPrinter > Y("ocaml", "ocaml 3.10-compatible collector")
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
if(VerifyEach)
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:55
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:59
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
static bool rewrite(Function &F)
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
SI optimize exec mask operations pre RA
This file contains some templates that are useful if you are working with the STL at all.
raw_pwrite_stream & OS
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static void GroupByComplexity(SmallVectorImpl< const SCEV * > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS)
static std::optional< int > CompareSCEVComplexity(EquivalenceClasses< const SCEV * > &EqCacheSCEV, EquivalenceClasses< const Value * > &EqCacheValue, const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool hasHugeExpression(ArrayRef< const SCEV * > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static int CompareValueComplexity(EquivalenceClasses< const Value * > &EqCacheValue, const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, ScalarEvolution &SE)
Finds the minimum unsigned root of the following equation:
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static bool CollectAddOperandsWithScales(DenseMap< const SCEV *, APInt > &M, SmallVectorImpl< const SCEV * > &NewOps, APInt &AccumulatedConstant, ArrayRef< const SCEV * > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static bool isKnownPredicateExtendIdiom(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, const ArrayRef< const SCEV * > Ops, SCEV::NoWrapFlags Flags)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
scalar evolution
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:167
This file contains some functions that are useful when dealing with strings.
static SymbolRef::Type getType(const Symbol *Sym)
Definition: TapiFile.cpp:40
This defines the Use class.
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition: VPlanSLP.cpp:191
Virtual Register Rewriter
Definition: VirtRegMap.cpp:237
Value * RHS
Value * LHS
static const uint32_t IV[8]
Definition: blake3_impl.h:78
A rewriter to replace SCEV expressions in Map with the corresponding entry in the map.
const SCEV * visitAddRecExpr(const SCEVAddRecExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitAddExpr(const SCEVAddExpr *Expr)
SCEVLoopGuardRewriter(ScalarEvolution &SE, DenseMap< const SCEV *, const SCEV * > &M, bool PreserveNUW, bool PreserveNSW)
const SCEV * visitMulExpr(const SCEVMulExpr *Expr)
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitUnknown(const SCEVUnknown *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
Class for arbitrary precision integers.
Definition: APInt.h:77
APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition: APInt.cpp:1941
APInt udiv(const APInt &RHS) const
Unsigned division operation.
Definition: APInt.cpp:1543
APInt zext(unsigned width) const
Zero extend to a new width.
Definition: APInt.cpp:981
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition: APInt.h:402
uint64_t getZExtValue() const
Get zero extended value.
Definition: APInt.h:1499
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition: APInt.h:1371
APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition: APInt.cpp:608
APInt zextOrTrunc(unsigned width) const
Zero extend or truncate to width.
Definition: APInt.cpp:1002
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition: APInt.h:1471
APInt trunc(unsigned width) const
Truncate to new width.
Definition: APInt.cpp:906
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition: APInt.h:185
APInt abs() const
Get the absolute value.
Definition: APInt.h:1752
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition: APInt.h:1161
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition: APInt.h:359
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition: APInt.h:445
APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition: APInt.cpp:1636
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition: APInt.h:1447
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition: APInt.h:1090
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition: APInt.h:188
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition: APInt.h:195
bool isNegative() const
Determine sign of this APInt.
Definition: APInt.h:308
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition: APInt.h:1145
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition: APInt.h:198
unsigned countTrailingZeros() const
Definition: APInt.h:1605
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition: APInt.h:335
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition: APInt.h:806
APInt multiplicativeInverse() const
Definition: APInt.cpp:1244
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition: APInt.h:1129
APInt sext(unsigned width) const
Sign extend to a new width.
Definition: APInt.cpp:954
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition: APInt.h:852
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition: APInt.h:285
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition: APInt.h:320
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition: APInt.h:1109
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition: APInt.h:179
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition: APInt.h:411
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition: APInt.h:218
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition: APInt.h:1200
This templated class represents "all analyses that operate over <a particular IR unit>" (e....
Definition: Analysis.h:49
API to communicate dependencies between analyses during invalidation.
Definition: PassManager.h:292
bool invalidate(IRUnitT &IR, const PreservedAnalyses &PA)
Trigger the invalidation of some other analysis pass if not already handled and return whether it was...
Definition: PassManager.h:310
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:405
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
ArrayRef< T > take_front(size_t N=1) const
Return a copy of *this with only the first N elements.
Definition: ArrayRef.h:228
iterator end() const
Definition: ArrayRef.h:154
size_t size() const
size - Get the array size.
Definition: ArrayRef.h:165
iterator begin() const
Definition: ArrayRef.h:153
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< ResultElem > assumptions()
Access the list of assumption handles currently tracked for this function.
bool isSingleEdge() const
Check if this is the only edge between Start and End.
Definition: Dominators.cpp:51
LLVM Basic Block Representation.
Definition: BasicBlock.h:61
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:438
const Instruction & front() const
Definition: BasicBlock.h:461
const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
Definition: BasicBlock.cpp:457
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:209
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:229
Value * getRHS() const
unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
Value * getLHS() const
BinaryOps getOpcode() const
Definition: InstrTypes.h:442
Conditional or Unconditional Branch instruction.
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Value * getCondition() const
LLVM_ATTRIBUTE_RETURNS_NONNULL void * Allocate(size_t Size, Align Alignment)
Allocate space at the specified alignment.
Definition: Allocator.h:148
This class represents a function call, abstracting a target machine's calling convention.
Value handle with callbacks on RAUW and destruction.
Definition: ValueHandle.h:383
void setValPtr(Value *P)
Definition: ValueHandle.h:390
bool isFalseWhenEqual() const
This is just a convenience.
Definition: InstrTypes.h:1062
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:757
@ ICMP_SLT
signed less than
Definition: InstrTypes.h:786
@ ICMP_SLE
signed less or equal
Definition: InstrTypes.h:787
@ ICMP_UGE
unsigned greater or equal
Definition: InstrTypes.h:781
@ ICMP_UGT
unsigned greater than
Definition: InstrTypes.h:780
@ ICMP_SGT
signed greater than
Definition: InstrTypes.h:784
@ ICMP_ULT
unsigned less than
Definition: InstrTypes.h:782
@ ICMP_EQ
equal
Definition: InstrTypes.h:778
@ ICMP_NE
not equal
Definition: InstrTypes.h:779
@ ICMP_SGE
signed greater or equal
Definition: InstrTypes.h:785
@ ICMP_ULE
unsigned less or equal
Definition: InstrTypes.h:783
bool isSigned() const
Definition: InstrTypes.h:1007
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition: InstrTypes.h:909
bool isTrueWhenEqual() const
This is just a convenience.
Definition: InstrTypes.h:1056
Predicate getNonStrictPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
Definition: InstrTypes.h:953
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition: InstrTypes.h:871
Predicate getPredicate() const
Return the predicate for this instruction.
Definition: InstrTypes.h:847
Predicate getFlippedSignednessPredicate()
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->Failed assert.
Definition: InstrTypes.h:1050
bool isUnsigned() const
Definition: InstrTypes.h:1013
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition: InstrTypes.h:1003
static Constant * getNot(Constant *C)
Definition: Constants.cpp:2555
static Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:2217
static Constant * getGetElementPtr(Type *Ty, Constant *C, ArrayRef< Constant * > IdxList, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReducedTy=nullptr)
Getelementptr form.
Definition: Constants.h:1240
static Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
Definition: Constants.cpp:2561
static Constant * getNeg(Constant *C, bool HasNSW=false)
Definition: Constants.cpp:2549
static Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:2203
This is the shared class of boolean and integer constants.
Definition: Constants.h:81
bool isMinusOne() const
This function will return true iff every bit in this constant is set to true.
Definition: Constants.h:218
bool isOne() const
This is just a convenience method to make client code smaller for a common case.
Definition: Constants.h:212
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition: Constants.h:206
static ConstantInt * getFalse(LLVMContext &Context)
Definition: Constants.cpp:857
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition: Constants.h:155
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition: Constants.h:146
static ConstantInt * getBool(LLVMContext &Context, bool V)
Definition: Constants.cpp:864
This class represents a range of values.
Definition: ConstantRange.h:47
ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
ConstantRange subtract(const APInt &CI) const
Subtract the specified constant from the endpoints of this constant range.
const APInt & getLower() const
Return the lower value for this range.
ConstantRange truncate(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other? NOTE: false does not mean that inverse pr...
bool isEmptySet() const
Return true if this set contains no members.
ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
void print(raw_ostream &OS) const
Print out the bounds to a stream.
ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
bool contains(const APInt &Val) const
Return true if the specified value is in the set.
APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
Definition: ConstantRange.h:84
static ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition: Constant.h:41
bool isNullValue() const
Return true if this is the value that would be returned by getNullValue.
Definition: Constants.cpp:90
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:110
const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
Definition: DataLayout.cpp:720
IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
Definition: DataLayout.cpp:878
unsigned getIndexTypeSizeInBits(Type *Ty) const
Layout size of the index used in GEP calculation.
Definition: DataLayout.cpp:774
IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
Definition: DataLayout.cpp:905
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition: DataLayout.h:672
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition: DenseMap.h:202
iterator find(const_arg_type_t< KeyT > Val)
Definition: DenseMap.h:155
bool erase(const KeyT &Val)
Definition: DenseMap.h:345
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition: DenseMap.h:71
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition: DenseMap.h:180
bool empty() const
Definition: DenseMap.h:98
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition: DenseMap.h:151
iterator end()
Definition: DenseMap.h:84
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition: DenseMap.h:145
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:220
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:279
bool properlyDominates(const DomTreeNodeBase< NodeT > *A, const DomTreeNodeBase< NodeT > *B) const
properlyDominates - Returns true iff A dominates B and A != B.
Legacy analysis pass which computes a DominatorTree.
Definition: Dominators.h:317
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:162
bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
Definition: Dominators.cpp:321
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:122
EquivalenceClasses - This represents a collection of equivalence classes and supports three efficient...
member_iterator unionSets(const ElemTy &V1, const ElemTy &V2)
union - Merge the two equivalence sets for the specified values, inserting them if they do not alread...
bool isEquivalent(const ElemTy &V1, const ElemTy &V2) const
FoldingSetNodeIDRef - This class describes a reference to an interned FoldingSetNodeID,...
Definition: FoldingSet.h:290
FoldingSetNodeID - This class is used to gather all the unique data bits of a node.
Definition: FoldingSet.h:327
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:311
const BasicBlock & getEntryBlock() const
Definition: Function.h:800
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:656
static bool isPrivateLinkage(LinkageTypes Linkage)
Definition: GlobalValue.h:406
static bool isInternalLinkage(LinkageTypes Linkage)
Definition: GlobalValue.h:403
This instruction compares its operands according to the predicate given to the constructor.
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
static bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
bool isEquality() const
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
Class to represent integer types.
Definition: DerivedTypes.h:40
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition: Type.cpp:278
An instruction for reading from memory.
Definition: Instructions.h:173
Analysis pass that exposes the LoopInfo for a function.
Definition: LoopInfo.h:571
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
iterator end() const
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
iterator begin() const
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition: LoopInfo.h:598
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:44
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition: LoopInfo.cpp:61
Metadata node.
Definition: Metadata.h:1067
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
Function * getFunction(StringRef Name) const
Look up the specified function in the module symbol table.
Definition: Module.cpp:193
This is a utility class that provides an abstraction for the common functionality between Instruction...
Definition: Operator.h:32
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition: Operator.h:42
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition: Operator.h:77
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition: Operator.h:110
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition: Operator.h:104
iterator_range< const_block_iterator > blocks() const
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
Definition: DerivedTypes.h:662
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1814
An interface layer with SCEV used to manage how we see SCEV expressions for values in the context of ...
void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
const SCEVPredicate & getPredicate() const
bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition: Analysis.h:264
constexpr bool isValid() const
Definition: Register.h:116
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
This is the base class for unary cast operator classes.
const SCEV * getOperand() const
SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
bool implies(const SCEVPredicate *N) const override
Implementation of the SCEVPredicate interface.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
const SCEV * getOperand(unsigned i) const
const SCEV *const * Operands
ArrayRef< const SCEV * > operands() const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
virtual bool implies(const SCEVPredicate *N) const =0
Returns true if this predicate implies N.
SCEVPredicate(const SCEVPredicate &)=default
virtual void print(raw_ostream &OS, unsigned Depth=0) const =0
Prints a textual representation of this predicate with an indentation of Depth.
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed maximum selection.
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
This class represents a sequential/in-order unsigned minimum selection.
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
void visitAll(const SCEV *Root)
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
const SCEV * getLHS() const
const SCEV * getRHS() const
This class represents an unsigned maximum selection.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds)
Union predicates don't get cached so create a dummy set ID for it.
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
bool implies(const SCEVPredicate *N) const override
Returns true if this predicate implies N.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
ArrayRef< const SCEV * > operands() const
Return operands of this SCEV expression.
unsigned short getExpressionSize() const
bool isOne() const
Return true if the expression is a constant one.
bool isZero() const
Return true if the expression is a constant zero.
void dump() const
This method is used for debugging.
bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEVTypes getSCEVType() const
Type * getType() const
Return the LLVM type of this SCEV expression.
NoWrapFlags
NoWrapFlags are bitfield indices into SubclassData.
Analysis pass that exposes the ScalarEvolution for a function.
ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
The main scalar evolution driver.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
bool isLoopBackedgeGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
const SCEV * getSMaxExpr(const SCEV *LHS, const SCEV *RHS)
const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
const SCEV * getGEPExpr(GEPOperator *GEP, const SmallVectorImpl< const SCEV * > &IndexExprs)
Returns an expression for a GEP.
Type * getWiderType(Type *Ty1, Type *Ty2) const
const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
const SCEV * getURemExpr(const SCEV *LHS, const SCEV *RHS)
Represents an unsigned remainder expression based on unsigned division.
bool SimplifyICmpOperands(ICmpInst::Predicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
APInt getConstantMultiple(const SCEV *S)
Returns the max constant multiple of S.
bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
const SCEV * getSMinExpr(const SCEV *LHS, const SCEV *RHS)
const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
const SCEV * getUMaxExpr(const SCEV *LHS, const SCEV *RHS)
void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?...
ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
const SCEV * getConstant(ConstantInt *V)
const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
unsigned getSmallConstantMaxTripCount(const Loop *L)
Returns the upper bound of the loop trip count as a normal unsigned value.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
const SCEV * getLosslessPtrToIntExpr(const SCEV *Op, unsigned Depth=0)
bool isKnownViaInduction(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
std::optional< bool > evaluatePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
bool isKnownPredicateAt(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
bool isKnownPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEV * getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
bool isKnownOnEveryIteration(ICmpInst::Predicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
const SCEV * getUDivExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
std::optional< bool > evaluatePredicateAt(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
uint32_t getMinTrailingZeros(const SCEV *S)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
void print(raw_ostream &OS) const
const SCEV * getUMinExpr(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVector< const SCEVPredicate *, 4 > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
void forgetTopmostLoop(const Loop *L)
void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallPtrSetImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
const SCEV * getCouldNotCompute()
bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
const SCEV * getVScale(Type *Ty)
unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
const SCEV * getMulExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
const SCEV * getUnknown(Value *V)
std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool isLoopEntryGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
const SCEV * getElementCount(Type *Ty, ElementCount EC)
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, int Mask)
Convenient NoWrapFlags manipulation that hides enum casts and is visible in the ScalarEvolution name ...
std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
const SCEV * getUDivExactExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
const SCEV * getAddExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVector< const SCEVPredicate *, 4 > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVMContext & getContext() const
This class represents the LLVM 'select' instruction.
size_type size() const
Definition: SmallPtrSet.h:94
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:323
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:344
bool contains(ConstPtrType Ptr) const
Definition: SmallPtrSet.h:418
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:479
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition: SmallSet.h:135
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition: SmallSet.h:179
size_type size() const
Definition: SmallSet.h:161
bool empty() const
Definition: SmallVector.h:94
size_t size() const
Definition: SmallVector.h:91
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:586
reference emplace_back(ArgTypes &&... Args)
Definition: SmallVector.h:950
void reserve(size_type N)
Definition: SmallVector.h:676
iterator erase(const_iterator CI)
Definition: SmallVector.h:750
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:696
iterator insert(iterator I, T &&Elt)
Definition: SmallVector.h:818
void push_back(const T &Elt)
Definition: SmallVector.h:426
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1209
An instruction for storing to memory.
Definition: Instructions.h:289
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition: DataLayout.h:622
TypeSize getElementOffset(unsigned Idx) const
Definition: DataLayout.h:651
TypeSize getSizeInBits() const
Definition: DataLayout.h:631
Class to represent struct types.
Definition: DerivedTypes.h:216
Multiway switch.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
bool isPointerTy() const
True if this is an instance of PointerType.
Definition: Type.h:255
static IntegerType * getInt1Ty(LLVMContext &C)
static IntegerType * getIntNTy(LLVMContext &C, unsigned N)
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
static IntegerType * getInt8Ty(LLVMContext &C)
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition: Type.h:243
static IntegerType * getInt32Ty(LLVMContext &C)
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition: Type.h:228
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
op_range operands()
Definition: User.h:242
Use & Op()
Definition: User.h:133
Value * getOperand(unsigned i) const
Definition: User.h:169
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition: Value.h:532
void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
Definition: AsmWriter.cpp:5105
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:1074
iterator_range< use_iterator > uses()
Definition: Value.h:376
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
Represents an op.with.overflow intrinsic.
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition: TypeSize.h:171
const ParentTy * getParent() const
Definition: ilist_node.h:32
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition: APInt.h:2193
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition: APInt.h:2198
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition: APInt.h:2203
std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition: APInt.cpp:2781
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition: APInt.h:2208
APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition: APInt.cpp:767
@ Entry
Definition: COFF.h:811
@ Exit
Definition: COFF.h:812
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
StringRef getName(ID id)
Return the LLVM name for an intrinsic, such as "llvm.ppc.altivec.lvx".
Definition: Function.cpp:1042
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
Definition: PatternMatch.h:168
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
bind_ty< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
Definition: PatternMatch.h:822
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
brc_match< Cond_t, bind_ty< BasicBlock >, bind_ty< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
apint_match m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
Definition: PatternMatch.h:299
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:92
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
class_match< BasicBlock > m_BasicBlock()
Match an arbitrary basic block value and ignore it.
Definition: PatternMatch.h:189
@ ReallyHidden
Definition: CommandLine.h:138
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:443
LocationClass< Ty > location(Ty &L)
Definition: CommandLine.h:463
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
constexpr double e
Definition: MathExtras.h:31
NodeAddr< PhiNode * > Phi
Definition: RDFGraph.h:390
@ FalseVal
Definition: TGLexer.h:59
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition: STLExtras.h:329
@ Offset
Definition: DWP.cpp:480
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
Definition: DynamicAPInt.h:383
void stable_sort(R &&Range)
Definition: STLExtras.h:1995
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1722
bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)
Definition: ScopeExit.h:59
bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
Definition: Verifier.cpp:7095
auto successors(const MachineBasicBlock *BB)
void * PointerTy
Definition: GenericValue.h:21
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition: STLExtras.h:2067
Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
unsigned short computeExpressionSize(ArrayRef< const SCEV * > Args)
bool VerifySCEV
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr)
ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition: bit.h:215
Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition: STLExtras.h:2059
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1729
bool getObjectSize(const Value *Ptr, uint64_t &Size, const DataLayout &DL, const TargetLibraryInfo *TLI, ObjectSizeOpts Opts={})
Compute the size of the object pointed by Ptr.
void initializeScalarEvolutionWrapperPassPass(PassRegistry &)
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:419
bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
Definition: LoopInfo.cpp:1150
bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
Definition: LoopInfo.cpp:1140
bool programUndefinedIfPoison(const Instruction *Inst)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
bool isPointerTy(const Type *T)
Definition: SPIRVUtils.h:120
ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
Constant * ConstantFoldInstOperands(Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
@ First
Helpers to iterate all locations in the MemoryEffectsBase class.
bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition: MathExtras.h:244
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition: STLExtras.h:1914
void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
DWARFExpression::Operation Op
auto max_element(R &&Range)
Definition: STLExtras.h:1986
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
Definition: APFixedPoint.h:293
constexpr unsigned BitWidth
Definition: BitmaskEnum.h:191
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1849
bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition: STLExtras.h:1921
auto predecessors(const MachineBasicBlock *BB)
bool isAllocationFn(const Value *V, const TargetLibraryInfo *TLI)
Tests if a value is a call or invoke to a library function that allocates or reallocates memory (eith...
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition: STLExtras.h:1879
unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Return the number of times the sign bit of the register is replicated into the other bits.
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition: Sequence.h:305
bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition: BitVector.h:858
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:860
#define N
#define NC
Definition: regutils.h:42
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition: Analysis.h:28
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition: KnownBits.h:290
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition: KnownBits.h:97
static KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
Definition: KnownBits.cpp:428
unsigned getBitWidth() const
Get the bit width of this value.
Definition: KnownBits.h:40
static KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
Definition: KnownBits.cpp:370
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition: KnownBits.h:185
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition: KnownBits.h:134
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition: KnownBits.h:118
bool isNegative() const
Returns true if this value is known to be negative.
Definition: KnownBits.h:94
static KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
Definition: KnownBits.cpp:285
Various options to control the behavior of getObjectSize.
bool NullIsUnknownSize
If this is true, null pointers in address space 0 will be treated as though they can't be evaluated.
bool RoundToAlign
Whether to round the result up to the alignment of allocas, byval arguments, and global variables.
An object of this class is returned by queries that could not be answered.
static bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
void addPredicate(const SCEVPredicate *P)