LLVM 23.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
65#include "llvm/ADT/FoldingSet.h"
66#include "llvm/ADT/STLExtras.h"
67#include "llvm/ADT/ScopeExit.h"
68#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/Statistic.h"
73#include "llvm/ADT/StringRef.h"
83#include "llvm/Config/llvm-config.h"
84#include "llvm/IR/Argument.h"
85#include "llvm/IR/BasicBlock.h"
86#include "llvm/IR/CFG.h"
87#include "llvm/IR/Constant.h"
89#include "llvm/IR/Constants.h"
90#include "llvm/IR/DataLayout.h"
92#include "llvm/IR/Dominators.h"
93#include "llvm/IR/Function.h"
94#include "llvm/IR/GlobalAlias.h"
95#include "llvm/IR/GlobalValue.h"
97#include "llvm/IR/InstrTypes.h"
98#include "llvm/IR/Instruction.h"
101#include "llvm/IR/Intrinsics.h"
102#include "llvm/IR/LLVMContext.h"
103#include "llvm/IR/Operator.h"
104#include "llvm/IR/PatternMatch.h"
105#include "llvm/IR/Type.h"
106#include "llvm/IR/Use.h"
107#include "llvm/IR/User.h"
108#include "llvm/IR/Value.h"
109#include "llvm/IR/Verifier.h"
111#include "llvm/Pass.h"
112#include "llvm/Support/Casting.h"
115#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136using namespace SCEVPatternMatch;
137
138#define DEBUG_TYPE "scalar-evolution"
139
140STATISTIC(NumExitCountsComputed,
141 "Number of loop exits with predictable exit counts");
142STATISTIC(NumExitCountsNotComputed,
143 "Number of loop exits without predictable exit counts");
144STATISTIC(NumBruteForceTripCountsComputed,
145 "Number of loops with trip counts computed by force");
146
147#ifdef EXPENSIVE_CHECKS
148bool llvm::VerifySCEV = true;
149#else
150bool llvm::VerifySCEV = false;
151#endif
152
154 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
155 cl::desc("Maximum number of iterations SCEV will "
156 "symbolically execute a constant "
157 "derived loop"),
158 cl::init(100));
159
161 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
162 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
164 "verify-scev-strict", cl::Hidden,
165 cl::desc("Enable stricter verification with -verify-scev is passed"));
166
168 "scev-verify-ir", cl::Hidden,
169 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
170 cl::init(false));
171
173 "scev-mulops-inline-threshold", cl::Hidden,
174 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
175 cl::init(32));
176
178 "scev-addops-inline-threshold", cl::Hidden,
179 cl::desc("Threshold for inlining addition operands into a SCEV"),
180 cl::init(500));
181
183 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
184 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
185 cl::init(32));
186
188 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
189 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
190 cl::init(2));
191
193 "scalar-evolution-max-value-compare-depth", cl::Hidden,
194 cl::desc("Maximum depth of recursive value complexity comparisons"),
195 cl::init(2));
196
198 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
199 cl::desc("Maximum depth of recursive arithmetics"),
200 cl::init(32));
201
203 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
204 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
205
207 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
208 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
209 cl::init(8));
210
212 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
213 cl::desc("Max coefficients in AddRec during evolving"),
214 cl::init(8));
215
217 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
218 cl::desc("Size of the expression which is considered huge"),
219 cl::init(4096));
220
222 "scev-range-iter-threshold", cl::Hidden,
223 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
224 cl::init(32));
225
227 "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
228 cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1));
229
230static cl::opt<bool>
231ClassifyExpressions("scalar-evolution-classify-expressions",
232 cl::Hidden, cl::init(true),
233 cl::desc("When printing analysis, include information on every instruction"));
234
236 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
237 cl::init(false),
238 cl::desc("Use more powerful methods of sharpening expression ranges. May "
239 "be costly in terms of compile time"));
240
242 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
243 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
244 "Phi strongly connected components"),
245 cl::init(8));
246
247static cl::opt<bool>
248 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
249 cl::desc("Handle <= and >= in finite loops"),
250 cl::init(true));
251
253 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
254 cl::desc("Infer nuw/nsw flags using context where suitable"),
255 cl::init(true));
256
257//===----------------------------------------------------------------------===//
258// SCEV class definitions
259//===----------------------------------------------------------------------===//
260
261//===----------------------------------------------------------------------===//
262// Implementation of the SCEV class.
263//
264
265#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
267 print(dbgs());
268 dbgs() << '\n';
269}
270#endif
271
272void SCEV::print(raw_ostream &OS) const {
273 switch (getSCEVType()) {
274 case scConstant:
275 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
276 return;
277 case scVScale:
278 OS << "vscale";
279 return;
280 case scPtrToAddr:
281 case scPtrToInt: {
282 const SCEVCastExpr *PtrCast = cast<SCEVCastExpr>(this);
283 const SCEV *Op = PtrCast->getOperand();
284 StringRef OpS = getSCEVType() == scPtrToAddr ? "addr" : "int";
285 OS << "(ptrto" << OpS << " " << *Op->getType() << " " << *Op << " to "
286 << *PtrCast->getType() << ")";
287 return;
288 }
289 case scTruncate: {
290 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
291 const SCEV *Op = Trunc->getOperand();
292 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
293 << *Trunc->getType() << ")";
294 return;
295 }
296 case scZeroExtend: {
298 const SCEV *Op = ZExt->getOperand();
299 OS << "(zext " << *Op->getType() << " " << *Op << " to "
300 << *ZExt->getType() << ")";
301 return;
302 }
303 case scSignExtend: {
305 const SCEV *Op = SExt->getOperand();
306 OS << "(sext " << *Op->getType() << " " << *Op << " to "
307 << *SExt->getType() << ")";
308 return;
309 }
310 case scAddRecExpr: {
311 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
312 OS << "{" << *AR->getOperand(0);
313 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
314 OS << ",+," << *AR->getOperand(i);
315 OS << "}<";
316 if (AR->hasNoUnsignedWrap())
317 OS << "nuw><";
318 if (AR->hasNoSignedWrap())
319 OS << "nsw><";
320 if (AR->hasNoSelfWrap() &&
322 OS << "nw><";
323 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
324 OS << ">";
325 return;
326 }
327 case scAddExpr:
328 case scMulExpr:
329 case scUMaxExpr:
330 case scSMaxExpr:
331 case scUMinExpr:
332 case scSMinExpr:
334 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
335 const char *OpStr = nullptr;
336 switch (NAry->getSCEVType()) {
337 case scAddExpr: OpStr = " + "; break;
338 case scMulExpr: OpStr = " * "; break;
339 case scUMaxExpr: OpStr = " umax "; break;
340 case scSMaxExpr: OpStr = " smax "; break;
341 case scUMinExpr:
342 OpStr = " umin ";
343 break;
344 case scSMinExpr:
345 OpStr = " smin ";
346 break;
348 OpStr = " umin_seq ";
349 break;
350 default:
351 llvm_unreachable("There are no other nary expression types.");
352 }
353 OS << "("
355 << ")";
356 switch (NAry->getSCEVType()) {
357 case scAddExpr:
358 case scMulExpr:
359 if (NAry->hasNoUnsignedWrap())
360 OS << "<nuw>";
361 if (NAry->hasNoSignedWrap())
362 OS << "<nsw>";
363 break;
364 default:
365 // Nothing to print for other nary expressions.
366 break;
367 }
368 return;
369 }
370 case scUDivExpr: {
371 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
372 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
373 return;
374 }
375 case scUnknown:
376 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
377 return;
379 OS << "***COULDNOTCOMPUTE***";
380 return;
381 }
382 llvm_unreachable("Unknown SCEV kind!");
383}
384
386 switch (getSCEVType()) {
387 case scConstant:
388 return cast<SCEVConstant>(this)->getType();
389 case scVScale:
390 return cast<SCEVVScale>(this)->getType();
391 case scPtrToAddr:
392 case scPtrToInt:
393 case scTruncate:
394 case scZeroExtend:
395 case scSignExtend:
396 return cast<SCEVCastExpr>(this)->getType();
397 case scAddRecExpr:
398 return cast<SCEVAddRecExpr>(this)->getType();
399 case scMulExpr:
400 return cast<SCEVMulExpr>(this)->getType();
401 case scUMaxExpr:
402 case scSMaxExpr:
403 case scUMinExpr:
404 case scSMinExpr:
405 return cast<SCEVMinMaxExpr>(this)->getType();
407 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
408 case scAddExpr:
409 return cast<SCEVAddExpr>(this)->getType();
410 case scUDivExpr:
411 return cast<SCEVUDivExpr>(this)->getType();
412 case scUnknown:
413 return cast<SCEVUnknown>(this)->getType();
415 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
416 }
417 llvm_unreachable("Unknown SCEV kind!");
418}
419
421 switch (getSCEVType()) {
422 case scConstant:
423 case scVScale:
424 case scUnknown:
425 return {};
426 case scPtrToAddr:
427 case scPtrToInt:
428 case scTruncate:
429 case scZeroExtend:
430 case scSignExtend:
431 return cast<SCEVCastExpr>(this)->operands();
432 case scAddRecExpr:
433 case scAddExpr:
434 case scMulExpr:
435 case scUMaxExpr:
436 case scSMaxExpr:
437 case scUMinExpr:
438 case scSMinExpr:
440 return cast<SCEVNAryExpr>(this)->operands();
441 case scUDivExpr:
442 return cast<SCEVUDivExpr>(this)->operands();
444 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
445 }
446 llvm_unreachable("Unknown SCEV kind!");
447}
448
449bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
450
451bool SCEV::isOne() const { return match(this, m_scev_One()); }
452
453bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
454
457 if (!Mul) return false;
458
459 // If there is a constant factor, it will be first.
460 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
461 if (!SC) return false;
462
463 // Return true if the value is negative, this matches things like (-42 * V).
464 return SC->getAPInt().isNegative();
465}
466
469
471 return S->getSCEVType() == scCouldNotCompute;
472}
473
476 ID.AddInteger(scConstant);
477 ID.AddPointer(V);
478 void *IP = nullptr;
479 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
480 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
481 UniqueSCEVs.InsertNode(S, IP);
482 return S;
483}
484
486 return getConstant(ConstantInt::get(getContext(), Val));
487}
488
489const SCEV *
492 // TODO: Avoid implicit trunc?
493 // See https://github.com/llvm/llvm-project/issues/112510.
494 return getConstant(
495 ConstantInt::get(ITy, V, isSigned, /*ImplicitTrunc=*/true));
496}
497
500 ID.AddInteger(scVScale);
501 ID.AddPointer(Ty);
502 void *IP = nullptr;
503 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
504 return S;
505 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
506 UniqueSCEVs.InsertNode(S, IP);
507 return S;
508}
509
511 SCEV::NoWrapFlags Flags) {
512 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
513 if (EC.isScalable())
514 Res = getMulExpr(Res, getVScale(Ty), Flags);
515 return Res;
516}
517
519 const SCEV *op, Type *ty)
520 : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
521
522SCEVPtrToAddrExpr::SCEVPtrToAddrExpr(const FoldingSetNodeIDRef ID,
523 const SCEV *Op, Type *ITy)
524 : SCEVCastExpr(ID, scPtrToAddr, Op, ITy) {
525 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
526 "Must be a non-bit-width-changing pointer-to-integer cast!");
527}
528
529SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
530 Type *ITy)
531 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
532 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
533 "Must be a non-bit-width-changing pointer-to-integer cast!");
534}
535
537 SCEVTypes SCEVTy, const SCEV *op,
538 Type *ty)
539 : SCEVCastExpr(ID, SCEVTy, op, ty) {}
540
541SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
542 Type *ty)
544 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
545 "Cannot truncate non-integer value!");
546}
547
548SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
549 const SCEV *op, Type *ty)
551 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
552 "Cannot zero extend non-integer value!");
553}
554
555SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
556 const SCEV *op, Type *ty)
558 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
559 "Cannot sign extend non-integer value!");
560}
561
563 // Clear this SCEVUnknown from various maps.
564 SE->forgetMemoizedResults(this);
565
566 // Remove this SCEVUnknown from the uniquing map.
567 SE->UniqueSCEVs.RemoveNode(this);
568
569 // Release the value.
570 setValPtr(nullptr);
571}
572
573void SCEVUnknown::allUsesReplacedWith(Value *New) {
574 // Clear this SCEVUnknown from various maps.
575 SE->forgetMemoizedResults(this);
576
577 // Remove this SCEVUnknown from the uniquing map.
578 SE->UniqueSCEVs.RemoveNode(this);
579
580 // Replace the value pointer in case someone is still using this SCEVUnknown.
581 setValPtr(New);
582}
583
584//===----------------------------------------------------------------------===//
585// SCEV Utilities
586//===----------------------------------------------------------------------===//
587
588/// Compare the two values \p LV and \p RV in terms of their "complexity" where
589/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
590/// operands in SCEV expressions.
591static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
592 Value *RV, unsigned Depth) {
594 return 0;
595
596 // Order pointer values after integer values. This helps SCEVExpander form
597 // GEPs.
598 bool LIsPointer = LV->getType()->isPointerTy(),
599 RIsPointer = RV->getType()->isPointerTy();
600 if (LIsPointer != RIsPointer)
601 return (int)LIsPointer - (int)RIsPointer;
602
603 // Compare getValueID values.
604 unsigned LID = LV->getValueID(), RID = RV->getValueID();
605 if (LID != RID)
606 return (int)LID - (int)RID;
607
608 // Sort arguments by their position.
609 if (const auto *LA = dyn_cast<Argument>(LV)) {
610 const auto *RA = cast<Argument>(RV);
611 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
612 return (int)LArgNo - (int)RArgNo;
613 }
614
615 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
616 const auto *RGV = cast<GlobalValue>(RV);
617
618 if (auto L = LGV->getLinkage() - RGV->getLinkage())
619 return L;
620
621 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
622 auto LT = GV->getLinkage();
623 return !(GlobalValue::isPrivateLinkage(LT) ||
625 };
626
627 // Use the names to distinguish the two values, but only if the
628 // names are semantically important.
629 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
630 return LGV->getName().compare(RGV->getName());
631 }
632
633 // For instructions, compare their loop depth, and their operand count. This
634 // is pretty loose.
635 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
636 const auto *RInst = cast<Instruction>(RV);
637
638 // Compare loop depths.
639 const BasicBlock *LParent = LInst->getParent(),
640 *RParent = RInst->getParent();
641 if (LParent != RParent) {
642 unsigned LDepth = LI->getLoopDepth(LParent),
643 RDepth = LI->getLoopDepth(RParent);
644 if (LDepth != RDepth)
645 return (int)LDepth - (int)RDepth;
646 }
647
648 // Compare the number of operands.
649 unsigned LNumOps = LInst->getNumOperands(),
650 RNumOps = RInst->getNumOperands();
651 if (LNumOps != RNumOps)
652 return (int)LNumOps - (int)RNumOps;
653
654 for (unsigned Idx : seq(LNumOps)) {
655 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
656 RInst->getOperand(Idx), Depth + 1);
657 if (Result != 0)
658 return Result;
659 }
660 }
661
662 return 0;
663}
664
665// Return negative, zero, or positive, if LHS is less than, equal to, or greater
666// than RHS, respectively. A three-way result allows recursive comparisons to be
667// more efficient.
668// If the max analysis depth was reached, return std::nullopt, assuming we do
669// not know if they are equivalent for sure.
670static std::optional<int>
671CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
672 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
673 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
674 if (LHS == RHS)
675 return 0;
676
677 // Primarily, sort the SCEVs by their getSCEVType().
678 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
679 if (LType != RType)
680 return (int)LType - (int)RType;
681
683 return std::nullopt;
684
685 // Aside from the getSCEVType() ordering, the particular ordering
686 // isn't very important except that it's beneficial to be consistent,
687 // so that (a + b) and (b + a) don't end up as different expressions.
688 switch (LType) {
689 case scUnknown: {
690 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
691 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
692
693 int X =
694 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
695 return X;
696 }
697
698 case scConstant: {
701
702 // Compare constant values.
703 const APInt &LA = LC->getAPInt();
704 const APInt &RA = RC->getAPInt();
705 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
706 if (LBitWidth != RBitWidth)
707 return (int)LBitWidth - (int)RBitWidth;
708 return LA.ult(RA) ? -1 : 1;
709 }
710
711 case scVScale: {
712 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
713 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
714 return LTy->getBitWidth() - RTy->getBitWidth();
715 }
716
717 case scAddRecExpr: {
720
721 // There is always a dominance between two recs that are used by one SCEV,
722 // so we can safely sort recs by loop header dominance. We require such
723 // order in getAddExpr.
724 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
725 if (LLoop != RLoop) {
726 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
727 assert(LHead != RHead && "Two loops share the same header?");
728 if (DT.dominates(LHead, RHead))
729 return 1;
730 assert(DT.dominates(RHead, LHead) &&
731 "No dominance between recurrences used by one SCEV?");
732 return -1;
733 }
734
735 [[fallthrough]];
736 }
737
738 case scTruncate:
739 case scZeroExtend:
740 case scSignExtend:
741 case scPtrToAddr:
742 case scPtrToInt:
743 case scAddExpr:
744 case scMulExpr:
745 case scUDivExpr:
746 case scSMaxExpr:
747 case scUMaxExpr:
748 case scSMinExpr:
749 case scUMinExpr:
751 ArrayRef<const SCEV *> LOps = LHS->operands();
752 ArrayRef<const SCEV *> ROps = RHS->operands();
753
754 // Lexicographically compare n-ary-like expressions.
755 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
756 if (LNumOps != RNumOps)
757 return (int)LNumOps - (int)RNumOps;
758
759 for (unsigned i = 0; i != LNumOps; ++i) {
760 auto X = CompareSCEVComplexity(LI, LOps[i], ROps[i], DT, Depth + 1);
761 if (X != 0)
762 return X;
763 }
764 return 0;
765 }
766
768 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
769 }
770 llvm_unreachable("Unknown SCEV kind!");
771}
772
773/// Given a list of SCEV objects, order them by their complexity, and group
774/// objects of the same complexity together by value. When this routine is
775/// finished, we know that any duplicates in the vector are consecutive and that
776/// complexity is monotonically increasing.
777///
778/// Note that we go take special precautions to ensure that we get deterministic
779/// results from this routine. In other words, we don't want the results of
780/// this to depend on where the addresses of various SCEV objects happened to
781/// land in memory.
783 LoopInfo *LI, DominatorTree &DT) {
784 if (Ops.size() < 2) return; // Noop
785
786 // Whether LHS has provably less complexity than RHS.
787 auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
788 auto Complexity = CompareSCEVComplexity(LI, LHS, RHS, DT);
789 return Complexity && *Complexity < 0;
790 };
791 if (Ops.size() == 2) {
792 // This is the common case, which also happens to be trivially simple.
793 // Special case it.
794 const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
795 if (IsLessComplex(RHS, LHS))
796 std::swap(LHS, RHS);
797 return;
798 }
799
800 // Do the rough sort by complexity.
801 llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
802 return IsLessComplex(LHS, RHS);
803 });
804
805 // Now that we are sorted by complexity, group elements of the same
806 // complexity. Note that this is, at worst, N^2, but the vector is likely to
807 // be extremely short in practice. Note that we take this approach because we
808 // do not want to depend on the addresses of the objects we are grouping.
809 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
810 const SCEV *S = Ops[i];
811 unsigned Complexity = S->getSCEVType();
812
813 // If there are any objects of the same complexity and same value as this
814 // one, group them.
815 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
816 if (Ops[j] == S) { // Found a duplicate.
817 // Move it to immediately after i'th element.
818 std::swap(Ops[i+1], Ops[j]);
819 ++i; // no need to rescan it.
820 if (i == e-2) return; // Done!
821 }
822 }
823 }
824}
825
826/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
827/// least HugeExprThreshold nodes).
829 return any_of(Ops, [](const SCEV *S) {
831 });
832}
833
834/// Performs a number of common optimizations on the passed \p Ops. If the
835/// whole expression reduces down to a single operand, it will be returned.
836///
837/// The following optimizations are performed:
838/// * Fold constants using the \p Fold function.
839/// * Remove identity constants satisfying \p IsIdentity.
840/// * If a constant satisfies \p IsAbsorber, return it.
841/// * Sort operands by complexity.
842template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
843static const SCEV *
846 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
847 const SCEVConstant *Folded = nullptr;
848 for (unsigned Idx = 0; Idx < Ops.size();) {
849 const SCEV *Op = Ops[Idx];
850 if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
851 if (!Folded)
852 Folded = C;
853 else
854 Folded = cast<SCEVConstant>(
855 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
856 Ops.erase(Ops.begin() + Idx);
857 continue;
858 }
859 ++Idx;
860 }
861
862 if (Ops.empty()) {
863 assert(Folded && "Must have folded value");
864 return Folded;
865 }
866
867 if (Folded && IsAbsorber(Folded->getAPInt()))
868 return Folded;
869
870 GroupByComplexity(Ops, &LI, DT);
871 if (Folded && !IsIdentity(Folded->getAPInt()))
872 Ops.insert(Ops.begin(), Folded);
873
874 return Ops.size() == 1 ? Ops[0] : nullptr;
875}
876
877//===----------------------------------------------------------------------===//
878// Simple SCEV method implementations
879//===----------------------------------------------------------------------===//
880
881/// Compute BC(It, K). The result has width W. Assume, K > 0.
882static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
883 ScalarEvolution &SE,
884 Type *ResultTy) {
885 // Handle the simplest case efficiently.
886 if (K == 1)
887 return SE.getTruncateOrZeroExtend(It, ResultTy);
888
889 // We are using the following formula for BC(It, K):
890 //
891 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
892 //
893 // Suppose, W is the bitwidth of the return value. We must be prepared for
894 // overflow. Hence, we must assure that the result of our computation is
895 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
896 // safe in modular arithmetic.
897 //
898 // However, this code doesn't use exactly that formula; the formula it uses
899 // is something like the following, where T is the number of factors of 2 in
900 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
901 // exponentiation:
902 //
903 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
904 //
905 // This formula is trivially equivalent to the previous formula. However,
906 // this formula can be implemented much more efficiently. The trick is that
907 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
908 // arithmetic. To do exact division in modular arithmetic, all we have
909 // to do is multiply by the inverse. Therefore, this step can be done at
910 // width W.
911 //
912 // The next issue is how to safely do the division by 2^T. The way this
913 // is done is by doing the multiplication step at a width of at least W + T
914 // bits. This way, the bottom W+T bits of the product are accurate. Then,
915 // when we perform the division by 2^T (which is equivalent to a right shift
916 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
917 // truncated out after the division by 2^T.
918 //
919 // In comparison to just directly using the first formula, this technique
920 // is much more efficient; using the first formula requires W * K bits,
921 // but this formula less than W + K bits. Also, the first formula requires
922 // a division step, whereas this formula only requires multiplies and shifts.
923 //
924 // It doesn't matter whether the subtraction step is done in the calculation
925 // width or the input iteration count's width; if the subtraction overflows,
926 // the result must be zero anyway. We prefer here to do it in the width of
927 // the induction variable because it helps a lot for certain cases; CodeGen
928 // isn't smart enough to ignore the overflow, which leads to much less
929 // efficient code if the width of the subtraction is wider than the native
930 // register width.
931 //
932 // (It's possible to not widen at all by pulling out factors of 2 before
933 // the multiplication; for example, K=2 can be calculated as
934 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
935 // extra arithmetic, so it's not an obvious win, and it gets
936 // much more complicated for K > 3.)
937
938 // Protection from insane SCEVs; this bound is conservative,
939 // but it probably doesn't matter.
940 if (K > 1000)
941 return SE.getCouldNotCompute();
942
943 unsigned W = SE.getTypeSizeInBits(ResultTy);
944
945 // Calculate K! / 2^T and T; we divide out the factors of two before
946 // multiplying for calculating K! / 2^T to avoid overflow.
947 // Other overflow doesn't matter because we only care about the bottom
948 // W bits of the result.
949 APInt OddFactorial(W, 1);
950 unsigned T = 1;
951 for (unsigned i = 3; i <= K; ++i) {
952 unsigned TwoFactors = countr_zero(i);
953 T += TwoFactors;
954 OddFactorial *= (i >> TwoFactors);
955 }
956
957 // We need at least W + T bits for the multiplication step
958 unsigned CalculationBits = W + T;
959
960 // Calculate 2^T, at width T+W.
961 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
962
963 // Calculate the multiplicative inverse of K! / 2^T;
964 // this multiplication factor will perform the exact division by
965 // K! / 2^T.
966 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
967
968 // Calculate the product, at width T+W
969 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
970 CalculationBits);
971 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
972 for (unsigned i = 1; i != K; ++i) {
973 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
974 Dividend = SE.getMulExpr(Dividend,
975 SE.getTruncateOrZeroExtend(S, CalculationTy));
976 }
977
978 // Divide by 2^T
979 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
980
981 // Truncate the result, and divide by K! / 2^T.
982
983 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
984 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
985}
986
987/// Return the value of this chain of recurrences at the specified iteration
988/// number. We can evaluate this recurrence by multiplying each element in the
989/// chain by the binomial coefficient corresponding to it. In other words, we
990/// can evaluate {A,+,B,+,C,+,D} as:
991///
992/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
993///
994/// where BC(It, k) stands for binomial coefficient.
996 ScalarEvolution &SE) const {
997 return evaluateAtIteration(operands(), It, SE);
998}
999
1000const SCEV *
1002 const SCEV *It, ScalarEvolution &SE) {
1003 assert(Operands.size() > 0);
1004 const SCEV *Result = Operands[0];
1005 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
1006 // The computation is correct in the face of overflow provided that the
1007 // multiplication is performed _after_ the evaluation of the binomial
1008 // coefficient.
1009 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
1010 if (isa<SCEVCouldNotCompute>(Coeff))
1011 return Coeff;
1012
1013 Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
1014 }
1015 return Result;
1016}
1017
1018//===----------------------------------------------------------------------===//
1019// SCEV Expression folder implementations
1020//===----------------------------------------------------------------------===//
1021
1022/// The SCEVCastSinkingRewriter takes a scalar evolution expression,
1023/// which computes a pointer-typed value, and rewrites the whole expression
1024/// tree so that *all* the computations are done on integers, and the only
1025/// pointer-typed operands in the expression are SCEVUnknown.
1026/// The CreatePtrCast callback is invoked to create the actual conversion
1027/// (ptrtoint or ptrtoaddr) at the SCEVUnknown leaves.
1029 : public SCEVRewriteVisitor<SCEVCastSinkingRewriter> {
1031 using ConversionFn = function_ref<const SCEV *(const SCEVUnknown *)>;
1032 Type *TargetTy;
1033 ConversionFn CreatePtrCast;
1034
1035public:
1037 ConversionFn CreatePtrCast)
1038 : Base(SE), TargetTy(TargetTy), CreatePtrCast(std::move(CreatePtrCast)) {}
1039
1040 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
1041 Type *TargetTy, ConversionFn CreatePtrCast) {
1042 SCEVCastSinkingRewriter Rewriter(SE, TargetTy, std::move(CreatePtrCast));
1043 return Rewriter.visit(Scev);
1044 }
1045
1046 const SCEV *visit(const SCEV *S) {
1047 Type *STy = S->getType();
1048 // If the expression is not pointer-typed, just keep it as-is.
1049 if (!STy->isPointerTy())
1050 return S;
1051 // Else, recursively sink the cast down into it.
1052 return Base::visit(S);
1053 }
1054
1055 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1056 // Preserve wrap flags on rewritten SCEVAddExpr, which the default
1057 // implementation drops.
1059 bool Changed = false;
1060 for (const auto *Op : Expr->operands()) {
1061 Operands.push_back(visit(Op));
1062 Changed |= Op != Operands.back();
1063 }
1064 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1065 }
1066
1067 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1069 bool Changed = false;
1070 for (const auto *Op : Expr->operands()) {
1071 Operands.push_back(visit(Op));
1072 Changed |= Op != Operands.back();
1073 }
1074 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1075 }
1076
1077 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1078 assert(Expr->getType()->isPointerTy() &&
1079 "Should only reach pointer-typed SCEVUnknown's.");
1080 // Perform some basic constant folding. If the operand of the cast is a
1081 // null pointer, don't create a cast SCEV expression (that will be left
1082 // as-is), but produce a zero constant.
1084 return SE.getZero(TargetTy);
1085 return CreatePtrCast(Expr);
1086 }
1087};
1088
1090 assert(Op->getType()->isPointerTy() && "Op must be a pointer");
1091
1092 // It isn't legal for optimizations to construct new ptrtoint expressions
1093 // for non-integral pointers.
1094 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1095 return getCouldNotCompute();
1096
1097 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1098
1099 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1100 // is sufficiently wide to represent all possible pointer values.
1101 // We could theoretically teach SCEV to truncate wider pointers, but
1102 // that isn't implemented for now.
1104 getDataLayout().getTypeSizeInBits(IntPtrTy))
1105 return getCouldNotCompute();
1106
1107 // Use the rewriter to sink the cast down to SCEVUnknown leaves.
1109 Op, *this, IntPtrTy, [this, IntPtrTy](const SCEVUnknown *U) {
1111 ID.AddInteger(scPtrToInt);
1112 ID.AddPointer(U);
1113 void *IP = nullptr;
1114 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1115 return S;
1116 SCEV *S = new (SCEVAllocator)
1117 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), U, IntPtrTy);
1118 UniqueSCEVs.InsertNode(S, IP);
1119 registerUser(S, U);
1120 return static_cast<const SCEV *>(S);
1121 });
1122 assert(IntOp->getType()->isIntegerTy() &&
1123 "We must have succeeded in sinking the cast, "
1124 "and ending up with an integer-typed expression!");
1125 return IntOp;
1126}
1127
1129 assert(Op->getType()->isPointerTy() && "Op must be a pointer");
1130 Type *Ty = DL.getAddressType(Op->getType());
1131
1133 ID.AddInteger(scPtrToAddr);
1134 ID.AddPointer(Op);
1135
1136 void *IP = nullptr;
1137
1138 // Is there already an expression for such a cast?
1139 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1140 return S;
1141
1142 // If not, is this expression something we can't reduce any further?
1143 if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1144 // Perform some basic constant folding. If the operand of the ptr2addr cast
1145 // is a null pointer, don't create a ptr2addr SCEV expression (that will be
1146 // left as-is), but produce a zero constant.
1147 // NOTE: We could handle a more general case, but lack motivational cases.
1148 if (isa<ConstantPointerNull>(U->getValue()))
1149 return getZero(Ty);
1150 }
1151
1152 // Create an explicit cast node.
1153 // We can reuse the existing insert position since if we get here,
1154 // we won't have made any changes which would invalidate it.
1155 SCEV *S =
1156 new (SCEVAllocator) SCEVPtrToAddrExpr(ID.Intern(SCEVAllocator), Op, Ty);
1157 UniqueSCEVs.InsertNode(S, IP);
1158 registerUser(S, Op);
1159 return S;
1160}
1161
1163 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1164
1165 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1166 if (isa<SCEVCouldNotCompute>(IntOp))
1167 return IntOp;
1168
1169 return getTruncateOrZeroExtend(IntOp, Ty);
1170}
1171
1173 unsigned Depth) {
1174 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1175 "This is not a truncating conversion!");
1176 assert(isSCEVable(Ty) &&
1177 "This is not a conversion to a SCEVable type!");
1178 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1179 Ty = getEffectiveSCEVType(Ty);
1180
1182 ID.AddInteger(scTruncate);
1183 ID.AddPointer(Op);
1184 ID.AddPointer(Ty);
1185 void *IP = nullptr;
1186 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1187
1188 // Fold if the operand is constant.
1189 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1190 return getConstant(
1191 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1192
1193 // trunc(trunc(x)) --> trunc(x)
1195 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1196
1197 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1199 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1200
1201 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1203 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1204
1205 if (Depth > MaxCastDepth) {
1206 SCEV *S =
1207 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1208 UniqueSCEVs.InsertNode(S, IP);
1209 registerUser(S, Op);
1210 return S;
1211 }
1212
1213 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1214 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1215 // if after transforming we have at most one truncate, not counting truncates
1216 // that replace other casts.
1218 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1220 unsigned numTruncs = 0;
1221 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1222 ++i) {
1223 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1224 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1226 numTruncs++;
1227 Operands.push_back(S);
1228 }
1229 if (numTruncs < 2) {
1230 if (isa<SCEVAddExpr>(Op))
1231 return getAddExpr(Operands);
1232 if (isa<SCEVMulExpr>(Op))
1233 return getMulExpr(Operands);
1234 llvm_unreachable("Unexpected SCEV type for Op.");
1235 }
1236 // Although we checked in the beginning that ID is not in the cache, it is
1237 // possible that during recursion and different modification ID was inserted
1238 // into the cache. So if we find it, just return it.
1239 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1240 return S;
1241 }
1242
1243 // If the input value is a chrec scev, truncate the chrec's operands.
1244 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1246 for (const SCEV *Op : AddRec->operands())
1247 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1248 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1249 }
1250
1251 // Return zero if truncating to known zeros.
1252 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1253 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1254 return getZero(Ty);
1255
1256 // The cast wasn't folded; create an explicit cast node. We can reuse
1257 // the existing insert position since if we get here, we won't have
1258 // made any changes which would invalidate it.
1259 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1260 Op, Ty);
1261 UniqueSCEVs.InsertNode(S, IP);
1262 registerUser(S, Op);
1263 return S;
1264}
1265
1266// Get the limit of a recurrence such that incrementing by Step cannot cause
1267// signed overflow as long as the value of the recurrence within the
1268// loop does not exceed this limit before incrementing.
1269static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1270 ICmpInst::Predicate *Pred,
1271 ScalarEvolution *SE) {
1272 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1273 if (SE->isKnownPositive(Step)) {
1274 *Pred = ICmpInst::ICMP_SLT;
1276 SE->getSignedRangeMax(Step));
1277 }
1278 if (SE->isKnownNegative(Step)) {
1279 *Pred = ICmpInst::ICMP_SGT;
1281 SE->getSignedRangeMin(Step));
1282 }
1283 return nullptr;
1284}
1285
1286// Get the limit of a recurrence such that incrementing by Step cannot cause
1287// unsigned overflow as long as the value of the recurrence within the loop does
1288// not exceed this limit before incrementing.
1290 ICmpInst::Predicate *Pred,
1291 ScalarEvolution *SE) {
1292 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1293 *Pred = ICmpInst::ICMP_ULT;
1294
1296 SE->getUnsignedRangeMax(Step));
1297}
1298
1299namespace {
1300
1301struct ExtendOpTraitsBase {
1302 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1303 unsigned);
1304};
1305
1306// Used to make code generic over signed and unsigned overflow.
1307template <typename ExtendOp> struct ExtendOpTraits {
1308 // Members present:
1309 //
1310 // static const SCEV::NoWrapFlags WrapType;
1311 //
1312 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1313 //
1314 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1315 // ICmpInst::Predicate *Pred,
1316 // ScalarEvolution *SE);
1317};
1318
1319template <>
1320struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1321 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1322
1323 static const GetExtendExprTy GetExtendExpr;
1324
1325 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1326 ICmpInst::Predicate *Pred,
1327 ScalarEvolution *SE) {
1328 return getSignedOverflowLimitForStep(Step, Pred, SE);
1329 }
1330};
1331
1332const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1334
1335template <>
1336struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1337 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1338
1339 static const GetExtendExprTy GetExtendExpr;
1340
1341 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1342 ICmpInst::Predicate *Pred,
1343 ScalarEvolution *SE) {
1344 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1345 }
1346};
1347
1348const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1350
1351} // end anonymous namespace
1352
1353// The recurrence AR has been shown to have no signed/unsigned wrap or something
1354// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1355// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1356// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1357// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1358// expression "Step + sext/zext(PreIncAR)" is congruent with
1359// "sext/zext(PostIncAR)"
1360template <typename ExtendOpTy>
1361static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1362 ScalarEvolution *SE, unsigned Depth) {
1363 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1364 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1365
1366 const Loop *L = AR->getLoop();
1367 const SCEV *Start = AR->getStart();
1368 const SCEV *Step = AR->getStepRecurrence(*SE);
1369
1370 // Check for a simple looking step prior to loop entry.
1371 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1372 if (!SA)
1373 return nullptr;
1374
1375 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1376 // subtraction is expensive. For this purpose, perform a quick and dirty
1377 // difference, by checking for Step in the operand list. Note, that
1378 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1380 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1381 if (*It == Step) {
1382 DiffOps.erase(It);
1383 break;
1384 }
1385
1386 if (DiffOps.size() == SA->getNumOperands())
1387 return nullptr;
1388
1389 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1390 // `Step`:
1391
1392 // 1. NSW/NUW flags on the step increment.
1393 auto PreStartFlags =
1395 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1397 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1398
1399 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1400 // "S+X does not sign/unsign-overflow".
1401 //
1402
1403 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1404 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1405 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1406 return PreStart;
1407
1408 // 2. Direct overflow check on the step operation's expression.
1409 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1410 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1411 const SCEV *OperandExtendedStart =
1412 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1413 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1414 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1415 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1416 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1417 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1418 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1419 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1420 }
1421 return PreStart;
1422 }
1423
1424 // 3. Loop precondition.
1426 const SCEV *OverflowLimit =
1427 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1428
1429 if (OverflowLimit &&
1430 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1431 return PreStart;
1432
1433 return nullptr;
1434}
1435
1436// Get the normalized zero or sign extended expression for this AddRec's Start.
1437template <typename ExtendOpTy>
1438static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1439 ScalarEvolution *SE,
1440 unsigned Depth) {
1441 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1442
1443 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1444 if (!PreStart)
1445 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1446
1447 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1448 Depth),
1449 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1450}
1451
1452// Try to prove away overflow by looking at "nearby" add recurrences. A
1453// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1454// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1455//
1456// Formally:
1457//
1458// {S,+,X} == {S-T,+,X} + T
1459// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1460//
1461// If ({S-T,+,X} + T) does not overflow ... (1)
1462//
1463// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1464//
1465// If {S-T,+,X} does not overflow ... (2)
1466//
1467// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1468// == {Ext(S-T)+Ext(T),+,Ext(X)}
1469//
1470// If (S-T)+T does not overflow ... (3)
1471//
1472// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1473// == {Ext(S),+,Ext(X)} == LHS
1474//
1475// Thus, if (1), (2) and (3) are true for some T, then
1476// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1477//
1478// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1479// does not overflow" restricted to the 0th iteration. Therefore we only need
1480// to check for (1) and (2).
1481//
1482// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1483// is `Delta` (defined below).
1484template <typename ExtendOpTy>
1485bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1486 const SCEV *Step,
1487 const Loop *L) {
1488 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1489
1490 // We restrict `Start` to a constant to prevent SCEV from spending too much
1491 // time here. It is correct (but more expensive) to continue with a
1492 // non-constant `Start` and do a general SCEV subtraction to compute
1493 // `PreStart` below.
1494 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1495 if (!StartC)
1496 return false;
1497
1498 APInt StartAI = StartC->getAPInt();
1499
1500 for (unsigned Delta : {-2, -1, 1, 2}) {
1501 const SCEV *PreStart = getConstant(StartAI - Delta);
1502
1503 FoldingSetNodeID ID;
1504 ID.AddInteger(scAddRecExpr);
1505 ID.AddPointer(PreStart);
1506 ID.AddPointer(Step);
1507 ID.AddPointer(L);
1508 void *IP = nullptr;
1509 const auto *PreAR =
1510 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1511
1512 // Give up if we don't already have the add recurrence we need because
1513 // actually constructing an add recurrence is relatively expensive.
1514 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1515 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1517 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1518 DeltaS, &Pred, this);
1519 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1520 return true;
1521 }
1522 }
1523
1524 return false;
1525}
1526
1527// Finds an integer D for an expression (C + x + y + ...) such that the top
1528// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1529// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1530// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1531// the (C + x + y + ...) expression is \p WholeAddExpr.
1533 const SCEVConstant *ConstantTerm,
1534 const SCEVAddExpr *WholeAddExpr) {
1535 const APInt &C = ConstantTerm->getAPInt();
1536 const unsigned BitWidth = C.getBitWidth();
1537 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1538 uint32_t TZ = BitWidth;
1539 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1540 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1541 if (TZ) {
1542 // Set D to be as many least significant bits of C as possible while still
1543 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1544 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1545 }
1546 return APInt(BitWidth, 0);
1547}
1548
1549// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1550// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1551// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1552// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1554 const APInt &ConstantStart,
1555 const SCEV *Step) {
1556 const unsigned BitWidth = ConstantStart.getBitWidth();
1557 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1558 if (TZ)
1559 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1560 : ConstantStart;
1561 return APInt(BitWidth, 0);
1562}
1563
1565 const ScalarEvolution::FoldID &ID, const SCEV *S,
1568 &FoldCacheUser) {
1569 auto I = FoldCache.insert({ID, S});
1570 if (!I.second) {
1571 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1572 // entry.
1573 auto &UserIDs = FoldCacheUser[I.first->second];
1574 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1575 for (unsigned I = 0; I != UserIDs.size(); ++I)
1576 if (UserIDs[I] == ID) {
1577 std::swap(UserIDs[I], UserIDs.back());
1578 break;
1579 }
1580 UserIDs.pop_back();
1581 I.first->second = S;
1582 }
1583 FoldCacheUser[S].push_back(ID);
1584}
1585
1586const SCEV *
1588 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1589 "This is not an extending conversion!");
1590 assert(isSCEVable(Ty) &&
1591 "This is not a conversion to a SCEVable type!");
1592 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1593 Ty = getEffectiveSCEVType(Ty);
1594
1595 FoldID ID(scZeroExtend, Op, Ty);
1596 if (const SCEV *S = FoldCache.lookup(ID))
1597 return S;
1598
1599 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1601 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1602 return S;
1603}
1604
1606 unsigned Depth) {
1607 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1608 "This is not an extending conversion!");
1609 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1610 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1611
1612 // Fold if the operand is constant.
1613 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1614 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1615
1616 // zext(zext(x)) --> zext(x)
1618 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1619
1620 // Before doing any expensive analysis, check to see if we've already
1621 // computed a SCEV for this Op and Ty.
1623 ID.AddInteger(scZeroExtend);
1624 ID.AddPointer(Op);
1625 ID.AddPointer(Ty);
1626 void *IP = nullptr;
1627 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1628 if (Depth > MaxCastDepth) {
1629 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1630 Op, Ty);
1631 UniqueSCEVs.InsertNode(S, IP);
1632 registerUser(S, Op);
1633 return S;
1634 }
1635
1636 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1638 // It's possible the bits taken off by the truncate were all zero bits. If
1639 // so, we should be able to simplify this further.
1640 const SCEV *X = ST->getOperand();
1642 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1643 unsigned NewBits = getTypeSizeInBits(Ty);
1644 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1645 CR.zextOrTrunc(NewBits)))
1646 return getTruncateOrZeroExtend(X, Ty, Depth);
1647 }
1648
1649 // If the input value is a chrec scev, and we can prove that the value
1650 // did not overflow the old, smaller, value, we can zero extend all of the
1651 // operands (often constants). This allows analysis of something like
1652 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1654 if (AR->isAffine()) {
1655 const SCEV *Start = AR->getStart();
1656 const SCEV *Step = AR->getStepRecurrence(*this);
1657 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1658 const Loop *L = AR->getLoop();
1659
1660 // If we have special knowledge that this addrec won't overflow,
1661 // we don't need to do any further analysis.
1662 if (AR->hasNoUnsignedWrap()) {
1663 Start =
1665 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1666 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1667 }
1668
1669 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1670 // Note that this serves two purposes: It filters out loops that are
1671 // simply not analyzable, and it covers the case where this code is
1672 // being called from within backedge-taken count analysis, such that
1673 // attempting to ask for the backedge-taken count would likely result
1674 // in infinite recursion. In the later case, the analysis code will
1675 // cope with a conservative value, and it will take care to purge
1676 // that value once it has finished.
1677 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1678 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1679 // Manually compute the final value for AR, checking for overflow.
1680
1681 // Check whether the backedge-taken count can be losslessly casted to
1682 // the addrec's type. The count is always unsigned.
1683 const SCEV *CastedMaxBECount =
1684 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1685 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1686 CastedMaxBECount, MaxBECount->getType(), Depth);
1687 if (MaxBECount == RecastedMaxBECount) {
1688 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1689 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1690 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1692 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1694 Depth + 1),
1695 WideTy, Depth + 1);
1696 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1697 const SCEV *WideMaxBECount =
1698 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1699 const SCEV *OperandExtendedAdd =
1700 getAddExpr(WideStart,
1701 getMulExpr(WideMaxBECount,
1702 getZeroExtendExpr(Step, WideTy, Depth + 1),
1705 if (ZAdd == OperandExtendedAdd) {
1706 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1707 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1708 // Return the expression with the addrec on the outside.
1709 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1710 Depth + 1);
1711 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1712 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1713 }
1714 // Similar to above, only this time treat the step value as signed.
1715 // This covers loops that count down.
1716 OperandExtendedAdd =
1717 getAddExpr(WideStart,
1718 getMulExpr(WideMaxBECount,
1719 getSignExtendExpr(Step, WideTy, Depth + 1),
1722 if (ZAdd == OperandExtendedAdd) {
1723 // Cache knowledge of AR NW, which is propagated to this AddRec.
1724 // Negative step causes unsigned wrap, but it still can't self-wrap.
1725 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1726 // Return the expression with the addrec on the outside.
1727 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1728 Depth + 1);
1729 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1730 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1731 }
1732 }
1733 }
1734
1735 // Normally, in the cases we can prove no-overflow via a
1736 // backedge guarding condition, we can also compute a backedge
1737 // taken count for the loop. The exceptions are assumptions and
1738 // guards present in the loop -- SCEV is not great at exploiting
1739 // these to compute max backedge taken counts, but can still use
1740 // these to prove lack of overflow. Use this fact to avoid
1741 // doing extra work that may not pay off.
1742 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1743 !AC.assumptions().empty()) {
1744
1745 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1746 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1747 if (AR->hasNoUnsignedWrap()) {
1748 // Same as nuw case above - duplicated here to avoid a compile time
1749 // issue. It's not clear that the order of checks does matter, but
1750 // it's one of two issue possible causes for a change which was
1751 // reverted. Be conservative for the moment.
1752 Start =
1754 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1755 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1756 }
1757
1758 // For a negative step, we can extend the operands iff doing so only
1759 // traverses values in the range zext([0,UINT_MAX]).
1760 if (isKnownNegative(Step)) {
1762 getSignedRangeMin(Step));
1765 // Cache knowledge of AR NW, which is propagated to this
1766 // AddRec. Negative step causes unsigned wrap, but it
1767 // still can't self-wrap.
1768 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1769 // Return the expression with the addrec on the outside.
1770 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1771 Depth + 1);
1772 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1773 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1774 }
1775 }
1776 }
1777
1778 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1779 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1780 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1781 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1782 const APInt &C = SC->getAPInt();
1783 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1784 if (D != 0) {
1785 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1786 const SCEV *SResidual =
1787 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1788 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1789 return getAddExpr(SZExtD, SZExtR,
1791 Depth + 1);
1792 }
1793 }
1794
1795 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1796 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1797 Start =
1799 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1800 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1801 }
1802 }
1803
1804 // zext(A % B) --> zext(A) % zext(B)
1805 {
1806 const SCEV *LHS;
1807 const SCEV *RHS;
1808 if (match(Op, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), *this)))
1809 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1810 getZeroExtendExpr(RHS, Ty, Depth + 1));
1811 }
1812
1813 // zext(A / B) --> zext(A) / zext(B).
1814 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1815 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1816 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1817
1818 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1819 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1820 if (SA->hasNoUnsignedWrap()) {
1821 // If the addition does not unsign overflow then we can, by definition,
1822 // commute the zero extension with the addition operation.
1824 for (const auto *Op : SA->operands())
1825 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1826 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1827 }
1828
1829 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1830 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1831 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1832 //
1833 // Often address arithmetics contain expressions like
1834 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1835 // This transformation is useful while proving that such expressions are
1836 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1837 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1838 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1839 if (D != 0) {
1840 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1841 const SCEV *SResidual =
1843 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1844 return getAddExpr(SZExtD, SZExtR,
1846 Depth + 1);
1847 }
1848 }
1849 }
1850
1851 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1852 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1853 if (SM->hasNoUnsignedWrap()) {
1854 // If the multiply does not unsign overflow then we can, by definition,
1855 // commute the zero extension with the multiply operation.
1857 for (const auto *Op : SM->operands())
1858 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1859 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1860 }
1861
1862 // zext(2^K * (trunc X to iN)) to iM ->
1863 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1864 //
1865 // Proof:
1866 //
1867 // zext(2^K * (trunc X to iN)) to iM
1868 // = zext((trunc X to iN) << K) to iM
1869 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1870 // (because shl removes the top K bits)
1871 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1872 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1873 //
1874 const APInt *C;
1875 const SCEV *TruncRHS;
1876 if (match(SM,
1877 m_scev_Mul(m_scev_APInt(C), m_scev_Trunc(m_SCEV(TruncRHS)))) &&
1878 C->isPowerOf2()) {
1879 int NewTruncBits =
1880 getTypeSizeInBits(SM->getOperand(1)->getType()) - C->logBase2();
1881 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1882 return getMulExpr(
1883 getZeroExtendExpr(SM->getOperand(0), Ty),
1884 getZeroExtendExpr(getTruncateExpr(TruncRHS, NewTruncTy), Ty),
1885 SCEV::FlagNUW, Depth + 1);
1886 }
1887 }
1888
1889 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1890 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1894 for (auto *Operand : MinMax->operands())
1895 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1897 return getUMinExpr(Operands);
1898 return getUMaxExpr(Operands);
1899 }
1900
1901 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1903 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1905 for (auto *Operand : MinMax->operands())
1906 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1907 return getUMinExpr(Operands, /*Sequential*/ true);
1908 }
1909
1910 // The cast wasn't folded; create an explicit cast node.
1911 // Recompute the insert position, as it may have been invalidated.
1912 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1913 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1914 Op, Ty);
1915 UniqueSCEVs.InsertNode(S, IP);
1916 registerUser(S, Op);
1917 return S;
1918}
1919
1920const SCEV *
1922 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1923 "This is not an extending conversion!");
1924 assert(isSCEVable(Ty) &&
1925 "This is not a conversion to a SCEVable type!");
1926 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1927 Ty = getEffectiveSCEVType(Ty);
1928
1929 FoldID ID(scSignExtend, Op, Ty);
1930 if (const SCEV *S = FoldCache.lookup(ID))
1931 return S;
1932
1933 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
1935 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1936 return S;
1937}
1938
1940 unsigned Depth) {
1941 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1942 "This is not an extending conversion!");
1943 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1944 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1945 Ty = getEffectiveSCEVType(Ty);
1946
1947 // Fold if the operand is constant.
1948 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1949 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
1950
1951 // sext(sext(x)) --> sext(x)
1953 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1954
1955 // sext(zext(x)) --> zext(x)
1957 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1958
1959 // Before doing any expensive analysis, check to see if we've already
1960 // computed a SCEV for this Op and Ty.
1962 ID.AddInteger(scSignExtend);
1963 ID.AddPointer(Op);
1964 ID.AddPointer(Ty);
1965 void *IP = nullptr;
1966 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1967 // Limit recursion depth.
1968 if (Depth > MaxCastDepth) {
1969 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1970 Op, Ty);
1971 UniqueSCEVs.InsertNode(S, IP);
1972 registerUser(S, Op);
1973 return S;
1974 }
1975
1976 // sext(trunc(x)) --> sext(x) or x or trunc(x)
1978 // It's possible the bits taken off by the truncate were all sign bits. If
1979 // so, we should be able to simplify this further.
1980 const SCEV *X = ST->getOperand();
1982 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1983 unsigned NewBits = getTypeSizeInBits(Ty);
1984 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1985 CR.sextOrTrunc(NewBits)))
1986 return getTruncateOrSignExtend(X, Ty, Depth);
1987 }
1988
1989 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1990 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1991 if (SA->hasNoSignedWrap()) {
1992 // If the addition does not sign overflow then we can, by definition,
1993 // commute the sign extension with the addition operation.
1995 for (const auto *Op : SA->operands())
1996 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1997 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1998 }
1999
2000 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
2001 // if D + (C - D + x + y + ...) could be proven to not signed wrap
2002 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
2003 //
2004 // For instance, this will bring two seemingly different expressions:
2005 // 1 + sext(5 + 20 * %x + 24 * %y) and
2006 // sext(6 + 20 * %x + 24 * %y)
2007 // to the same form:
2008 // 2 + sext(4 + 20 * %x + 24 * %y)
2009 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
2010 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
2011 if (D != 0) {
2012 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2013 const SCEV *SResidual =
2015 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2016 return getAddExpr(SSExtD, SSExtR,
2018 Depth + 1);
2019 }
2020 }
2021 }
2022 // If the input value is a chrec scev, and we can prove that the value
2023 // did not overflow the old, smaller, value, we can sign extend all of the
2024 // operands (often constants). This allows analysis of something like
2025 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
2027 if (AR->isAffine()) {
2028 const SCEV *Start = AR->getStart();
2029 const SCEV *Step = AR->getStepRecurrence(*this);
2030 unsigned BitWidth = getTypeSizeInBits(AR->getType());
2031 const Loop *L = AR->getLoop();
2032
2033 // If we have special knowledge that this addrec won't overflow,
2034 // we don't need to do any further analysis.
2035 if (AR->hasNoSignedWrap()) {
2036 Start =
2038 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2039 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2040 }
2041
2042 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2043 // Note that this serves two purposes: It filters out loops that are
2044 // simply not analyzable, and it covers the case where this code is
2045 // being called from within backedge-taken count analysis, such that
2046 // attempting to ask for the backedge-taken count would likely result
2047 // in infinite recursion. In the later case, the analysis code will
2048 // cope with a conservative value, and it will take care to purge
2049 // that value once it has finished.
2050 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2051 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2052 // Manually compute the final value for AR, checking for
2053 // overflow.
2054
2055 // Check whether the backedge-taken count can be losslessly casted to
2056 // the addrec's type. The count is always unsigned.
2057 const SCEV *CastedMaxBECount =
2058 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2059 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2060 CastedMaxBECount, MaxBECount->getType(), Depth);
2061 if (MaxBECount == RecastedMaxBECount) {
2062 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2063 // Check whether Start+Step*MaxBECount has no signed overflow.
2064 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2066 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2068 Depth + 1),
2069 WideTy, Depth + 1);
2070 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2071 const SCEV *WideMaxBECount =
2072 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2073 const SCEV *OperandExtendedAdd =
2074 getAddExpr(WideStart,
2075 getMulExpr(WideMaxBECount,
2076 getSignExtendExpr(Step, WideTy, Depth + 1),
2079 if (SAdd == OperandExtendedAdd) {
2080 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2081 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2082 // Return the expression with the addrec on the outside.
2083 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2084 Depth + 1);
2085 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2086 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2087 }
2088 // Similar to above, only this time treat the step value as unsigned.
2089 // This covers loops that count up with an unsigned step.
2090 OperandExtendedAdd =
2091 getAddExpr(WideStart,
2092 getMulExpr(WideMaxBECount,
2093 getZeroExtendExpr(Step, WideTy, Depth + 1),
2096 if (SAdd == OperandExtendedAdd) {
2097 // If AR wraps around then
2098 //
2099 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2100 // => SAdd != OperandExtendedAdd
2101 //
2102 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2103 // (SAdd == OperandExtendedAdd => AR is NW)
2104
2105 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2106
2107 // Return the expression with the addrec on the outside.
2108 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2109 Depth + 1);
2110 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2111 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2112 }
2113 }
2114 }
2115
2116 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2117 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2118 if (AR->hasNoSignedWrap()) {
2119 // Same as nsw case above - duplicated here to avoid a compile time
2120 // issue. It's not clear that the order of checks does matter, but
2121 // it's one of two issue possible causes for a change which was
2122 // reverted. Be conservative for the moment.
2123 Start =
2125 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2126 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2127 }
2128
2129 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2130 // if D + (C - D + Step * n) could be proven to not signed wrap
2131 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2132 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2133 const APInt &C = SC->getAPInt();
2134 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2135 if (D != 0) {
2136 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2137 const SCEV *SResidual =
2138 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2139 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2140 return getAddExpr(SSExtD, SSExtR,
2142 Depth + 1);
2143 }
2144 }
2145
2146 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2147 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2148 Start =
2150 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2151 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2152 }
2153 }
2154
2155 // If the input value is provably positive and we could not simplify
2156 // away the sext build a zext instead.
2158 return getZeroExtendExpr(Op, Ty, Depth + 1);
2159
2160 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2161 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2165 for (auto *Operand : MinMax->operands())
2166 Operands.push_back(getSignExtendExpr(Operand, Ty));
2168 return getSMinExpr(Operands);
2169 return getSMaxExpr(Operands);
2170 }
2171
2172 // The cast wasn't folded; create an explicit cast node.
2173 // Recompute the insert position, as it may have been invalidated.
2174 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2175 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2176 Op, Ty);
2177 UniqueSCEVs.InsertNode(S, IP);
2178 registerUser(S, { Op });
2179 return S;
2180}
2181
2183 Type *Ty) {
2184 switch (Kind) {
2185 case scTruncate:
2186 return getTruncateExpr(Op, Ty);
2187 case scZeroExtend:
2188 return getZeroExtendExpr(Op, Ty);
2189 case scSignExtend:
2190 return getSignExtendExpr(Op, Ty);
2191 case scPtrToInt:
2192 return getPtrToIntExpr(Op, Ty);
2193 default:
2194 llvm_unreachable("Not a SCEV cast expression!");
2195 }
2196}
2197
2198/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2199/// unspecified bits out to the given type.
2201 Type *Ty) {
2202 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2203 "This is not an extending conversion!");
2204 assert(isSCEVable(Ty) &&
2205 "This is not a conversion to a SCEVable type!");
2206 Ty = getEffectiveSCEVType(Ty);
2207
2208 // Sign-extend negative constants.
2209 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2210 if (SC->getAPInt().isNegative())
2211 return getSignExtendExpr(Op, Ty);
2212
2213 // Peel off a truncate cast.
2215 const SCEV *NewOp = T->getOperand();
2216 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2217 return getAnyExtendExpr(NewOp, Ty);
2218 return getTruncateOrNoop(NewOp, Ty);
2219 }
2220
2221 // Next try a zext cast. If the cast is folded, use it.
2222 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2223 if (!isa<SCEVZeroExtendExpr>(ZExt))
2224 return ZExt;
2225
2226 // Next try a sext cast. If the cast is folded, use it.
2227 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2228 if (!isa<SCEVSignExtendExpr>(SExt))
2229 return SExt;
2230
2231 // Force the cast to be folded into the operands of an addrec.
2232 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2234 for (const SCEV *Op : AR->operands())
2235 Ops.push_back(getAnyExtendExpr(Op, Ty));
2236 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2237 }
2238
2239 // If the expression is obviously signed, use the sext cast value.
2240 if (isa<SCEVSMaxExpr>(Op))
2241 return SExt;
2242
2243 // Absent any other information, use the zext cast value.
2244 return ZExt;
2245}
2246
2247/// Process the given Ops list, which is a list of operands to be added under
2248/// the given scale, update the given map. This is a helper function for
2249/// getAddRecExpr. As an example of what it does, given a sequence of operands
2250/// that would form an add expression like this:
2251///
2252/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2253///
2254/// where A and B are constants, update the map with these values:
2255///
2256/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2257///
2258/// and add 13 + A*B*29 to AccumulatedConstant.
2259/// This will allow getAddRecExpr to produce this:
2260///
2261/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2262///
2263/// This form often exposes folding opportunities that are hidden in
2264/// the original operand list.
2265///
2266/// Return true iff it appears that any interesting folding opportunities
2267/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2268/// the common case where no interesting opportunities are present, and
2269/// is also used as a check to avoid infinite recursion.
2270static bool
2273 APInt &AccumulatedConstant,
2274 ArrayRef<const SCEV *> Ops, const APInt &Scale,
2275 ScalarEvolution &SE) {
2276 bool Interesting = false;
2277
2278 // Iterate over the add operands. They are sorted, with constants first.
2279 unsigned i = 0;
2280 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2281 ++i;
2282 // Pull a buried constant out to the outside.
2283 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2284 Interesting = true;
2285 AccumulatedConstant += Scale * C->getAPInt();
2286 }
2287
2288 // Next comes everything else. We're especially interested in multiplies
2289 // here, but they're in the middle, so just visit the rest with one loop.
2290 for (; i != Ops.size(); ++i) {
2292 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2293 APInt NewScale =
2294 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2295 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2296 // A multiplication of a constant with another add; recurse.
2297 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2298 Interesting |=
2299 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2300 Add->operands(), NewScale, SE);
2301 } else {
2302 // A multiplication of a constant with some other value. Update
2303 // the map.
2304 SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2305 const SCEV *Key = SE.getMulExpr(MulOps);
2306 auto Pair = M.insert({Key, NewScale});
2307 if (Pair.second) {
2308 NewOps.push_back(Pair.first->first);
2309 } else {
2310 Pair.first->second += NewScale;
2311 // The map already had an entry for this value, which may indicate
2312 // a folding opportunity.
2313 Interesting = true;
2314 }
2315 }
2316 } else {
2317 // An ordinary operand. Update the map.
2318 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2319 M.insert({Ops[i], Scale});
2320 if (Pair.second) {
2321 NewOps.push_back(Pair.first->first);
2322 } else {
2323 Pair.first->second += Scale;
2324 // The map already had an entry for this value, which may indicate
2325 // a folding opportunity.
2326 Interesting = true;
2327 }
2328 }
2329 }
2330
2331 return Interesting;
2332}
2333
2335 const SCEV *LHS, const SCEV *RHS,
2336 const Instruction *CtxI) {
2337 const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2338 SCEV::NoWrapFlags, unsigned);
2339 switch (BinOp) {
2340 default:
2341 llvm_unreachable("Unsupported binary op");
2342 case Instruction::Add:
2344 break;
2345 case Instruction::Sub:
2347 break;
2348 case Instruction::Mul:
2350 break;
2351 }
2352
2353 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2356
2357 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2358 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2359 auto *WideTy =
2360 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2361
2362 const SCEV *A = (this->*Extension)(
2363 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2364 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2365 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2366 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2367 if (A == B)
2368 return true;
2369 // Can we use context to prove the fact we need?
2370 if (!CtxI)
2371 return false;
2372 // TODO: Support mul.
2373 if (BinOp == Instruction::Mul)
2374 return false;
2375 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2376 // TODO: Lift this limitation.
2377 if (!RHSC)
2378 return false;
2379 APInt C = RHSC->getAPInt();
2380 unsigned NumBits = C.getBitWidth();
2381 bool IsSub = (BinOp == Instruction::Sub);
2382 bool IsNegativeConst = (Signed && C.isNegative());
2383 // Compute the direction and magnitude by which we need to check overflow.
2384 bool OverflowDown = IsSub ^ IsNegativeConst;
2385 APInt Magnitude = C;
2386 if (IsNegativeConst) {
2387 if (C == APInt::getSignedMinValue(NumBits))
2388 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2389 // want to deal with that.
2390 return false;
2391 Magnitude = -C;
2392 }
2393
2395 if (OverflowDown) {
2396 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2397 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2398 : APInt::getMinValue(NumBits);
2399 APInt Limit = Min + Magnitude;
2400 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2401 } else {
2402 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2403 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2404 : APInt::getMaxValue(NumBits);
2405 APInt Limit = Max - Magnitude;
2406 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2407 }
2408}
2409
2410std::optional<SCEV::NoWrapFlags>
2412 const OverflowingBinaryOperator *OBO) {
2413 // It cannot be done any better.
2414 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2415 return std::nullopt;
2416
2418
2419 if (OBO->hasNoUnsignedWrap())
2421 if (OBO->hasNoSignedWrap())
2423
2424 bool Deduced = false;
2425
2426 if (OBO->getOpcode() != Instruction::Add &&
2427 OBO->getOpcode() != Instruction::Sub &&
2428 OBO->getOpcode() != Instruction::Mul)
2429 return std::nullopt;
2430
2431 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2432 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2433
2434 const Instruction *CtxI =
2436 if (!OBO->hasNoUnsignedWrap() &&
2438 /* Signed */ false, LHS, RHS, CtxI)) {
2440 Deduced = true;
2441 }
2442
2443 if (!OBO->hasNoSignedWrap() &&
2445 /* Signed */ true, LHS, RHS, CtxI)) {
2447 Deduced = true;
2448 }
2449
2450 if (Deduced)
2451 return Flags;
2452 return std::nullopt;
2453}
2454
2455// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2456// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2457// can't-overflow flags for the operation if possible.
2461 SCEV::NoWrapFlags Flags) {
2462 using namespace std::placeholders;
2463
2464 using OBO = OverflowingBinaryOperator;
2465
2466 bool CanAnalyze =
2468 (void)CanAnalyze;
2469 assert(CanAnalyze && "don't call from other places!");
2470
2471 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2472 SCEV::NoWrapFlags SignOrUnsignWrap =
2473 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2474
2475 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2476 auto IsKnownNonNegative = [&](const SCEV *S) {
2477 return SE->isKnownNonNegative(S);
2478 };
2479
2480 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2481 Flags =
2482 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2483
2484 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2485
2486 if (SignOrUnsignWrap != SignOrUnsignMask &&
2487 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2488 isa<SCEVConstant>(Ops[0])) {
2489
2490 auto Opcode = [&] {
2491 switch (Type) {
2492 case scAddExpr:
2493 return Instruction::Add;
2494 case scMulExpr:
2495 return Instruction::Mul;
2496 default:
2497 llvm_unreachable("Unexpected SCEV op.");
2498 }
2499 }();
2500
2501 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2502
2503 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2504 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2506 Opcode, C, OBO::NoSignedWrap);
2507 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2509 }
2510
2511 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2512 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2514 Opcode, C, OBO::NoUnsignedWrap);
2515 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2517 }
2518 }
2519
2520 // <0,+,nonnegative><nw> is also nuw
2521 // TODO: Add corresponding nsw case
2523 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2524 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2526
2527 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2529 Ops.size() == 2) {
2530 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2531 if (UDiv->getOperand(1) == Ops[1])
2533 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2534 if (UDiv->getOperand(1) == Ops[0])
2536 }
2537
2538 return Flags;
2539}
2540
2542 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2543}
2544
2545/// Get a canonical add expression, or something simpler if possible.
2547 SCEV::NoWrapFlags OrigFlags,
2548 unsigned Depth) {
2549 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2550 "only nuw or nsw allowed");
2551 assert(!Ops.empty() && "Cannot get empty add!");
2552 if (Ops.size() == 1) return Ops[0];
2553#ifndef NDEBUG
2554 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2555 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2556 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2557 "SCEVAddExpr operand types don't match!");
2558 unsigned NumPtrs = count_if(
2559 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2560 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2561#endif
2562
2563 const SCEV *Folded = constantFoldAndGroupOps(
2564 *this, LI, DT, Ops,
2565 [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2566 [](const APInt &C) { return C.isZero(); }, // identity
2567 [](const APInt &C) { return false; }); // absorber
2568 if (Folded)
2569 return Folded;
2570
2571 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2572
2573 // Delay expensive flag strengthening until necessary.
2574 auto ComputeFlags = [this, OrigFlags](ArrayRef<const SCEV *> Ops) {
2575 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2576 };
2577
2578 // Limit recursion calls depth.
2580 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2581
2582 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2583 // Don't strengthen flags if we have no new information.
2584 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2585 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2586 Add->setNoWrapFlags(ComputeFlags(Ops));
2587 return S;
2588 }
2589
2590 // Okay, check to see if the same value occurs in the operand list more than
2591 // once. If so, merge them together into an multiply expression. Since we
2592 // sorted the list, these values are required to be adjacent.
2593 Type *Ty = Ops[0]->getType();
2594 bool FoundMatch = false;
2595 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2596 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2597 // Scan ahead to count how many equal operands there are.
2598 unsigned Count = 2;
2599 while (i+Count != e && Ops[i+Count] == Ops[i])
2600 ++Count;
2601 // Merge the values into a multiply.
2602 const SCEV *Scale = getConstant(Ty, Count);
2603 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2604 if (Ops.size() == Count)
2605 return Mul;
2606 Ops[i] = Mul;
2607 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2608 --i; e -= Count - 1;
2609 FoundMatch = true;
2610 }
2611 if (FoundMatch)
2612 return getAddExpr(Ops, OrigFlags, Depth + 1);
2613
2614 // Check for truncates. If all the operands are truncated from the same
2615 // type, see if factoring out the truncate would permit the result to be
2616 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2617 // if the contents of the resulting outer trunc fold to something simple.
2618 auto FindTruncSrcType = [&]() -> Type * {
2619 // We're ultimately looking to fold an addrec of truncs and muls of only
2620 // constants and truncs, so if we find any other types of SCEV
2621 // as operands of the addrec then we bail and return nullptr here.
2622 // Otherwise, we return the type of the operand of a trunc that we find.
2623 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2624 return T->getOperand()->getType();
2625 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2626 const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2627 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2628 return T->getOperand()->getType();
2629 }
2630 return nullptr;
2631 };
2632 if (auto *SrcType = FindTruncSrcType()) {
2634 bool Ok = true;
2635 // Check all the operands to see if they can be represented in the
2636 // source type of the truncate.
2637 for (const SCEV *Op : Ops) {
2639 if (T->getOperand()->getType() != SrcType) {
2640 Ok = false;
2641 break;
2642 }
2643 LargeOps.push_back(T->getOperand());
2644 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2645 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2646 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2647 SmallVector<const SCEV *, 8> LargeMulOps;
2648 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2649 if (const SCEVTruncateExpr *T =
2650 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2651 if (T->getOperand()->getType() != SrcType) {
2652 Ok = false;
2653 break;
2654 }
2655 LargeMulOps.push_back(T->getOperand());
2656 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2657 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2658 } else {
2659 Ok = false;
2660 break;
2661 }
2662 }
2663 if (Ok)
2664 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2665 } else {
2666 Ok = false;
2667 break;
2668 }
2669 }
2670 if (Ok) {
2671 // Evaluate the expression in the larger type.
2672 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2673 // If it folds to something simple, use it. Otherwise, don't.
2674 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2675 return getTruncateExpr(Fold, Ty);
2676 }
2677 }
2678
2679 if (Ops.size() == 2) {
2680 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2681 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2682 // C1).
2683 const SCEV *A = Ops[0];
2684 const SCEV *B = Ops[1];
2685 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2686 auto *C = dyn_cast<SCEVConstant>(A);
2687 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2688 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2689 auto C2 = C->getAPInt();
2690 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2691
2692 APInt ConstAdd = C1 + C2;
2693 auto AddFlags = AddExpr->getNoWrapFlags();
2694 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2696 ConstAdd.ule(C1)) {
2697 PreservedFlags =
2699 }
2700
2701 // Adding a constant with the same sign and small magnitude is NSW, if the
2702 // original AddExpr was NSW.
2704 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2705 ConstAdd.abs().ule(C1.abs())) {
2706 PreservedFlags =
2708 }
2709
2710 if (PreservedFlags != SCEV::FlagAnyWrap) {
2711 SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2712 NewOps[0] = getConstant(ConstAdd);
2713 return getAddExpr(NewOps, PreservedFlags);
2714 }
2715 }
2716
2717 // Try to push the constant operand into a ZExt: A + zext (-A + B) -> zext
2718 // (B), if trunc (A) + -A + B does not unsigned-wrap.
2719 const SCEVAddExpr *InnerAdd;
2720 if (match(B, m_scev_ZExt(m_scev_Add(InnerAdd)))) {
2721 const SCEV *NarrowA = getTruncateExpr(A, InnerAdd->getType());
2722 if (NarrowA == getNegativeSCEV(InnerAdd->getOperand(0)) &&
2723 getZeroExtendExpr(NarrowA, B->getType()) == A &&
2724 hasFlags(StrengthenNoWrapFlags(this, scAddExpr, {NarrowA, InnerAdd},
2726 SCEV::FlagNUW)) {
2727 return getZeroExtendExpr(getAddExpr(NarrowA, InnerAdd), B->getType());
2728 }
2729 }
2730 }
2731
2732 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2733 const SCEV *Y;
2734 if (Ops.size() == 2 &&
2735 match(Ops[0],
2737 m_scev_URem(m_scev_Specific(Ops[1]), m_SCEV(Y), *this))))
2738 return getMulExpr(Y, getUDivExpr(Ops[1], Y));
2739
2740 // Skip past any other cast SCEVs.
2741 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2742 ++Idx;
2743
2744 // If there are add operands they would be next.
2745 if (Idx < Ops.size()) {
2746 bool DeletedAdd = false;
2747 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2748 // common NUW flag for expression after inlining. Other flags cannot be
2749 // preserved, because they may depend on the original order of operations.
2750 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2751 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2752 if (Ops.size() > AddOpsInlineThreshold ||
2753 Add->getNumOperands() > AddOpsInlineThreshold)
2754 break;
2755 // If we have an add, expand the add operands onto the end of the operands
2756 // list.
2757 Ops.erase(Ops.begin()+Idx);
2758 append_range(Ops, Add->operands());
2759 DeletedAdd = true;
2760 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2761 }
2762
2763 // If we deleted at least one add, we added operands to the end of the list,
2764 // and they are not necessarily sorted. Recurse to resort and resimplify
2765 // any operands we just acquired.
2766 if (DeletedAdd)
2767 return getAddExpr(Ops, CommonFlags, Depth + 1);
2768 }
2769
2770 // Skip over the add expression until we get to a multiply.
2771 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2772 ++Idx;
2773
2774 // Check to see if there are any folding opportunities present with
2775 // operands multiplied by constant values.
2776 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2780 APInt AccumulatedConstant(BitWidth, 0);
2781 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2782 Ops, APInt(BitWidth, 1), *this)) {
2783 struct APIntCompare {
2784 bool operator()(const APInt &LHS, const APInt &RHS) const {
2785 return LHS.ult(RHS);
2786 }
2787 };
2788
2789 // Some interesting folding opportunity is present, so its worthwhile to
2790 // re-generate the operands list. Group the operands by constant scale,
2791 // to avoid multiplying by the same constant scale multiple times.
2792 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2793 for (const SCEV *NewOp : NewOps)
2794 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2795 // Re-generate the operands list.
2796 Ops.clear();
2797 if (AccumulatedConstant != 0)
2798 Ops.push_back(getConstant(AccumulatedConstant));
2799 for (auto &MulOp : MulOpLists) {
2800 if (MulOp.first == 1) {
2801 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2802 } else if (MulOp.first != 0) {
2803 Ops.push_back(getMulExpr(
2804 getConstant(MulOp.first),
2805 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2806 SCEV::FlagAnyWrap, Depth + 1));
2807 }
2808 }
2809 if (Ops.empty())
2810 return getZero(Ty);
2811 if (Ops.size() == 1)
2812 return Ops[0];
2813 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2814 }
2815 }
2816
2817 // If we are adding something to a multiply expression, make sure the
2818 // something is not already an operand of the multiply. If so, merge it into
2819 // the multiply.
2820 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2821 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2822 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2823 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2824 if (isa<SCEVConstant>(MulOpSCEV))
2825 continue;
2826 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2827 if (MulOpSCEV == Ops[AddOp]) {
2828 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2829 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2830 if (Mul->getNumOperands() != 2) {
2831 // If the multiply has more than two operands, we must get the
2832 // Y*Z term.
2834 Mul->operands().take_front(MulOp));
2835 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2836 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2837 }
2838 SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2839 const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2840 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2842 if (Ops.size() == 2) return OuterMul;
2843 if (AddOp < Idx) {
2844 Ops.erase(Ops.begin()+AddOp);
2845 Ops.erase(Ops.begin()+Idx-1);
2846 } else {
2847 Ops.erase(Ops.begin()+Idx);
2848 Ops.erase(Ops.begin()+AddOp-1);
2849 }
2850 Ops.push_back(OuterMul);
2851 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2852 }
2853
2854 // Check this multiply against other multiplies being added together.
2855 for (unsigned OtherMulIdx = Idx+1;
2856 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2857 ++OtherMulIdx) {
2858 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2859 // If MulOp occurs in OtherMul, we can fold the two multiplies
2860 // together.
2861 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2862 OMulOp != e; ++OMulOp)
2863 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2864 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2865 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2866 if (Mul->getNumOperands() != 2) {
2868 Mul->operands().take_front(MulOp));
2869 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2870 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2871 }
2872 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2873 if (OtherMul->getNumOperands() != 2) {
2875 OtherMul->operands().take_front(OMulOp));
2876 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2877 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2878 }
2879 SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2880 const SCEV *InnerMulSum =
2881 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2882 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2884 if (Ops.size() == 2) return OuterMul;
2885 Ops.erase(Ops.begin()+Idx);
2886 Ops.erase(Ops.begin()+OtherMulIdx-1);
2887 Ops.push_back(OuterMul);
2888 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2889 }
2890 }
2891 }
2892 }
2893
2894 // If there are any add recurrences in the operands list, see if any other
2895 // added values are loop invariant. If so, we can fold them into the
2896 // recurrence.
2897 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2898 ++Idx;
2899
2900 // Scan over all recurrences, trying to fold loop invariants into them.
2901 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2902 // Scan all of the other operands to this add and add them to the vector if
2903 // they are loop invariant w.r.t. the recurrence.
2905 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2906 const Loop *AddRecLoop = AddRec->getLoop();
2907 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2908 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2909 LIOps.push_back(Ops[i]);
2910 Ops.erase(Ops.begin()+i);
2911 --i; --e;
2912 }
2913
2914 // If we found some loop invariants, fold them into the recurrence.
2915 if (!LIOps.empty()) {
2916 // Compute nowrap flags for the addition of the loop-invariant ops and
2917 // the addrec. Temporarily push it as an operand for that purpose. These
2918 // flags are valid in the scope of the addrec only.
2919 LIOps.push_back(AddRec);
2920 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2921 LIOps.pop_back();
2922
2923 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2924 LIOps.push_back(AddRec->getStart());
2925
2926 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2927
2928 // It is not in general safe to propagate flags valid on an add within
2929 // the addrec scope to one outside it. We must prove that the inner
2930 // scope is guaranteed to execute if the outer one does to be able to
2931 // safely propagate. We know the program is undefined if poison is
2932 // produced on the inner scoped addrec. We also know that *for this use*
2933 // the outer scoped add can't overflow (because of the flags we just
2934 // computed for the inner scoped add) without the program being undefined.
2935 // Proving that entry to the outer scope neccesitates entry to the inner
2936 // scope, thus proves the program undefined if the flags would be violated
2937 // in the outer scope.
2938 SCEV::NoWrapFlags AddFlags = Flags;
2939 if (AddFlags != SCEV::FlagAnyWrap) {
2940 auto *DefI = getDefiningScopeBound(LIOps);
2941 auto *ReachI = &*AddRecLoop->getHeader()->begin();
2942 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2943 AddFlags = SCEV::FlagAnyWrap;
2944 }
2945 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2946
2947 // Build the new addrec. Propagate the NUW and NSW flags if both the
2948 // outer add and the inner addrec are guaranteed to have no overflow.
2949 // Always propagate NW.
2950 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2951 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2952
2953 // If all of the other operands were loop invariant, we are done.
2954 if (Ops.size() == 1) return NewRec;
2955
2956 // Otherwise, add the folded AddRec by the non-invariant parts.
2957 for (unsigned i = 0;; ++i)
2958 if (Ops[i] == AddRec) {
2959 Ops[i] = NewRec;
2960 break;
2961 }
2962 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2963 }
2964
2965 // Okay, if there weren't any loop invariants to be folded, check to see if
2966 // there are multiple AddRec's with the same loop induction variable being
2967 // added together. If so, we can fold them.
2968 for (unsigned OtherIdx = Idx+1;
2969 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2970 ++OtherIdx) {
2971 // We expect the AddRecExpr's to be sorted in reverse dominance order,
2972 // so that the 1st found AddRecExpr is dominated by all others.
2973 assert(DT.dominates(
2974 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2975 AddRec->getLoop()->getHeader()) &&
2976 "AddRecExprs are not sorted in reverse dominance order?");
2977 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2978 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2979 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2980 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2981 ++OtherIdx) {
2982 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2983 if (OtherAddRec->getLoop() == AddRecLoop) {
2984 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2985 i != e; ++i) {
2986 if (i >= AddRecOps.size()) {
2987 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
2988 break;
2989 }
2991 AddRecOps[i], OtherAddRec->getOperand(i)};
2992 AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2993 }
2994 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2995 }
2996 }
2997 // Step size has changed, so we cannot guarantee no self-wraparound.
2998 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2999 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3000 }
3001 }
3002
3003 // Otherwise couldn't fold anything into this recurrence. Move onto the
3004 // next one.
3005 }
3006
3007 // Okay, it looks like we really DO need an add expr. Check to see if we
3008 // already have one, otherwise create a new one.
3009 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
3010}
3011
3012const SCEV *
3013ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
3014 SCEV::NoWrapFlags Flags) {
3016 ID.AddInteger(scAddExpr);
3017 for (const SCEV *Op : Ops)
3018 ID.AddPointer(Op);
3019 void *IP = nullptr;
3020 SCEVAddExpr *S =
3021 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3022 if (!S) {
3023 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3025 S = new (SCEVAllocator)
3026 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
3027 UniqueSCEVs.InsertNode(S, IP);
3028 registerUser(S, Ops);
3029 }
3030 S->setNoWrapFlags(Flags);
3031 return S;
3032}
3033
3034const SCEV *
3035ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
3036 const Loop *L, SCEV::NoWrapFlags Flags) {
3037 FoldingSetNodeID ID;
3038 ID.AddInteger(scAddRecExpr);
3039 for (const SCEV *Op : Ops)
3040 ID.AddPointer(Op);
3041 ID.AddPointer(L);
3042 void *IP = nullptr;
3043 SCEVAddRecExpr *S =
3044 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3045 if (!S) {
3046 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3048 S = new (SCEVAllocator)
3049 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3050 UniqueSCEVs.InsertNode(S, IP);
3051 LoopUsers[L].push_back(S);
3052 registerUser(S, Ops);
3053 }
3054 setNoWrapFlags(S, Flags);
3055 return S;
3056}
3057
3058const SCEV *
3059ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
3060 SCEV::NoWrapFlags Flags) {
3061 FoldingSetNodeID ID;
3062 ID.AddInteger(scMulExpr);
3063 for (const SCEV *Op : Ops)
3064 ID.AddPointer(Op);
3065 void *IP = nullptr;
3066 SCEVMulExpr *S =
3067 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3068 if (!S) {
3069 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3071 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3072 O, Ops.size());
3073 UniqueSCEVs.InsertNode(S, IP);
3074 registerUser(S, Ops);
3075 }
3076 S->setNoWrapFlags(Flags);
3077 return S;
3078}
3079
3080static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3081 uint64_t k = i*j;
3082 if (j > 1 && k / j != i) Overflow = true;
3083 return k;
3084}
3085
3086/// Compute the result of "n choose k", the binomial coefficient. If an
3087/// intermediate computation overflows, Overflow will be set and the return will
3088/// be garbage. Overflow is not cleared on absence of overflow.
3089static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3090 // We use the multiplicative formula:
3091 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3092 // At each iteration, we take the n-th term of the numeral and divide by the
3093 // (k-n)th term of the denominator. This division will always produce an
3094 // integral result, and helps reduce the chance of overflow in the
3095 // intermediate computations. However, we can still overflow even when the
3096 // final result would fit.
3097
3098 if (n == 0 || n == k) return 1;
3099 if (k > n) return 0;
3100
3101 if (k > n/2)
3102 k = n-k;
3103
3104 uint64_t r = 1;
3105 for (uint64_t i = 1; i <= k; ++i) {
3106 r = umul_ov(r, n-(i-1), Overflow);
3107 r /= i;
3108 }
3109 return r;
3110}
3111
3112/// Determine if any of the operands in this SCEV are a constant or if
3113/// any of the add or multiply expressions in this SCEV contain a constant.
3114static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3115 struct FindConstantInAddMulChain {
3116 bool FoundConstant = false;
3117
3118 bool follow(const SCEV *S) {
3119 FoundConstant |= isa<SCEVConstant>(S);
3120 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3121 }
3122
3123 bool isDone() const {
3124 return FoundConstant;
3125 }
3126 };
3127
3128 FindConstantInAddMulChain F;
3130 ST.visitAll(StartExpr);
3131 return F.FoundConstant;
3132}
3133
3134/// Get a canonical multiply expression, or something simpler if possible.
3136 SCEV::NoWrapFlags OrigFlags,
3137 unsigned Depth) {
3138 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3139 "only nuw or nsw allowed");
3140 assert(!Ops.empty() && "Cannot get empty mul!");
3141 if (Ops.size() == 1) return Ops[0];
3142#ifndef NDEBUG
3143 Type *ETy = Ops[0]->getType();
3144 assert(!ETy->isPointerTy());
3145 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3146 assert(Ops[i]->getType() == ETy &&
3147 "SCEVMulExpr operand types don't match!");
3148#endif
3149
3150 const SCEV *Folded = constantFoldAndGroupOps(
3151 *this, LI, DT, Ops,
3152 [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3153 [](const APInt &C) { return C.isOne(); }, // identity
3154 [](const APInt &C) { return C.isZero(); }); // absorber
3155 if (Folded)
3156 return Folded;
3157
3158 // Delay expensive flag strengthening until necessary.
3159 auto ComputeFlags = [this, OrigFlags](ArrayRef<const SCEV *> Ops) {
3160 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3161 };
3162
3163 // Limit recursion calls depth.
3165 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3166
3167 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3168 // Don't strengthen flags if we have no new information.
3169 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3170 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3171 Mul->setNoWrapFlags(ComputeFlags(Ops));
3172 return S;
3173 }
3174
3175 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3176 if (Ops.size() == 2) {
3177 // C1*(C2+V) -> C1*C2 + C1*V
3178 // If any of Add's ops are Adds or Muls with a constant, apply this
3179 // transformation as well.
3180 //
3181 // TODO: There are some cases where this transformation is not
3182 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3183 // this transformation should be narrowed down.
3184 const SCEV *Op0, *Op1;
3185 if (match(Ops[1], m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))) &&
3187 const SCEV *LHS = getMulExpr(LHSC, Op0, SCEV::FlagAnyWrap, Depth + 1);
3188 const SCEV *RHS = getMulExpr(LHSC, Op1, SCEV::FlagAnyWrap, Depth + 1);
3189 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3190 }
3191
3192 if (Ops[0]->isAllOnesValue()) {
3193 // If we have a mul by -1 of an add, try distributing the -1 among the
3194 // add operands.
3195 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3197 bool AnyFolded = false;
3198 for (const SCEV *AddOp : Add->operands()) {
3199 const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3200 Depth + 1);
3201 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3202 NewOps.push_back(Mul);
3203 }
3204 if (AnyFolded)
3205 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3206 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3207 // Negation preserves a recurrence's no self-wrap property.
3209 for (const SCEV *AddRecOp : AddRec->operands())
3210 Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3211 Depth + 1));
3212 // Let M be the minimum representable signed value. AddRec with nsw
3213 // multiplied by -1 can have signed overflow if and only if it takes a
3214 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3215 // maximum signed value. In all other cases signed overflow is
3216 // impossible.
3217 auto FlagsMask = SCEV::FlagNW;
3218 if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) {
3219 auto MinInt =
3220 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3221 if (getSignedRangeMin(AddRec) != MinInt)
3222 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3223 }
3224 return getAddRecExpr(Operands, AddRec->getLoop(),
3225 AddRec->getNoWrapFlags(FlagsMask));
3226 }
3227 }
3228
3229 // Try to push the constant operand into a ZExt: C * zext (A + B) ->
3230 // zext (C*A + C*B) if trunc (C) * (A + B) does not unsigned-wrap.
3231 const SCEVAddExpr *InnerAdd;
3232 if (match(Ops[1], m_scev_ZExt(m_scev_Add(InnerAdd)))) {
3233 const SCEV *NarrowC = getTruncateExpr(LHSC, InnerAdd->getType());
3234 if (isa<SCEVConstant>(InnerAdd->getOperand(0)) &&
3235 getZeroExtendExpr(NarrowC, Ops[1]->getType()) == LHSC &&
3236 hasFlags(StrengthenNoWrapFlags(this, scMulExpr, {NarrowC, InnerAdd},
3238 SCEV::FlagNUW)) {
3239 auto *Res = getMulExpr(NarrowC, InnerAdd, SCEV::FlagNUW, Depth + 1);
3240 return getZeroExtendExpr(Res, Ops[1]->getType(), Depth + 1);
3241 };
3242 }
3243
3244 // Try to fold (C1 * D /u C2) -> C1/C2 * D, if C1 and C2 are powers-of-2,
3245 // D is a multiple of C2, and C1 is a multiple of C2. If C2 is a multiple
3246 // of C1, fold to (D /u (C2 /u C1)).
3247 const SCEV *D;
3248 APInt C1V = LHSC->getAPInt();
3249 // (C1 * D /u C2) == -1 * -C1 * D /u C2 when C1 != INT_MIN. Don't treat -1
3250 // as -1 * 1, as it won't enable additional folds.
3251 if (C1V.isNegative() && !C1V.isMinSignedValue() && !C1V.isAllOnes())
3252 C1V = C1V.abs();
3253 const SCEVConstant *C2;
3254 if (C1V.isPowerOf2() &&
3256 C2->getAPInt().isPowerOf2() &&
3257 C1V.logBase2() <= getMinTrailingZeros(D)) {
3258 const SCEV *NewMul = nullptr;
3259 if (C1V.uge(C2->getAPInt())) {
3260 NewMul = getMulExpr(getUDivExpr(getConstant(C1V), C2), D);
3261 } else if (C2->getAPInt().logBase2() <= getMinTrailingZeros(D)) {
3262 assert(C1V.ugt(1) && "C1 <= 1 should have been folded earlier");
3263 NewMul = getUDivExpr(D, getUDivExpr(C2, getConstant(C1V)));
3264 }
3265 if (NewMul)
3266 return C1V == LHSC->getAPInt() ? NewMul : getNegativeSCEV(NewMul);
3267 }
3268 }
3269 }
3270
3271 // Skip over the add expression until we get to a multiply.
3272 unsigned Idx = 0;
3273 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3274 ++Idx;
3275
3276 // If there are mul operands inline them all into this expression.
3277 if (Idx < Ops.size()) {
3278 bool DeletedMul = false;
3279 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3280 if (Ops.size() > MulOpsInlineThreshold)
3281 break;
3282 // If we have an mul, expand the mul operands onto the end of the
3283 // operands list.
3284 Ops.erase(Ops.begin()+Idx);
3285 append_range(Ops, Mul->operands());
3286 DeletedMul = true;
3287 }
3288
3289 // If we deleted at least one mul, we added operands to the end of the
3290 // list, and they are not necessarily sorted. Recurse to resort and
3291 // resimplify any operands we just acquired.
3292 if (DeletedMul)
3293 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3294 }
3295
3296 // If there are any add recurrences in the operands list, see if any other
3297 // added values are loop invariant. If so, we can fold them into the
3298 // recurrence.
3299 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3300 ++Idx;
3301
3302 // Scan over all recurrences, trying to fold loop invariants into them.
3303 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3304 // Scan all of the other operands to this mul and add them to the vector
3305 // if they are loop invariant w.r.t. the recurrence.
3307 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3308 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3309 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3310 LIOps.push_back(Ops[i]);
3311 Ops.erase(Ops.begin()+i);
3312 --i; --e;
3313 }
3314
3315 // If we found some loop invariants, fold them into the recurrence.
3316 if (!LIOps.empty()) {
3317 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3319 NewOps.reserve(AddRec->getNumOperands());
3320 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3321
3322 // If both the mul and addrec are nuw, we can preserve nuw.
3323 // If both the mul and addrec are nsw, we can only preserve nsw if either
3324 // a) they are also nuw, or
3325 // b) all multiplications of addrec operands with scale are nsw.
3326 SCEV::NoWrapFlags Flags =
3327 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3328
3329 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3330 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3331 SCEV::FlagAnyWrap, Depth + 1));
3332
3333 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3335 Instruction::Mul, getSignedRange(Scale),
3337 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3338 Flags = clearFlags(Flags, SCEV::FlagNSW);
3339 }
3340 }
3341
3342 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3343
3344 // If all of the other operands were loop invariant, we are done.
3345 if (Ops.size() == 1) return NewRec;
3346
3347 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3348 for (unsigned i = 0;; ++i)
3349 if (Ops[i] == AddRec) {
3350 Ops[i] = NewRec;
3351 break;
3352 }
3353 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3354 }
3355
3356 // Okay, if there weren't any loop invariants to be folded, check to see
3357 // if there are multiple AddRec's with the same loop induction variable
3358 // being multiplied together. If so, we can fold them.
3359
3360 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3361 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3362 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3363 // ]]],+,...up to x=2n}.
3364 // Note that the arguments to choose() are always integers with values
3365 // known at compile time, never SCEV objects.
3366 //
3367 // The implementation avoids pointless extra computations when the two
3368 // addrec's are of different length (mathematically, it's equivalent to
3369 // an infinite stream of zeros on the right).
3370 bool OpsModified = false;
3371 for (unsigned OtherIdx = Idx+1;
3372 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3373 ++OtherIdx) {
3374 const SCEVAddRecExpr *OtherAddRec =
3375 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3376 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3377 continue;
3378
3379 // Limit max number of arguments to avoid creation of unreasonably big
3380 // SCEVAddRecs with very complex operands.
3381 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3382 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3383 continue;
3384
3385 bool Overflow = false;
3386 Type *Ty = AddRec->getType();
3387 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3389 for (int x = 0, xe = AddRec->getNumOperands() +
3390 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3392 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3393 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3394 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3395 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3396 z < ze && !Overflow; ++z) {
3397 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3398 uint64_t Coeff;
3399 if (LargerThan64Bits)
3400 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3401 else
3402 Coeff = Coeff1*Coeff2;
3403 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3404 const SCEV *Term1 = AddRec->getOperand(y-z);
3405 const SCEV *Term2 = OtherAddRec->getOperand(z);
3406 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3407 SCEV::FlagAnyWrap, Depth + 1));
3408 }
3409 }
3410 if (SumOps.empty())
3411 SumOps.push_back(getZero(Ty));
3412 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3413 }
3414 if (!Overflow) {
3415 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3417 if (Ops.size() == 2) return NewAddRec;
3418 Ops[Idx] = NewAddRec;
3419 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3420 OpsModified = true;
3421 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3422 if (!AddRec)
3423 break;
3424 }
3425 }
3426 if (OpsModified)
3427 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3428
3429 // Otherwise couldn't fold anything into this recurrence. Move onto the
3430 // next one.
3431 }
3432
3433 // Okay, it looks like we really DO need an mul expr. Check to see if we
3434 // already have one, otherwise create a new one.
3435 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3436}
3437
3438/// Represents an unsigned remainder expression based on unsigned division.
3440 const SCEV *RHS) {
3441 assert(getEffectiveSCEVType(LHS->getType()) ==
3442 getEffectiveSCEVType(RHS->getType()) &&
3443 "SCEVURemExpr operand types don't match!");
3444
3445 // Short-circuit easy cases
3446 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3447 // If constant is one, the result is trivial
3448 if (RHSC->getValue()->isOne())
3449 return getZero(LHS->getType()); // X urem 1 --> 0
3450
3451 // If constant is a power of two, fold into a zext(trunc(LHS)).
3452 if (RHSC->getAPInt().isPowerOf2()) {
3453 Type *FullTy = LHS->getType();
3454 Type *TruncTy =
3455 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3456 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3457 }
3458 }
3459
3460 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3461 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3462 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3463 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3464}
3465
3466/// Get a canonical unsigned division expression, or something simpler if
3467/// possible.
3469 const SCEV *RHS) {
3470 assert(!LHS->getType()->isPointerTy() &&
3471 "SCEVUDivExpr operand can't be pointer!");
3472 assert(LHS->getType() == RHS->getType() &&
3473 "SCEVUDivExpr operand types don't match!");
3474
3476 ID.AddInteger(scUDivExpr);
3477 ID.AddPointer(LHS);
3478 ID.AddPointer(RHS);
3479 void *IP = nullptr;
3480 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3481 return S;
3482
3483 // 0 udiv Y == 0
3484 if (match(LHS, m_scev_Zero()))
3485 return LHS;
3486
3487 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3488 if (RHSC->getValue()->isOne())
3489 return LHS; // X udiv 1 --> x
3490 // If the denominator is zero, the result of the udiv is undefined. Don't
3491 // try to analyze it, because the resolution chosen here may differ from
3492 // the resolution chosen in other parts of the compiler.
3493 if (!RHSC->getValue()->isZero()) {
3494 // Determine if the division can be folded into the operands of
3495 // its operands.
3496 // TODO: Generalize this to non-constants by using known-bits information.
3497 Type *Ty = LHS->getType();
3498 unsigned LZ = RHSC->getAPInt().countl_zero();
3499 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3500 // For non-power-of-two values, effectively round the value up to the
3501 // nearest power of two.
3502 if (!RHSC->getAPInt().isPowerOf2())
3503 ++MaxShiftAmt;
3504 IntegerType *ExtTy =
3505 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3506 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3507 if (const SCEVConstant *Step =
3508 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3509 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3510 const APInt &StepInt = Step->getAPInt();
3511 const APInt &DivInt = RHSC->getAPInt();
3512 if (!StepInt.urem(DivInt) &&
3513 getZeroExtendExpr(AR, ExtTy) ==
3514 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3515 getZeroExtendExpr(Step, ExtTy),
3516 AR->getLoop(), SCEV::FlagAnyWrap)) {
3518 for (const SCEV *Op : AR->operands())
3519 Operands.push_back(getUDivExpr(Op, RHS));
3520 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3521 }
3522 /// Get a canonical UDivExpr for a recurrence.
3523 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3524 const APInt *StartRem;
3525 if (!DivInt.urem(StepInt) && match(getURemExpr(AR->getStart(), Step),
3526 m_scev_APInt(StartRem))) {
3527 bool NoWrap =
3528 getZeroExtendExpr(AR, ExtTy) ==
3529 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3530 getZeroExtendExpr(Step, ExtTy), AR->getLoop(),
3532
3533 // With N <= C and both N, C as powers-of-2, the transformation
3534 // {X,+,N}/C => {(X - X%N),+,N}/C preserves division results even
3535 // if wrapping occurs, as the division results remain equivalent for
3536 // all offsets in [[(X - X%N), X).
3537 bool CanFoldWithWrap = StepInt.ule(DivInt) && // N <= C
3538 StepInt.isPowerOf2() && DivInt.isPowerOf2();
3539 // Only fold if the subtraction can be folded in the start
3540 // expression.
3541 const SCEV *NewStart =
3542 getMinusSCEV(AR->getStart(), getConstant(*StartRem));
3543 if (*StartRem != 0 && (NoWrap || CanFoldWithWrap) &&
3544 !isa<SCEVAddExpr>(NewStart)) {
3545 const SCEV *NewLHS =
3546 getAddRecExpr(NewStart, Step, AR->getLoop(),
3547 NoWrap ? SCEV::FlagNW : SCEV::FlagAnyWrap);
3548 if (LHS != NewLHS) {
3549 LHS = NewLHS;
3550
3551 // Reset the ID to include the new LHS, and check if it is
3552 // already cached.
3553 ID.clear();
3554 ID.AddInteger(scUDivExpr);
3555 ID.AddPointer(LHS);
3556 ID.AddPointer(RHS);
3557 IP = nullptr;
3558 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3559 return S;
3560 }
3561 }
3562 }
3563 }
3564 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3565 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3567 for (const SCEV *Op : M->operands())
3568 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3569 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3570 // Find an operand that's safely divisible.
3571 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3572 const SCEV *Op = M->getOperand(i);
3573 const SCEV *Div = getUDivExpr(Op, RHSC);
3574 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3575 Operands = SmallVector<const SCEV *, 4>(M->operands());
3576 Operands[i] = Div;
3577 return getMulExpr(Operands);
3578 }
3579 }
3580 }
3581
3582 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3583 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3584 if (auto *DivisorConstant =
3585 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3586 bool Overflow = false;
3587 APInt NewRHS =
3588 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3589 if (Overflow) {
3590 return getConstant(RHSC->getType(), 0, false);
3591 }
3592 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3593 }
3594 }
3595
3596 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3597 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3599 for (const SCEV *Op : A->operands())
3600 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3601 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3602 Operands.clear();
3603 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3604 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3605 if (isa<SCEVUDivExpr>(Op) ||
3606 getMulExpr(Op, RHS) != A->getOperand(i))
3607 break;
3608 Operands.push_back(Op);
3609 }
3610 if (Operands.size() == A->getNumOperands())
3611 return getAddExpr(Operands);
3612 }
3613 }
3614
3615 // Fold if both operands are constant.
3616 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3617 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3618 }
3619 }
3620
3621 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3622 const APInt *NegC, *C;
3623 if (match(LHS,
3626 NegC->isNegative() && !NegC->isMinSignedValue() && *C == -*NegC)
3627 return getZero(LHS->getType());
3628
3629 // TODO: Generalize to handle any common factors.
3630 // udiv (mul nuw a, vscale), (mul nuw b, vscale) --> udiv a, b
3631 const SCEV *NewLHS, *NewRHS;
3632 if (match(LHS, m_scev_c_NUWMul(m_SCEV(NewLHS), m_SCEVVScale())) &&
3633 match(RHS, m_scev_c_NUWMul(m_SCEV(NewRHS), m_SCEVVScale())))
3634 return getUDivExpr(NewLHS, NewRHS);
3635
3636 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3637 // changes). Make sure we get a new one.
3638 IP = nullptr;
3639 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3640 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3641 LHS, RHS);
3642 UniqueSCEVs.InsertNode(S, IP);
3643 registerUser(S, {LHS, RHS});
3644 return S;
3645}
3646
3647APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3648 APInt A = C1->getAPInt().abs();
3649 APInt B = C2->getAPInt().abs();
3650 uint32_t ABW = A.getBitWidth();
3651 uint32_t BBW = B.getBitWidth();
3652
3653 if (ABW > BBW)
3654 B = B.zext(ABW);
3655 else if (ABW < BBW)
3656 A = A.zext(BBW);
3657
3658 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3659}
3660
3661/// Get a canonical unsigned division expression, or something simpler if
3662/// possible. There is no representation for an exact udiv in SCEV IR, but we
3663/// can attempt to remove factors from the LHS and RHS. We can't do this when
3664/// it's not exact because the udiv may be clearing bits.
3666 const SCEV *RHS) {
3667 // TODO: we could try to find factors in all sorts of things, but for now we
3668 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3669 // end of this file for inspiration.
3670
3672 if (!Mul || !Mul->hasNoUnsignedWrap())
3673 return getUDivExpr(LHS, RHS);
3674
3675 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3676 // If the mulexpr multiplies by a constant, then that constant must be the
3677 // first element of the mulexpr.
3678 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3679 if (LHSCst == RHSCst) {
3680 SmallVector<const SCEV *, 2> Operands(drop_begin(Mul->operands()));
3681 return getMulExpr(Operands);
3682 }
3683
3684 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3685 // that there's a factor provided by one of the other terms. We need to
3686 // check.
3687 APInt Factor = gcd(LHSCst, RHSCst);
3688 if (!Factor.isIntN(1)) {
3689 LHSCst =
3690 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3691 RHSCst =
3692 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3694 Operands.push_back(LHSCst);
3695 append_range(Operands, Mul->operands().drop_front());
3696 LHS = getMulExpr(Operands);
3697 RHS = RHSCst;
3699 if (!Mul)
3700 return getUDivExactExpr(LHS, RHS);
3701 }
3702 }
3703 }
3704
3705 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3706 if (Mul->getOperand(i) == RHS) {
3708 append_range(Operands, Mul->operands().take_front(i));
3709 append_range(Operands, Mul->operands().drop_front(i + 1));
3710 return getMulExpr(Operands);
3711 }
3712 }
3713
3714 return getUDivExpr(LHS, RHS);
3715}
3716
3717/// Get an add recurrence expression for the specified loop. Simplify the
3718/// expression as much as possible.
3719const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3720 const Loop *L,
3721 SCEV::NoWrapFlags Flags) {
3723 Operands.push_back(Start);
3724 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3725 if (StepChrec->getLoop() == L) {
3726 append_range(Operands, StepChrec->operands());
3727 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3728 }
3729
3730 Operands.push_back(Step);
3731 return getAddRecExpr(Operands, L, Flags);
3732}
3733
3734/// Get an add recurrence expression for the specified loop. Simplify the
3735/// expression as much as possible.
3736const SCEV *
3738 const Loop *L, SCEV::NoWrapFlags Flags) {
3739 if (Operands.size() == 1) return Operands[0];
3740#ifndef NDEBUG
3741 Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3742 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3743 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3744 "SCEVAddRecExpr operand types don't match!");
3745 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3746 }
3747 for (const SCEV *Op : Operands)
3749 "SCEVAddRecExpr operand is not available at loop entry!");
3750#endif
3751
3752 if (Operands.back()->isZero()) {
3753 Operands.pop_back();
3754 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3755 }
3756
3757 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3758 // use that information to infer NUW and NSW flags. However, computing a
3759 // BE count requires calling getAddRecExpr, so we may not yet have a
3760 // meaningful BE count at this point (and if we don't, we'd be stuck
3761 // with a SCEVCouldNotCompute as the cached BE count).
3762
3763 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3764
3765 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3766 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3767 const Loop *NestedLoop = NestedAR->getLoop();
3768 if (L->contains(NestedLoop)
3769 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3770 : (!NestedLoop->contains(L) &&
3771 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3772 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3773 Operands[0] = NestedAR->getStart();
3774 // AddRecs require their operands be loop-invariant with respect to their
3775 // loops. Don't perform this transformation if it would break this
3776 // requirement.
3777 bool AllInvariant = all_of(
3778 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3779
3780 if (AllInvariant) {
3781 // Create a recurrence for the outer loop with the same step size.
3782 //
3783 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3784 // inner recurrence has the same property.
3785 SCEV::NoWrapFlags OuterFlags =
3786 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3787
3788 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3789 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3790 return isLoopInvariant(Op, NestedLoop);
3791 });
3792
3793 if (AllInvariant) {
3794 // Ok, both add recurrences are valid after the transformation.
3795 //
3796 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3797 // the outer recurrence has the same property.
3798 SCEV::NoWrapFlags InnerFlags =
3799 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3800 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3801 }
3802 }
3803 // Reset Operands to its original state.
3804 Operands[0] = NestedAR;
3805 }
3806 }
3807
3808 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3809 // already have one, otherwise create a new one.
3810 return getOrCreateAddRecExpr(Operands, L, Flags);
3811}
3812
3814 ArrayRef<const SCEV *> IndexExprs) {
3815 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3816 // getSCEV(Base)->getType() has the same address space as Base->getType()
3817 // because SCEV::getType() preserves the address space.
3818 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3819 if (NW != GEPNoWrapFlags::none()) {
3820 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3821 // but to do that, we have to ensure that said flag is valid in the entire
3822 // defined scope of the SCEV.
3823 // TODO: non-instructions have global scope. We might be able to prove
3824 // some global scope cases
3825 auto *GEPI = dyn_cast<Instruction>(GEP);
3826 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3827 NW = GEPNoWrapFlags::none();
3828 }
3829
3830 return getGEPExpr(BaseExpr, IndexExprs, GEP->getSourceElementType(), NW);
3831}
3832
3834 ArrayRef<const SCEV *> IndexExprs,
3835 Type *SrcElementTy, GEPNoWrapFlags NW) {
3837 if (NW.hasNoUnsignedSignedWrap())
3838 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3839 if (NW.hasNoUnsignedWrap())
3840 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3841
3842 Type *CurTy = BaseExpr->getType();
3843 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3844 bool FirstIter = true;
3846 for (const SCEV *IndexExpr : IndexExprs) {
3847 // Compute the (potentially symbolic) offset in bytes for this index.
3848 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3849 // For a struct, add the member offset.
3850 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3851 unsigned FieldNo = Index->getZExtValue();
3852 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3853 Offsets.push_back(FieldOffset);
3854
3855 // Update CurTy to the type of the field at Index.
3856 CurTy = STy->getTypeAtIndex(Index);
3857 } else {
3858 // Update CurTy to its element type.
3859 if (FirstIter) {
3860 assert(isa<PointerType>(CurTy) &&
3861 "The first index of a GEP indexes a pointer");
3862 CurTy = SrcElementTy;
3863 FirstIter = false;
3864 } else {
3866 }
3867 // For an array, add the element offset, explicitly scaled.
3868 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3869 // Getelementptr indices are signed.
3870 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3871
3872 // Multiply the index by the element size to compute the element offset.
3873 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3874 Offsets.push_back(LocalOffset);
3875 }
3876 }
3877
3878 // Handle degenerate case of GEP without offsets.
3879 if (Offsets.empty())
3880 return BaseExpr;
3881
3882 // Add the offsets together, assuming nsw if inbounds.
3883 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3884 // Add the base address and the offset. We cannot use the nsw flag, as the
3885 // base address is unsigned. However, if we know that the offset is
3886 // non-negative, we can use nuw.
3887 bool NUW = NW.hasNoUnsignedWrap() ||
3890 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3891 assert(BaseExpr->getType() == GEPExpr->getType() &&
3892 "GEP should not change type mid-flight.");
3893 return GEPExpr;
3894}
3895
3896SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3899 ID.AddInteger(SCEVType);
3900 for (const SCEV *Op : Ops)
3901 ID.AddPointer(Op);
3902 void *IP = nullptr;
3903 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3904}
3905
3906const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3908 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3909}
3910
3913 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3914 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3915 if (Ops.size() == 1) return Ops[0];
3916#ifndef NDEBUG
3917 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3918 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3919 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3920 "Operand types don't match!");
3921 assert(Ops[0]->getType()->isPointerTy() ==
3922 Ops[i]->getType()->isPointerTy() &&
3923 "min/max should be consistently pointerish");
3924 }
3925#endif
3926
3927 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3928 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3929
3930 const SCEV *Folded = constantFoldAndGroupOps(
3931 *this, LI, DT, Ops,
3932 [&](const APInt &C1, const APInt &C2) {
3933 switch (Kind) {
3934 case scSMaxExpr:
3935 return APIntOps::smax(C1, C2);
3936 case scSMinExpr:
3937 return APIntOps::smin(C1, C2);
3938 case scUMaxExpr:
3939 return APIntOps::umax(C1, C2);
3940 case scUMinExpr:
3941 return APIntOps::umin(C1, C2);
3942 default:
3943 llvm_unreachable("Unknown SCEV min/max opcode");
3944 }
3945 },
3946 [&](const APInt &C) {
3947 // identity
3948 if (IsMax)
3949 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3950 else
3951 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3952 },
3953 [&](const APInt &C) {
3954 // absorber
3955 if (IsMax)
3956 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3957 else
3958 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3959 });
3960 if (Folded)
3961 return Folded;
3962
3963 // Check if we have created the same expression before.
3964 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3965 return S;
3966 }
3967
3968 // Find the first operation of the same kind
3969 unsigned Idx = 0;
3970 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3971 ++Idx;
3972
3973 // Check to see if one of the operands is of the same kind. If so, expand its
3974 // operands onto our operand list, and recurse to simplify.
3975 if (Idx < Ops.size()) {
3976 bool DeletedAny = false;
3977 while (Ops[Idx]->getSCEVType() == Kind) {
3978 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3979 Ops.erase(Ops.begin()+Idx);
3980 append_range(Ops, SMME->operands());
3981 DeletedAny = true;
3982 }
3983
3984 if (DeletedAny)
3985 return getMinMaxExpr(Kind, Ops);
3986 }
3987
3988 // Okay, check to see if the same value occurs in the operand list twice. If
3989 // so, delete one. Since we sorted the list, these values are required to
3990 // be adjacent.
3995 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3996 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3997 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3998 if (Ops[i] == Ops[i + 1] ||
3999 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
4000 // X op Y op Y --> X op Y
4001 // X op Y --> X, if we know X, Y are ordered appropriately
4002 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
4003 --i;
4004 --e;
4005 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
4006 Ops[i + 1])) {
4007 // X op Y --> Y, if we know X, Y are ordered appropriately
4008 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
4009 --i;
4010 --e;
4011 }
4012 }
4013
4014 if (Ops.size() == 1) return Ops[0];
4015
4016 assert(!Ops.empty() && "Reduced smax down to nothing!");
4017
4018 // Okay, it looks like we really DO need an expr. Check to see if we
4019 // already have one, otherwise create a new one.
4021 ID.AddInteger(Kind);
4022 for (const SCEV *Op : Ops)
4023 ID.AddPointer(Op);
4024 void *IP = nullptr;
4025 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4026 if (ExistingSCEV)
4027 return ExistingSCEV;
4028 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4030 SCEV *S = new (SCEVAllocator)
4031 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4032
4033 UniqueSCEVs.InsertNode(S, IP);
4034 registerUser(S, Ops);
4035 return S;
4036}
4037
4038namespace {
4039
4040class SCEVSequentialMinMaxDeduplicatingVisitor final
4041 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
4042 std::optional<const SCEV *>> {
4043 using RetVal = std::optional<const SCEV *>;
4045
4046 ScalarEvolution &SE;
4047 const SCEVTypes RootKind; // Must be a sequential min/max expression.
4048 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
4050
4051 bool canRecurseInto(SCEVTypes Kind) const {
4052 // We can only recurse into the SCEV expression of the same effective type
4053 // as the type of our root SCEV expression.
4054 return RootKind == Kind || NonSequentialRootKind == Kind;
4055 };
4056
4057 RetVal visitAnyMinMaxExpr(const SCEV *S) {
4059 "Only for min/max expressions.");
4060 SCEVTypes Kind = S->getSCEVType();
4061
4062 if (!canRecurseInto(Kind))
4063 return S;
4064
4065 auto *NAry = cast<SCEVNAryExpr>(S);
4067 bool Changed = visit(Kind, NAry->operands(), NewOps);
4068
4069 if (!Changed)
4070 return S;
4071 if (NewOps.empty())
4072 return std::nullopt;
4073
4075 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
4076 : SE.getMinMaxExpr(Kind, NewOps);
4077 }
4078
4079 RetVal visit(const SCEV *S) {
4080 // Has the whole operand been seen already?
4081 if (!SeenOps.insert(S).second)
4082 return std::nullopt;
4083 return Base::visit(S);
4084 }
4085
4086public:
4087 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4088 SCEVTypes RootKind)
4089 : SE(SE), RootKind(RootKind),
4090 NonSequentialRootKind(
4091 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4092 RootKind)) {}
4093
4094 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
4095 SmallVectorImpl<const SCEV *> &NewOps) {
4096 bool Changed = false;
4098 Ops.reserve(OrigOps.size());
4099
4100 for (const SCEV *Op : OrigOps) {
4101 RetVal NewOp = visit(Op);
4102 if (NewOp != Op)
4103 Changed = true;
4104 if (NewOp)
4105 Ops.emplace_back(*NewOp);
4106 }
4107
4108 if (Changed)
4109 NewOps = std::move(Ops);
4110 return Changed;
4111 }
4112
4113 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4114
4115 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4116
4117 RetVal visitPtrToAddrExpr(const SCEVPtrToAddrExpr *Expr) { return Expr; }
4118
4119 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4120
4121 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4122
4123 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4124
4125 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4126
4127 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4128
4129 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4130
4131 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4132
4133 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4134
4135 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4136 return visitAnyMinMaxExpr(Expr);
4137 }
4138
4139 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4140 return visitAnyMinMaxExpr(Expr);
4141 }
4142
4143 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4144 return visitAnyMinMaxExpr(Expr);
4145 }
4146
4147 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4148 return visitAnyMinMaxExpr(Expr);
4149 }
4150
4151 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4152 return visitAnyMinMaxExpr(Expr);
4153 }
4154
4155 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4156
4157 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4158};
4159
4160} // namespace
4161
4163 switch (Kind) {
4164 case scConstant:
4165 case scVScale:
4166 case scTruncate:
4167 case scZeroExtend:
4168 case scSignExtend:
4169 case scPtrToAddr:
4170 case scPtrToInt:
4171 case scAddExpr:
4172 case scMulExpr:
4173 case scUDivExpr:
4174 case scAddRecExpr:
4175 case scUMaxExpr:
4176 case scSMaxExpr:
4177 case scUMinExpr:
4178 case scSMinExpr:
4179 case scUnknown:
4180 // If any operand is poison, the whole expression is poison.
4181 return true;
4183 // FIXME: if the *first* operand is poison, the whole expression is poison.
4184 return false; // Pessimistically, say that it does not propagate poison.
4185 case scCouldNotCompute:
4186 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4187 }
4188 llvm_unreachable("Unknown SCEV kind!");
4189}
4190
4191namespace {
4192// The only way poison may be introduced in a SCEV expression is from a
4193// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4194// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4195// introduce poison -- they encode guaranteed, non-speculated knowledge.
4196//
4197// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4198// with the notable exception of umin_seq, where only poison from the first
4199// operand is (unconditionally) propagated.
4200struct SCEVPoisonCollector {
4201 bool LookThroughMaybePoisonBlocking;
4202 SmallPtrSet<const SCEVUnknown *, 4> MaybePoison;
4203 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4204 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4205
4206 bool follow(const SCEV *S) {
4207 if (!LookThroughMaybePoisonBlocking &&
4209 return false;
4210
4211 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4212 if (!isGuaranteedNotToBePoison(SU->getValue()))
4213 MaybePoison.insert(SU);
4214 }
4215 return true;
4216 }
4217 bool isDone() const { return false; }
4218};
4219} // namespace
4220
4221/// Return true if V is poison given that AssumedPoison is already poison.
4222static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4223 // First collect all SCEVs that might result in AssumedPoison to be poison.
4224 // We need to look through potentially poison-blocking operations here,
4225 // because we want to find all SCEVs that *might* result in poison, not only
4226 // those that are *required* to.
4227 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4228 visitAll(AssumedPoison, PC1);
4229
4230 // AssumedPoison is never poison. As the assumption is false, the implication
4231 // is true. Don't bother walking the other SCEV in this case.
4232 if (PC1.MaybePoison.empty())
4233 return true;
4234
4235 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4236 // as well. We cannot look through potentially poison-blocking operations
4237 // here, as their arguments only *may* make the result poison.
4238 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4239 visitAll(S, PC2);
4240
4241 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4242 // it will also make S poison by being part of PC2.MaybePoison.
4243 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison);
4244}
4245
4247 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4248 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4249 visitAll(S, PC);
4250 for (const SCEVUnknown *SU : PC.MaybePoison)
4251 Result.insert(SU->getValue());
4252}
4253
4255 const SCEV *S, Instruction *I,
4256 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4257 // If the instruction cannot be poison, it's always safe to reuse.
4259 return true;
4260
4261 // Otherwise, it is possible that I is more poisonous that S. Collect the
4262 // poison-contributors of S, and then check whether I has any additional
4263 // poison-contributors. Poison that is contributed through poison-generating
4264 // flags is handled by dropping those flags instead.
4266 getPoisonGeneratingValues(PoisonVals, S);
4267
4268 SmallVector<Value *> Worklist;
4270 Worklist.push_back(I);
4271 while (!Worklist.empty()) {
4272 Value *V = Worklist.pop_back_val();
4273 if (!Visited.insert(V).second)
4274 continue;
4275
4276 // Avoid walking large instruction graphs.
4277 if (Visited.size() > 16)
4278 return false;
4279
4280 // Either the value can't be poison, or the S would also be poison if it
4281 // is.
4282 if (PoisonVals.contains(V) || ::isGuaranteedNotToBePoison(V))
4283 continue;
4284
4285 auto *I = dyn_cast<Instruction>(V);
4286 if (!I)
4287 return false;
4288
4289 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4290 // can't replace an arbitrary add with disjoint or, even if we drop the
4291 // flag. We would need to convert the or into an add.
4292 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4293 if (PDI->isDisjoint())
4294 return false;
4295
4296 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4297 // because SCEV currently assumes it can't be poison. Remove this special
4298 // case once we proper model when vscale can be poison.
4299 if (auto *II = dyn_cast<IntrinsicInst>(I);
4300 II && II->getIntrinsicID() == Intrinsic::vscale)
4301 continue;
4302
4303 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4304 return false;
4305
4306 // If the instruction can't create poison, we can recurse to its operands.
4307 if (I->hasPoisonGeneratingAnnotations())
4308 DropPoisonGeneratingInsts.push_back(I);
4309
4310 llvm::append_range(Worklist, I->operands());
4311 }
4312 return true;
4313}
4314
4315const SCEV *
4318 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4319 "Not a SCEVSequentialMinMaxExpr!");
4320 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4321 if (Ops.size() == 1)
4322 return Ops[0];
4323#ifndef NDEBUG
4324 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4325 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4326 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4327 "Operand types don't match!");
4328 assert(Ops[0]->getType()->isPointerTy() ==
4329 Ops[i]->getType()->isPointerTy() &&
4330 "min/max should be consistently pointerish");
4331 }
4332#endif
4333
4334 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4335 // so we can *NOT* do any kind of sorting of the expressions!
4336
4337 // Check if we have created the same expression before.
4338 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4339 return S;
4340
4341 // FIXME: there are *some* simplifications that we can do here.
4342
4343 // Keep only the first instance of an operand.
4344 {
4345 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4346 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4347 if (Changed)
4348 return getSequentialMinMaxExpr(Kind, Ops);
4349 }
4350
4351 // Check to see if one of the operands is of the same kind. If so, expand its
4352 // operands onto our operand list, and recurse to simplify.
4353 {
4354 unsigned Idx = 0;
4355 bool DeletedAny = false;
4356 while (Idx < Ops.size()) {
4357 if (Ops[Idx]->getSCEVType() != Kind) {
4358 ++Idx;
4359 continue;
4360 }
4361 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4362 Ops.erase(Ops.begin() + Idx);
4363 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4364 SMME->operands().end());
4365 DeletedAny = true;
4366 }
4367
4368 if (DeletedAny)
4369 return getSequentialMinMaxExpr(Kind, Ops);
4370 }
4371
4372 const SCEV *SaturationPoint;
4374 switch (Kind) {
4376 SaturationPoint = getZero(Ops[0]->getType());
4377 Pred = ICmpInst::ICMP_ULE;
4378 break;
4379 default:
4380 llvm_unreachable("Not a sequential min/max type.");
4381 }
4382
4383 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4384 if (!isGuaranteedNotToCauseUB(Ops[i]))
4385 continue;
4386 // We can replace %x umin_seq %y with %x umin %y if either:
4387 // * %y being poison implies %x is also poison.
4388 // * %x cannot be the saturating value (e.g. zero for umin).
4389 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4390 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4391 SaturationPoint)) {
4392 SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
4393 Ops[i - 1] = getMinMaxExpr(
4395 SeqOps);
4396 Ops.erase(Ops.begin() + i);
4397 return getSequentialMinMaxExpr(Kind, Ops);
4398 }
4399 // Fold %x umin_seq %y to %x if %x ule %y.
4400 // TODO: We might be able to prove the predicate for a later operand.
4401 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4402 Ops.erase(Ops.begin() + i);
4403 return getSequentialMinMaxExpr(Kind, Ops);
4404 }
4405 }
4406
4407 // Okay, it looks like we really DO need an expr. Check to see if we
4408 // already have one, otherwise create a new one.
4410 ID.AddInteger(Kind);
4411 for (const SCEV *Op : Ops)
4412 ID.AddPointer(Op);
4413 void *IP = nullptr;
4414 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4415 if (ExistingSCEV)
4416 return ExistingSCEV;
4417
4418 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4420 SCEV *S = new (SCEVAllocator)
4421 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4422
4423 UniqueSCEVs.InsertNode(S, IP);
4424 registerUser(S, Ops);
4425 return S;
4426}
4427
4428const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4429 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4430 return getSMaxExpr(Ops);
4431}
4432
4436
4437const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4438 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4439 return getUMaxExpr(Ops);
4440}
4441
4445
4447 const SCEV *RHS) {
4448 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4449 return getSMinExpr(Ops);
4450}
4451
4455
4456const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS,
4457 bool Sequential) {
4458 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4459 return getUMinExpr(Ops, Sequential);
4460}
4461
4467
4468const SCEV *
4470 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4471 if (Size.isScalable())
4472 Res = getMulExpr(Res, getVScale(IntTy));
4473 return Res;
4474}
4475
4477 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4478}
4479
4481 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4482}
4483
4485 StructType *STy,
4486 unsigned FieldNo) {
4487 // We can bypass creating a target-independent constant expression and then
4488 // folding it back into a ConstantInt. This is just a compile-time
4489 // optimization.
4490 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4491 assert(!SL->getSizeInBits().isScalable() &&
4492 "Cannot get offset for structure containing scalable vector types");
4493 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4494}
4495
4497 // Don't attempt to do anything other than create a SCEVUnknown object
4498 // here. createSCEV only calls getUnknown after checking for all other
4499 // interesting possibilities, and any other code that calls getUnknown
4500 // is doing so in order to hide a value from SCEV canonicalization.
4501
4503 ID.AddInteger(scUnknown);
4504 ID.AddPointer(V);
4505 void *IP = nullptr;
4506 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4507 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4508 "Stale SCEVUnknown in uniquing map!");
4509 return S;
4510 }
4511 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4512 FirstUnknown);
4513 FirstUnknown = cast<SCEVUnknown>(S);
4514 UniqueSCEVs.InsertNode(S, IP);
4515 return S;
4516}
4517
4518//===----------------------------------------------------------------------===//
4519// Basic SCEV Analysis and PHI Idiom Recognition Code
4520//
4521
4522/// Test if values of the given type are analyzable within the SCEV
4523/// framework. This primarily includes integer types, and it can optionally
4524/// include pointer types if the ScalarEvolution class has access to
4525/// target-specific information.
4527 // Integers and pointers are always SCEVable.
4528 return Ty->isIntOrPtrTy();
4529}
4530
4531/// Return the size in bits of the specified type, for which isSCEVable must
4532/// return true.
4534 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4535 if (Ty->isPointerTy())
4537 return getDataLayout().getTypeSizeInBits(Ty);
4538}
4539
4540/// Return a type with the same bitwidth as the given type and which represents
4541/// how SCEV will treat the given type, for which isSCEVable must return
4542/// true. For pointer types, this is the pointer index sized integer type.
4544 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4545
4546 if (Ty->isIntegerTy())
4547 return Ty;
4548
4549 // The only other support type is pointer.
4550 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4551 return getDataLayout().getIndexType(Ty);
4552}
4553
4555 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4556}
4557
4559 const SCEV *B) {
4560 /// For a valid use point to exist, the defining scope of one operand
4561 /// must dominate the other.
4562 bool PreciseA, PreciseB;
4563 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4564 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4565 if (!PreciseA || !PreciseB)
4566 // Can't tell.
4567 return false;
4568 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4569 DT.dominates(ScopeB, ScopeA);
4570}
4571
4573 return CouldNotCompute.get();
4574}
4575
4576bool ScalarEvolution::checkValidity(const SCEV *S) const {
4577 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4578 auto *SU = dyn_cast<SCEVUnknown>(S);
4579 return SU && SU->getValue() == nullptr;
4580 });
4581
4582 return !ContainsNulls;
4583}
4584
4586 HasRecMapType::iterator I = HasRecMap.find(S);
4587 if (I != HasRecMap.end())
4588 return I->second;
4589
4590 bool FoundAddRec =
4591 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4592 HasRecMap.insert({S, FoundAddRec});
4593 return FoundAddRec;
4594}
4595
4596/// Return the ValueOffsetPair set for \p S. \p S can be represented
4597/// by the value and offset from any ValueOffsetPair in the set.
4598ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4599 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4600 if (SI == ExprValueMap.end())
4601 return {};
4602 return SI->second.getArrayRef();
4603}
4604
4605/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4606/// cannot be used separately. eraseValueFromMap should be used to remove
4607/// V from ValueExprMap and ExprValueMap at the same time.
4608void ScalarEvolution::eraseValueFromMap(Value *V) {
4609 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4610 if (I != ValueExprMap.end()) {
4611 auto EVIt = ExprValueMap.find(I->second);
4612 bool Removed = EVIt->second.remove(V);
4613 (void) Removed;
4614 assert(Removed && "Value not in ExprValueMap?");
4615 ValueExprMap.erase(I);
4616 }
4617}
4618
4619void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4620 // A recursive query may have already computed the SCEV. It should be
4621 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4622 // inferred nowrap flags.
4623 auto It = ValueExprMap.find_as(V);
4624 if (It == ValueExprMap.end()) {
4625 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4626 ExprValueMap[S].insert(V);
4627 }
4628}
4629
4630/// Return an existing SCEV if it exists, otherwise analyze the expression and
4631/// create a new one.
4633 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4634
4635 if (const SCEV *S = getExistingSCEV(V))
4636 return S;
4637 return createSCEVIter(V);
4638}
4639
4641 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4642
4643 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4644 if (I != ValueExprMap.end()) {
4645 const SCEV *S = I->second;
4646 assert(checkValidity(S) &&
4647 "existing SCEV has not been properly invalidated");
4648 return S;
4649 }
4650 return nullptr;
4651}
4652
4653/// Return a SCEV corresponding to -V = -1*V
4655 SCEV::NoWrapFlags Flags) {
4656 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4657 return getConstant(
4658 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4659
4660 Type *Ty = V->getType();
4661 Ty = getEffectiveSCEVType(Ty);
4662 return getMulExpr(V, getMinusOne(Ty), Flags);
4663}
4664
4665/// If Expr computes ~A, return A else return nullptr
4666static const SCEV *MatchNotExpr(const SCEV *Expr) {
4667 const SCEV *MulOp;
4668 if (match(Expr, m_scev_Add(m_scev_AllOnes(),
4669 m_scev_Mul(m_scev_AllOnes(), m_SCEV(MulOp)))))
4670 return MulOp;
4671 return nullptr;
4672}
4673
4674/// Return a SCEV corresponding to ~V = -1-V
4676 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4677
4678 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4679 return getConstant(
4680 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4681
4682 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4683 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4684 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4685 SmallVector<const SCEV *, 2> MatchedOperands;
4686 for (const SCEV *Operand : MME->operands()) {
4687 const SCEV *Matched = MatchNotExpr(Operand);
4688 if (!Matched)
4689 return (const SCEV *)nullptr;
4690 MatchedOperands.push_back(Matched);
4691 }
4692 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4693 MatchedOperands);
4694 };
4695 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4696 return Replaced;
4697 }
4698
4699 Type *Ty = V->getType();
4700 Ty = getEffectiveSCEVType(Ty);
4701 return getMinusSCEV(getMinusOne(Ty), V);
4702}
4703
4705 assert(P->getType()->isPointerTy());
4706
4707 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4708 // The base of an AddRec is the first operand.
4709 SmallVector<const SCEV *> Ops{AddRec->operands()};
4710 Ops[0] = removePointerBase(Ops[0]);
4711 // Don't try to transfer nowrap flags for now. We could in some cases
4712 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4713 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4714 }
4715 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4716 // The base of an Add is the pointer operand.
4717 SmallVector<const SCEV *> Ops{Add->operands()};
4718 const SCEV **PtrOp = nullptr;
4719 for (const SCEV *&AddOp : Ops) {
4720 if (AddOp->getType()->isPointerTy()) {
4721 assert(!PtrOp && "Cannot have multiple pointer ops");
4722 PtrOp = &AddOp;
4723 }
4724 }
4725 *PtrOp = removePointerBase(*PtrOp);
4726 // Don't try to transfer nowrap flags for now. We could in some cases
4727 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4728 return getAddExpr(Ops);
4729 }
4730 // Any other expression must be a pointer base.
4731 return getZero(P->getType());
4732}
4733
4734const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4735 SCEV::NoWrapFlags Flags,
4736 unsigned Depth) {
4737 // Fast path: X - X --> 0.
4738 if (LHS == RHS)
4739 return getZero(LHS->getType());
4740
4741 // If we subtract two pointers with different pointer bases, bail.
4742 // Eventually, we're going to add an assertion to getMulExpr that we
4743 // can't multiply by a pointer.
4744 if (RHS->getType()->isPointerTy()) {
4745 if (!LHS->getType()->isPointerTy() ||
4746 getPointerBase(LHS) != getPointerBase(RHS))
4747 return getCouldNotCompute();
4748 LHS = removePointerBase(LHS);
4749 RHS = removePointerBase(RHS);
4750 }
4751
4752 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4753 // makes it so that we cannot make much use of NUW.
4754 auto AddFlags = SCEV::FlagAnyWrap;
4755 const bool RHSIsNotMinSigned =
4757 if (hasFlags(Flags, SCEV::FlagNSW)) {
4758 // Let M be the minimum representable signed value. Then (-1)*RHS
4759 // signed-wraps if and only if RHS is M. That can happen even for
4760 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4761 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4762 // (-1)*RHS, we need to prove that RHS != M.
4763 //
4764 // If LHS is non-negative and we know that LHS - RHS does not
4765 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4766 // either by proving that RHS > M or that LHS >= 0.
4767 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4768 AddFlags = SCEV::FlagNSW;
4769 }
4770 }
4771
4772 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4773 // RHS is NSW and LHS >= 0.
4774 //
4775 // The difficulty here is that the NSW flag may have been proven
4776 // relative to a loop that is to be found in a recurrence in LHS and
4777 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4778 // larger scope than intended.
4779 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4780
4781 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4782}
4783
4785 unsigned Depth) {
4786 Type *SrcTy = V->getType();
4787 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4788 "Cannot truncate or zero extend with non-integer arguments!");
4789 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4790 return V; // No conversion
4791 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4792 return getTruncateExpr(V, Ty, Depth);
4793 return getZeroExtendExpr(V, Ty, Depth);
4794}
4795
4797 unsigned Depth) {
4798 Type *SrcTy = V->getType();
4799 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4800 "Cannot truncate or zero extend with non-integer arguments!");
4801 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4802 return V; // No conversion
4803 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4804 return getTruncateExpr(V, Ty, Depth);
4805 return getSignExtendExpr(V, Ty, Depth);
4806}
4807
4808const SCEV *
4810 Type *SrcTy = V->getType();
4811 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4812 "Cannot noop or zero extend with non-integer arguments!");
4814 "getNoopOrZeroExtend cannot truncate!");
4815 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4816 return V; // No conversion
4817 return getZeroExtendExpr(V, Ty);
4818}
4819
4820const SCEV *
4822 Type *SrcTy = V->getType();
4823 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4824 "Cannot noop or sign extend with non-integer arguments!");
4826 "getNoopOrSignExtend cannot truncate!");
4827 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4828 return V; // No conversion
4829 return getSignExtendExpr(V, Ty);
4830}
4831
4832const SCEV *
4834 Type *SrcTy = V->getType();
4835 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4836 "Cannot noop or any extend with non-integer arguments!");
4838 "getNoopOrAnyExtend cannot truncate!");
4839 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4840 return V; // No conversion
4841 return getAnyExtendExpr(V, Ty);
4842}
4843
4844const SCEV *
4846 Type *SrcTy = V->getType();
4847 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4848 "Cannot truncate or noop with non-integer arguments!");
4850 "getTruncateOrNoop cannot extend!");
4851 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4852 return V; // No conversion
4853 return getTruncateExpr(V, Ty);
4854}
4855
4857 const SCEV *RHS) {
4858 const SCEV *PromotedLHS = LHS;
4859 const SCEV *PromotedRHS = RHS;
4860
4861 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4862 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4863 else
4864 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4865
4866 return getUMaxExpr(PromotedLHS, PromotedRHS);
4867}
4868
4870 const SCEV *RHS,
4871 bool Sequential) {
4872 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4873 return getUMinFromMismatchedTypes(Ops, Sequential);
4874}
4875
4876const SCEV *
4878 bool Sequential) {
4879 assert(!Ops.empty() && "At least one operand must be!");
4880 // Trivial case.
4881 if (Ops.size() == 1)
4882 return Ops[0];
4883
4884 // Find the max type first.
4885 Type *MaxType = nullptr;
4886 for (const auto *S : Ops)
4887 if (MaxType)
4888 MaxType = getWiderType(MaxType, S->getType());
4889 else
4890 MaxType = S->getType();
4891 assert(MaxType && "Failed to find maximum type!");
4892
4893 // Extend all ops to max type.
4894 SmallVector<const SCEV *, 2> PromotedOps;
4895 for (const auto *S : Ops)
4896 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4897
4898 // Generate umin.
4899 return getUMinExpr(PromotedOps, Sequential);
4900}
4901
4903 // A pointer operand may evaluate to a nonpointer expression, such as null.
4904 if (!V->getType()->isPointerTy())
4905 return V;
4906
4907 while (true) {
4908 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4909 V = AddRec->getStart();
4910 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4911 const SCEV *PtrOp = nullptr;
4912 for (const SCEV *AddOp : Add->operands()) {
4913 if (AddOp->getType()->isPointerTy()) {
4914 assert(!PtrOp && "Cannot have multiple pointer ops");
4915 PtrOp = AddOp;
4916 }
4917 }
4918 assert(PtrOp && "Must have pointer op");
4919 V = PtrOp;
4920 } else // Not something we can look further into.
4921 return V;
4922 }
4923}
4924
4925/// Push users of the given Instruction onto the given Worklist.
4929 // Push the def-use children onto the Worklist stack.
4930 for (User *U : I->users()) {
4931 auto *UserInsn = cast<Instruction>(U);
4932 if (Visited.insert(UserInsn).second)
4933 Worklist.push_back(UserInsn);
4934 }
4935}
4936
4937namespace {
4938
4939/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4940/// expression in case its Loop is L. If it is not L then
4941/// if IgnoreOtherLoops is true then use AddRec itself
4942/// otherwise rewrite cannot be done.
4943/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4944class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4945public:
4946 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4947 bool IgnoreOtherLoops = true) {
4948 SCEVInitRewriter Rewriter(L, SE);
4949 const SCEV *Result = Rewriter.visit(S);
4950 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4951 return SE.getCouldNotCompute();
4952 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4953 ? SE.getCouldNotCompute()
4954 : Result;
4955 }
4956
4957 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4958 if (!SE.isLoopInvariant(Expr, L))
4959 SeenLoopVariantSCEVUnknown = true;
4960 return Expr;
4961 }
4962
4963 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4964 // Only re-write AddRecExprs for this loop.
4965 if (Expr->getLoop() == L)
4966 return Expr->getStart();
4967 SeenOtherLoops = true;
4968 return Expr;
4969 }
4970
4971 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4972
4973 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4974
4975private:
4976 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4977 : SCEVRewriteVisitor(SE), L(L) {}
4978
4979 const Loop *L;
4980 bool SeenLoopVariantSCEVUnknown = false;
4981 bool SeenOtherLoops = false;
4982};
4983
4984/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4985/// increment expression in case its Loop is L. If it is not L then
4986/// use AddRec itself.
4987/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4988class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4989public:
4990 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4991 SCEVPostIncRewriter Rewriter(L, SE);
4992 const SCEV *Result = Rewriter.visit(S);
4993 return Rewriter.hasSeenLoopVariantSCEVUnknown()
4994 ? SE.getCouldNotCompute()
4995 : Result;
4996 }
4997
4998 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4999 if (!SE.isLoopInvariant(Expr, L))
5000 SeenLoopVariantSCEVUnknown = true;
5001 return Expr;
5002 }
5003
5004 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5005 // Only re-write AddRecExprs for this loop.
5006 if (Expr->getLoop() == L)
5007 return Expr->getPostIncExpr(SE);
5008 SeenOtherLoops = true;
5009 return Expr;
5010 }
5011
5012 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
5013
5014 bool hasSeenOtherLoops() { return SeenOtherLoops; }
5015
5016private:
5017 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
5018 : SCEVRewriteVisitor(SE), L(L) {}
5019
5020 const Loop *L;
5021 bool SeenLoopVariantSCEVUnknown = false;
5022 bool SeenOtherLoops = false;
5023};
5024
5025/// This class evaluates the compare condition by matching it against the
5026/// condition of loop latch. If there is a match we assume a true value
5027/// for the condition while building SCEV nodes.
5028class SCEVBackedgeConditionFolder
5029 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
5030public:
5031 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5032 ScalarEvolution &SE) {
5033 bool IsPosBECond = false;
5034 Value *BECond = nullptr;
5035 if (BasicBlock *Latch = L->getLoopLatch()) {
5036 BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
5037 if (BI && BI->isConditional()) {
5038 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
5039 "Both outgoing branches should not target same header!");
5040 BECond = BI->getCondition();
5041 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
5042 } else {
5043 return S;
5044 }
5045 }
5046 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
5047 return Rewriter.visit(S);
5048 }
5049
5050 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5051 const SCEV *Result = Expr;
5052 bool InvariantF = SE.isLoopInvariant(Expr, L);
5053
5054 if (!InvariantF) {
5056 switch (I->getOpcode()) {
5057 case Instruction::Select: {
5058 SelectInst *SI = cast<SelectInst>(I);
5059 std::optional<const SCEV *> Res =
5060 compareWithBackedgeCondition(SI->getCondition());
5061 if (Res) {
5062 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
5063 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
5064 }
5065 break;
5066 }
5067 default: {
5068 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
5069 if (Res)
5070 Result = *Res;
5071 break;
5072 }
5073 }
5074 }
5075 return Result;
5076 }
5077
5078private:
5079 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5080 bool IsPosBECond, ScalarEvolution &SE)
5081 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5082 IsPositiveBECond(IsPosBECond) {}
5083
5084 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5085
5086 const Loop *L;
5087 /// Loop back condition.
5088 Value *BackedgeCond = nullptr;
5089 /// Set to true if loop back is on positive branch condition.
5090 bool IsPositiveBECond;
5091};
5092
5093std::optional<const SCEV *>
5094SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5095
5096 // If value matches the backedge condition for loop latch,
5097 // then return a constant evolution node based on loopback
5098 // branch taken.
5099 if (BackedgeCond == IC)
5100 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5102 return std::nullopt;
5103}
5104
5105class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5106public:
5107 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5108 ScalarEvolution &SE) {
5109 SCEVShiftRewriter Rewriter(L, SE);
5110 const SCEV *Result = Rewriter.visit(S);
5111 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5112 }
5113
5114 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5115 // Only allow AddRecExprs for this loop.
5116 if (!SE.isLoopInvariant(Expr, L))
5117 Valid = false;
5118 return Expr;
5119 }
5120
5121 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5122 if (Expr->getLoop() == L && Expr->isAffine())
5123 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5124 Valid = false;
5125 return Expr;
5126 }
5127
5128 bool isValid() { return Valid; }
5129
5130private:
5131 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5132 : SCEVRewriteVisitor(SE), L(L) {}
5133
5134 const Loop *L;
5135 bool Valid = true;
5136};
5137
5138} // end anonymous namespace
5139
5141ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5142 if (!AR->isAffine())
5143 return SCEV::FlagAnyWrap;
5144
5145 using OBO = OverflowingBinaryOperator;
5146
5148
5149 if (!AR->hasNoSelfWrap()) {
5150 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5151 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5152 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5153 const APInt &BECountAP = BECountMax->getAPInt();
5154 unsigned NoOverflowBitWidth =
5155 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5156 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5158 }
5159 }
5160
5161 if (!AR->hasNoSignedWrap()) {
5162 ConstantRange AddRecRange = getSignedRange(AR);
5163 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5164
5166 Instruction::Add, IncRange, OBO::NoSignedWrap);
5167 if (NSWRegion.contains(AddRecRange))
5169 }
5170
5171 if (!AR->hasNoUnsignedWrap()) {
5172 ConstantRange AddRecRange = getUnsignedRange(AR);
5173 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5174
5176 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5177 if (NUWRegion.contains(AddRecRange))
5179 }
5180
5181 return Result;
5182}
5183
5185ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5187
5188 if (AR->hasNoSignedWrap())
5189 return Result;
5190
5191 if (!AR->isAffine())
5192 return Result;
5193
5194 // This function can be expensive, only try to prove NSW once per AddRec.
5195 if (!SignedWrapViaInductionTried.insert(AR).second)
5196 return Result;
5197
5198 const SCEV *Step = AR->getStepRecurrence(*this);
5199 const Loop *L = AR->getLoop();
5200
5201 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5202 // Note that this serves two purposes: It filters out loops that are
5203 // simply not analyzable, and it covers the case where this code is
5204 // being called from within backedge-taken count analysis, such that
5205 // attempting to ask for the backedge-taken count would likely result
5206 // in infinite recursion. In the later case, the analysis code will
5207 // cope with a conservative value, and it will take care to purge
5208 // that value once it has finished.
5209 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5210
5211 // Normally, in the cases we can prove no-overflow via a
5212 // backedge guarding condition, we can also compute a backedge
5213 // taken count for the loop. The exceptions are assumptions and
5214 // guards present in the loop -- SCEV is not great at exploiting
5215 // these to compute max backedge taken counts, but can still use
5216 // these to prove lack of overflow. Use this fact to avoid
5217 // doing extra work that may not pay off.
5218
5219 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5220 AC.assumptions().empty())
5221 return Result;
5222
5223 // If the backedge is guarded by a comparison with the pre-inc value the
5224 // addrec is safe. Also, if the entry is guarded by a comparison with the
5225 // start value and the backedge is guarded by a comparison with the post-inc
5226 // value, the addrec is safe.
5228 const SCEV *OverflowLimit =
5229 getSignedOverflowLimitForStep(Step, &Pred, this);
5230 if (OverflowLimit &&
5231 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5232 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5233 Result = setFlags(Result, SCEV::FlagNSW);
5234 }
5235 return Result;
5236}
5238ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5240
5241 if (AR->hasNoUnsignedWrap())
5242 return Result;
5243
5244 if (!AR->isAffine())
5245 return Result;
5246
5247 // This function can be expensive, only try to prove NUW once per AddRec.
5248 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5249 return Result;
5250
5251 const SCEV *Step = AR->getStepRecurrence(*this);
5252 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5253 const Loop *L = AR->getLoop();
5254
5255 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5256 // Note that this serves two purposes: It filters out loops that are
5257 // simply not analyzable, and it covers the case where this code is
5258 // being called from within backedge-taken count analysis, such that
5259 // attempting to ask for the backedge-taken count would likely result
5260 // in infinite recursion. In the later case, the analysis code will
5261 // cope with a conservative value, and it will take care to purge
5262 // that value once it has finished.
5263 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5264
5265 // Normally, in the cases we can prove no-overflow via a
5266 // backedge guarding condition, we can also compute a backedge
5267 // taken count for the loop. The exceptions are assumptions and
5268 // guards present in the loop -- SCEV is not great at exploiting
5269 // these to compute max backedge taken counts, but can still use
5270 // these to prove lack of overflow. Use this fact to avoid
5271 // doing extra work that may not pay off.
5272
5273 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5274 AC.assumptions().empty())
5275 return Result;
5276
5277 // If the backedge is guarded by a comparison with the pre-inc value the
5278 // addrec is safe. Also, if the entry is guarded by a comparison with the
5279 // start value and the backedge is guarded by a comparison with the post-inc
5280 // value, the addrec is safe.
5281 if (isKnownPositive(Step)) {
5282 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
5283 getUnsignedRangeMax(Step));
5286 Result = setFlags(Result, SCEV::FlagNUW);
5287 }
5288 }
5289
5290 return Result;
5291}
5292
5293namespace {
5294
5295/// Represents an abstract binary operation. This may exist as a
5296/// normal instruction or constant expression, or may have been
5297/// derived from an expression tree.
5298struct BinaryOp {
5299 unsigned Opcode;
5300 Value *LHS;
5301 Value *RHS;
5302 bool IsNSW = false;
5303 bool IsNUW = false;
5304
5305 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5306 /// constant expression.
5307 Operator *Op = nullptr;
5308
5309 explicit BinaryOp(Operator *Op)
5310 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5311 Op(Op) {
5312 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5313 IsNSW = OBO->hasNoSignedWrap();
5314 IsNUW = OBO->hasNoUnsignedWrap();
5315 }
5316 }
5317
5318 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5319 bool IsNUW = false)
5320 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5321};
5322
5323} // end anonymous namespace
5324
5325/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5326static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5327 AssumptionCache &AC,
5328 const DominatorTree &DT,
5329 const Instruction *CxtI) {
5330 auto *Op = dyn_cast<Operator>(V);
5331 if (!Op)
5332 return std::nullopt;
5333
5334 // Implementation detail: all the cleverness here should happen without
5335 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5336 // SCEV expressions when possible, and we should not break that.
5337
5338 switch (Op->getOpcode()) {
5339 case Instruction::Add:
5340 case Instruction::Sub:
5341 case Instruction::Mul:
5342 case Instruction::UDiv:
5343 case Instruction::URem:
5344 case Instruction::And:
5345 case Instruction::AShr:
5346 case Instruction::Shl:
5347 return BinaryOp(Op);
5348
5349 case Instruction::Or: {
5350 // Convert or disjoint into add nuw nsw.
5351 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5352 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5353 /*IsNSW=*/true, /*IsNUW=*/true);
5354 return BinaryOp(Op);
5355 }
5356
5357 case Instruction::Xor:
5358 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5359 // If the RHS of the xor is a signmask, then this is just an add.
5360 // Instcombine turns add of signmask into xor as a strength reduction step.
5361 if (RHSC->getValue().isSignMask())
5362 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5363 // Binary `xor` is a bit-wise `add`.
5364 if (V->getType()->isIntegerTy(1))
5365 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5366 return BinaryOp(Op);
5367
5368 case Instruction::LShr:
5369 // Turn logical shift right of a constant into a unsigned divide.
5370 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5371 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5372
5373 // If the shift count is not less than the bitwidth, the result of
5374 // the shift is undefined. Don't try to analyze it, because the
5375 // resolution chosen here may differ from the resolution chosen in
5376 // other parts of the compiler.
5377 if (SA->getValue().ult(BitWidth)) {
5378 Constant *X =
5379 ConstantInt::get(SA->getContext(),
5380 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5381 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5382 }
5383 }
5384 return BinaryOp(Op);
5385
5386 case Instruction::ExtractValue: {
5387 auto *EVI = cast<ExtractValueInst>(Op);
5388 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5389 break;
5390
5391 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5392 if (!WO)
5393 break;
5394
5395 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5396 bool Signed = WO->isSigned();
5397 // TODO: Should add nuw/nsw flags for mul as well.
5398 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5399 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5400
5401 // Now that we know that all uses of the arithmetic-result component of
5402 // CI are guarded by the overflow check, we can go ahead and pretend
5403 // that the arithmetic is non-overflowing.
5404 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5405 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5406 }
5407
5408 default:
5409 break;
5410 }
5411
5412 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5413 // semantics as a Sub, return a binary sub expression.
5414 if (auto *II = dyn_cast<IntrinsicInst>(V))
5415 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5416 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5417
5418 return std::nullopt;
5419}
5420
5421/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5422/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5423/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5424/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5425/// follows one of the following patterns:
5426/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5427/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5428/// If the SCEV expression of \p Op conforms with one of the expected patterns
5429/// we return the type of the truncation operation, and indicate whether the
5430/// truncated type should be treated as signed/unsigned by setting
5431/// \p Signed to true/false, respectively.
5432static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5433 bool &Signed, ScalarEvolution &SE) {
5434 // The case where Op == SymbolicPHI (that is, with no type conversions on
5435 // the way) is handled by the regular add recurrence creating logic and
5436 // would have already been triggered in createAddRecForPHI. Reaching it here
5437 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5438 // because one of the other operands of the SCEVAddExpr updating this PHI is
5439 // not invariant).
5440 //
5441 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5442 // this case predicates that allow us to prove that Op == SymbolicPHI will
5443 // be added.
5444 if (Op == SymbolicPHI)
5445 return nullptr;
5446
5447 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5448 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5449 if (SourceBits != NewBits)
5450 return nullptr;
5451
5452 if (match(Op, m_scev_SExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5453 Signed = true;
5454 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5455 }
5456 if (match(Op, m_scev_ZExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5457 Signed = false;
5458 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5459 }
5460 return nullptr;
5461}
5462
5463static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5464 if (!PN->getType()->isIntegerTy())
5465 return nullptr;
5466 const Loop *L = LI.getLoopFor(PN->getParent());
5467 if (!L || L->getHeader() != PN->getParent())
5468 return nullptr;
5469 return L;
5470}
5471
5472// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5473// computation that updates the phi follows the following pattern:
5474// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5475// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5476// If so, try to see if it can be rewritten as an AddRecExpr under some
5477// Predicates. If successful, return them as a pair. Also cache the results
5478// of the analysis.
5479//
5480// Example usage scenario:
5481// Say the Rewriter is called for the following SCEV:
5482// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5483// where:
5484// %X = phi i64 (%Start, %BEValue)
5485// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5486// and call this function with %SymbolicPHI = %X.
5487//
5488// The analysis will find that the value coming around the backedge has
5489// the following SCEV:
5490// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5491// Upon concluding that this matches the desired pattern, the function
5492// will return the pair {NewAddRec, SmallPredsVec} where:
5493// NewAddRec = {%Start,+,%Step}
5494// SmallPredsVec = {P1, P2, P3} as follows:
5495// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5496// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5497// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5498// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5499// under the predicates {P1,P2,P3}.
5500// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5501// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5502//
5503// TODO's:
5504//
5505// 1) Extend the Induction descriptor to also support inductions that involve
5506// casts: When needed (namely, when we are called in the context of the
5507// vectorizer induction analysis), a Set of cast instructions will be
5508// populated by this method, and provided back to isInductionPHI. This is
5509// needed to allow the vectorizer to properly record them to be ignored by
5510// the cost model and to avoid vectorizing them (otherwise these casts,
5511// which are redundant under the runtime overflow checks, will be
5512// vectorized, which can be costly).
5513//
5514// 2) Support additional induction/PHISCEV patterns: We also want to support
5515// inductions where the sext-trunc / zext-trunc operations (partly) occur
5516// after the induction update operation (the induction increment):
5517//
5518// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5519// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5520//
5521// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5522// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5523//
5524// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5525std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5526ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5528
5529 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5530 // return an AddRec expression under some predicate.
5531
5532 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5533 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5534 assert(L && "Expecting an integer loop header phi");
5535
5536 // The loop may have multiple entrances or multiple exits; we can analyze
5537 // this phi as an addrec if it has a unique entry value and a unique
5538 // backedge value.
5539 Value *BEValueV = nullptr, *StartValueV = nullptr;
5540 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5541 Value *V = PN->getIncomingValue(i);
5542 if (L->contains(PN->getIncomingBlock(i))) {
5543 if (!BEValueV) {
5544 BEValueV = V;
5545 } else if (BEValueV != V) {
5546 BEValueV = nullptr;
5547 break;
5548 }
5549 } else if (!StartValueV) {
5550 StartValueV = V;
5551 } else if (StartValueV != V) {
5552 StartValueV = nullptr;
5553 break;
5554 }
5555 }
5556 if (!BEValueV || !StartValueV)
5557 return std::nullopt;
5558
5559 const SCEV *BEValue = getSCEV(BEValueV);
5560
5561 // If the value coming around the backedge is an add with the symbolic
5562 // value we just inserted, possibly with casts that we can ignore under
5563 // an appropriate runtime guard, then we found a simple induction variable!
5564 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5565 if (!Add)
5566 return std::nullopt;
5567
5568 // If there is a single occurrence of the symbolic value, possibly
5569 // casted, replace it with a recurrence.
5570 unsigned FoundIndex = Add->getNumOperands();
5571 Type *TruncTy = nullptr;
5572 bool Signed;
5573 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5574 if ((TruncTy =
5575 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5576 if (FoundIndex == e) {
5577 FoundIndex = i;
5578 break;
5579 }
5580
5581 if (FoundIndex == Add->getNumOperands())
5582 return std::nullopt;
5583
5584 // Create an add with everything but the specified operand.
5586 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5587 if (i != FoundIndex)
5588 Ops.push_back(Add->getOperand(i));
5589 const SCEV *Accum = getAddExpr(Ops);
5590
5591 // The runtime checks will not be valid if the step amount is
5592 // varying inside the loop.
5593 if (!isLoopInvariant(Accum, L))
5594 return std::nullopt;
5595
5596 // *** Part2: Create the predicates
5597
5598 // Analysis was successful: we have a phi-with-cast pattern for which we
5599 // can return an AddRec expression under the following predicates:
5600 //
5601 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5602 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5603 // P2: An Equal predicate that guarantees that
5604 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5605 // P3: An Equal predicate that guarantees that
5606 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5607 //
5608 // As we next prove, the above predicates guarantee that:
5609 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5610 //
5611 //
5612 // More formally, we want to prove that:
5613 // Expr(i+1) = Start + (i+1) * Accum
5614 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5615 //
5616 // Given that:
5617 // 1) Expr(0) = Start
5618 // 2) Expr(1) = Start + Accum
5619 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5620 // 3) Induction hypothesis (step i):
5621 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5622 //
5623 // Proof:
5624 // Expr(i+1) =
5625 // = Start + (i+1)*Accum
5626 // = (Start + i*Accum) + Accum
5627 // = Expr(i) + Accum
5628 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5629 // :: from step i
5630 //
5631 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5632 //
5633 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5634 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5635 // + Accum :: from P3
5636 //
5637 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5638 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5639 //
5640 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5641 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5642 //
5643 // By induction, the same applies to all iterations 1<=i<n:
5644 //
5645
5646 // Create a truncated addrec for which we will add a no overflow check (P1).
5647 const SCEV *StartVal = getSCEV(StartValueV);
5648 const SCEV *PHISCEV =
5649 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5650 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5651
5652 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5653 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5654 // will be constant.
5655 //
5656 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5657 // add P1.
5658 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5662 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5663 Predicates.push_back(AddRecPred);
5664 }
5665
5666 // Create the Equal Predicates P2,P3:
5667
5668 // It is possible that the predicates P2 and/or P3 are computable at
5669 // compile time due to StartVal and/or Accum being constants.
5670 // If either one is, then we can check that now and escape if either P2
5671 // or P3 is false.
5672
5673 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5674 // for each of StartVal and Accum
5675 auto getExtendedExpr = [&](const SCEV *Expr,
5676 bool CreateSignExtend) -> const SCEV * {
5677 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5678 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5679 const SCEV *ExtendedExpr =
5680 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5681 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5682 return ExtendedExpr;
5683 };
5684
5685 // Given:
5686 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5687 // = getExtendedExpr(Expr)
5688 // Determine whether the predicate P: Expr == ExtendedExpr
5689 // is known to be false at compile time
5690 auto PredIsKnownFalse = [&](const SCEV *Expr,
5691 const SCEV *ExtendedExpr) -> bool {
5692 return Expr != ExtendedExpr &&
5693 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5694 };
5695
5696 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5697 if (PredIsKnownFalse(StartVal, StartExtended)) {
5698 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5699 return std::nullopt;
5700 }
5701
5702 // The Step is always Signed (because the overflow checks are either
5703 // NSSW or NUSW)
5704 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5705 if (PredIsKnownFalse(Accum, AccumExtended)) {
5706 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5707 return std::nullopt;
5708 }
5709
5710 auto AppendPredicate = [&](const SCEV *Expr,
5711 const SCEV *ExtendedExpr) -> void {
5712 if (Expr != ExtendedExpr &&
5713 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5714 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5715 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5716 Predicates.push_back(Pred);
5717 }
5718 };
5719
5720 AppendPredicate(StartVal, StartExtended);
5721 AppendPredicate(Accum, AccumExtended);
5722
5723 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5724 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5725 // into NewAR if it will also add the runtime overflow checks specified in
5726 // Predicates.
5727 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5728
5729 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5730 std::make_pair(NewAR, Predicates);
5731 // Remember the result of the analysis for this SCEV at this locayyytion.
5732 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5733 return PredRewrite;
5734}
5735
5736std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5738 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5739 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5740 if (!L)
5741 return std::nullopt;
5742
5743 // Check to see if we already analyzed this PHI.
5744 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5745 if (I != PredicatedSCEVRewrites.end()) {
5746 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5747 I->second;
5748 // Analysis was done before and failed to create an AddRec:
5749 if (Rewrite.first == SymbolicPHI)
5750 return std::nullopt;
5751 // Analysis was done before and succeeded to create an AddRec under
5752 // a predicate:
5753 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5754 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5755 return Rewrite;
5756 }
5757
5758 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5759 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5760
5761 // Record in the cache that the analysis failed
5762 if (!Rewrite) {
5764 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5765 return std::nullopt;
5766 }
5767
5768 return Rewrite;
5769}
5770
5771// FIXME: This utility is currently required because the Rewriter currently
5772// does not rewrite this expression:
5773// {0, +, (sext ix (trunc iy to ix) to iy)}
5774// into {0, +, %step},
5775// even when the following Equal predicate exists:
5776// "%step == (sext ix (trunc iy to ix) to iy)".
5778 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5779 if (AR1 == AR2)
5780 return true;
5781
5782 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5783 if (Expr1 != Expr2 &&
5784 !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5785 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5786 return false;
5787 return true;
5788 };
5789
5790 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5791 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5792 return false;
5793 return true;
5794}
5795
5796/// A helper function for createAddRecFromPHI to handle simple cases.
5797///
5798/// This function tries to find an AddRec expression for the simplest (yet most
5799/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5800/// If it fails, createAddRecFromPHI will use a more general, but slow,
5801/// technique for finding the AddRec expression.
5802const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5803 Value *BEValueV,
5804 Value *StartValueV) {
5805 const Loop *L = LI.getLoopFor(PN->getParent());
5806 assert(L && L->getHeader() == PN->getParent());
5807 assert(BEValueV && StartValueV);
5808
5809 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5810 if (!BO)
5811 return nullptr;
5812
5813 if (BO->Opcode != Instruction::Add)
5814 return nullptr;
5815
5816 const SCEV *Accum = nullptr;
5817 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5818 Accum = getSCEV(BO->RHS);
5819 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5820 Accum = getSCEV(BO->LHS);
5821
5822 if (!Accum)
5823 return nullptr;
5824
5826 if (BO->IsNUW)
5827 Flags = setFlags(Flags, SCEV::FlagNUW);
5828 if (BO->IsNSW)
5829 Flags = setFlags(Flags, SCEV::FlagNSW);
5830
5831 const SCEV *StartVal = getSCEV(StartValueV);
5832 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5833 insertValueToMap(PN, PHISCEV);
5834
5835 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5836 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5838 proveNoWrapViaConstantRanges(AR)));
5839 }
5840
5841 // We can add Flags to the post-inc expression only if we
5842 // know that it is *undefined behavior* for BEValueV to
5843 // overflow.
5844 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5845 assert(isLoopInvariant(Accum, L) &&
5846 "Accum is defined outside L, but is not invariant?");
5847 if (isAddRecNeverPoison(BEInst, L))
5848 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5849 }
5850
5851 return PHISCEV;
5852}
5853
5854const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5855 const Loop *L = LI.getLoopFor(PN->getParent());
5856 if (!L || L->getHeader() != PN->getParent())
5857 return nullptr;
5858
5859 // The loop may have multiple entrances or multiple exits; we can analyze
5860 // this phi as an addrec if it has a unique entry value and a unique
5861 // backedge value.
5862 Value *BEValueV = nullptr, *StartValueV = nullptr;
5863 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5864 Value *V = PN->getIncomingValue(i);
5865 if (L->contains(PN->getIncomingBlock(i))) {
5866 if (!BEValueV) {
5867 BEValueV = V;
5868 } else if (BEValueV != V) {
5869 BEValueV = nullptr;
5870 break;
5871 }
5872 } else if (!StartValueV) {
5873 StartValueV = V;
5874 } else if (StartValueV != V) {
5875 StartValueV = nullptr;
5876 break;
5877 }
5878 }
5879 if (!BEValueV || !StartValueV)
5880 return nullptr;
5881
5882 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5883 "PHI node already processed?");
5884
5885 // First, try to find AddRec expression without creating a fictituos symbolic
5886 // value for PN.
5887 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5888 return S;
5889
5890 // Handle PHI node value symbolically.
5891 const SCEV *SymbolicName = getUnknown(PN);
5892 insertValueToMap(PN, SymbolicName);
5893
5894 // Using this symbolic name for the PHI, analyze the value coming around
5895 // the back-edge.
5896 const SCEV *BEValue = getSCEV(BEValueV);
5897
5898 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5899 // has a special value for the first iteration of the loop.
5900
5901 // If the value coming around the backedge is an add with the symbolic
5902 // value we just inserted, then we found a simple induction variable!
5903 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5904 // If there is a single occurrence of the symbolic value, replace it
5905 // with a recurrence.
5906 unsigned FoundIndex = Add->getNumOperands();
5907 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5908 if (Add->getOperand(i) == SymbolicName)
5909 if (FoundIndex == e) {
5910 FoundIndex = i;
5911 break;
5912 }
5913
5914 if (FoundIndex != Add->getNumOperands()) {
5915 // Create an add with everything but the specified operand.
5917 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5918 if (i != FoundIndex)
5919 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5920 L, *this));
5921 const SCEV *Accum = getAddExpr(Ops);
5922
5923 // This is not a valid addrec if the step amount is varying each
5924 // loop iteration, but is not itself an addrec in this loop.
5925 if (isLoopInvariant(Accum, L) ||
5926 (isa<SCEVAddRecExpr>(Accum) &&
5927 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5929
5930 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
5931 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5932 if (BO->IsNUW)
5933 Flags = setFlags(Flags, SCEV::FlagNUW);
5934 if (BO->IsNSW)
5935 Flags = setFlags(Flags, SCEV::FlagNSW);
5936 }
5937 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5938 if (GEP->getOperand(0) == PN) {
5939 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
5940 // If the increment has any nowrap flags, then we know the address
5941 // space cannot be wrapped around.
5942 if (NW != GEPNoWrapFlags::none())
5943 Flags = setFlags(Flags, SCEV::FlagNW);
5944 // If the GEP is nuw or nusw with non-negative offset, we know that
5945 // no unsigned wrap occurs. We cannot set the nsw flag as only the
5946 // offset is treated as signed, while the base is unsigned.
5947 if (NW.hasNoUnsignedWrap() ||
5949 Flags = setFlags(Flags, SCEV::FlagNUW);
5950 }
5951
5952 // We cannot transfer nuw and nsw flags from subtraction
5953 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5954 // for instance.
5955 }
5956
5957 const SCEV *StartVal = getSCEV(StartValueV);
5958 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5959
5960 // Okay, for the entire analysis of this edge we assumed the PHI
5961 // to be symbolic. We now need to go back and purge all of the
5962 // entries for the scalars that use the symbolic expression.
5963 forgetMemoizedResults(SymbolicName);
5964 insertValueToMap(PN, PHISCEV);
5965
5966 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5967 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5969 proveNoWrapViaConstantRanges(AR)));
5970 }
5971
5972 // We can add Flags to the post-inc expression only if we
5973 // know that it is *undefined behavior* for BEValueV to
5974 // overflow.
5975 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5976 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5977 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5978
5979 return PHISCEV;
5980 }
5981 }
5982 } else {
5983 // Otherwise, this could be a loop like this:
5984 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
5985 // In this case, j = {1,+,1} and BEValue is j.
5986 // Because the other in-value of i (0) fits the evolution of BEValue
5987 // i really is an addrec evolution.
5988 //
5989 // We can generalize this saying that i is the shifted value of BEValue
5990 // by one iteration:
5991 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
5992
5993 // Do not allow refinement in rewriting of BEValue.
5994 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
5995 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
5996 if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() &&
5997 isGuaranteedNotToCauseUB(Shifted) && ::impliesPoison(Shifted, Start)) {
5998 const SCEV *StartVal = getSCEV(StartValueV);
5999 if (Start == StartVal) {
6000 // Okay, for the entire analysis of this edge we assumed the PHI
6001 // to be symbolic. We now need to go back and purge all of the
6002 // entries for the scalars that use the symbolic expression.
6003 forgetMemoizedResults(SymbolicName);
6004 insertValueToMap(PN, Shifted);
6005 return Shifted;
6006 }
6007 }
6008 }
6009
6010 // Remove the temporary PHI node SCEV that has been inserted while intending
6011 // to create an AddRecExpr for this PHI node. We can not keep this temporary
6012 // as it will prevent later (possibly simpler) SCEV expressions to be added
6013 // to the ValueExprMap.
6014 eraseValueFromMap(PN);
6015
6016 return nullptr;
6017}
6018
6019// Try to match a control flow sequence that branches out at BI and merges back
6020// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
6021// match.
6023 Value *&C, Value *&LHS, Value *&RHS) {
6024 C = BI->getCondition();
6025
6026 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
6027 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
6028
6029 if (!LeftEdge.isSingleEdge())
6030 return false;
6031
6032 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
6033
6034 Use &LeftUse = Merge->getOperandUse(0);
6035 Use &RightUse = Merge->getOperandUse(1);
6036
6037 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
6038 LHS = LeftUse;
6039 RHS = RightUse;
6040 return true;
6041 }
6042
6043 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
6044 LHS = RightUse;
6045 RHS = LeftUse;
6046 return true;
6047 }
6048
6049 return false;
6050}
6051
6052const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
6053 auto IsReachable =
6054 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
6055 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
6056 // Try to match
6057 //
6058 // br %cond, label %left, label %right
6059 // left:
6060 // br label %merge
6061 // right:
6062 // br label %merge
6063 // merge:
6064 // V = phi [ %x, %left ], [ %y, %right ]
6065 //
6066 // as "select %cond, %x, %y"
6067
6068 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
6069 assert(IDom && "At least the entry block should dominate PN");
6070
6071 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
6072 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6073
6074 if (BI && BI->isConditional() &&
6075 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
6078 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6079 }
6080
6081 return nullptr;
6082}
6083
6084/// Returns SCEV for the first operand of a phi if all phi operands have
6085/// identical opcodes and operands
6086/// eg.
6087/// a: %add = %a + %b
6088/// br %c
6089/// b: %add1 = %a + %b
6090/// br %c
6091/// c: %phi = phi [%add, a], [%add1, b]
6092/// scev(%phi) => scev(%add)
6093const SCEV *
6094ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) {
6095 BinaryOperator *CommonInst = nullptr;
6096 // Check if instructions are identical.
6097 for (Value *Incoming : PN->incoming_values()) {
6098 auto *IncomingInst = dyn_cast<BinaryOperator>(Incoming);
6099 if (!IncomingInst)
6100 return nullptr;
6101 if (CommonInst) {
6102 if (!CommonInst->isIdenticalToWhenDefined(IncomingInst))
6103 return nullptr; // Not identical, give up
6104 } else {
6105 // Remember binary operator
6106 CommonInst = IncomingInst;
6107 }
6108 }
6109 if (!CommonInst)
6110 return nullptr;
6111
6112 // Check if SCEV exprs for instructions are identical.
6113 const SCEV *CommonSCEV = getSCEV(CommonInst);
6114 bool SCEVExprsIdentical =
6116 [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); });
6117 return SCEVExprsIdentical ? CommonSCEV : nullptr;
6118}
6119
6120const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6121 if (const SCEV *S = createAddRecFromPHI(PN))
6122 return S;
6123
6124 // We do not allow simplifying phi (undef, X) to X here, to avoid reusing the
6125 // phi node for X.
6126 if (Value *V = simplifyInstruction(
6127 PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
6128 /*UseInstrInfo=*/true, /*CanUseUndef=*/false}))
6129 return getSCEV(V);
6130
6131 if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN))
6132 return S;
6133
6134 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6135 return S;
6136
6137 // If it's not a loop phi, we can't handle it yet.
6138 return getUnknown(PN);
6139}
6140
6141bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6142 SCEVTypes RootKind) {
6143 struct FindClosure {
6144 const SCEV *OperandToFind;
6145 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6146 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6147
6148 bool Found = false;
6149
6150 bool canRecurseInto(SCEVTypes Kind) const {
6151 // We can only recurse into the SCEV expression of the same effective type
6152 // as the type of our root SCEV expression, and into zero-extensions.
6153 return RootKind == Kind || NonSequentialRootKind == Kind ||
6154 scZeroExtend == Kind;
6155 };
6156
6157 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6158 : OperandToFind(OperandToFind), RootKind(RootKind),
6159 NonSequentialRootKind(
6161 RootKind)) {}
6162
6163 bool follow(const SCEV *S) {
6164 Found = S == OperandToFind;
6165
6166 return !isDone() && canRecurseInto(S->getSCEVType());
6167 }
6168
6169 bool isDone() const { return Found; }
6170 };
6171
6172 FindClosure FC(OperandToFind, RootKind);
6173 visitAll(Root, FC);
6174 return FC.Found;
6175}
6176
6177std::optional<const SCEV *>
6178ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6179 ICmpInst *Cond,
6180 Value *TrueVal,
6181 Value *FalseVal) {
6182 // Try to match some simple smax or umax patterns.
6183 auto *ICI = Cond;
6184
6185 Value *LHS = ICI->getOperand(0);
6186 Value *RHS = ICI->getOperand(1);
6187
6188 switch (ICI->getPredicate()) {
6189 case ICmpInst::ICMP_SLT:
6190 case ICmpInst::ICMP_SLE:
6191 case ICmpInst::ICMP_ULT:
6192 case ICmpInst::ICMP_ULE:
6193 std::swap(LHS, RHS);
6194 [[fallthrough]];
6195 case ICmpInst::ICMP_SGT:
6196 case ICmpInst::ICMP_SGE:
6197 case ICmpInst::ICMP_UGT:
6198 case ICmpInst::ICMP_UGE:
6199 // a > b ? a+x : b+x -> max(a, b)+x
6200 // a > b ? b+x : a+x -> min(a, b)+x
6202 bool Signed = ICI->isSigned();
6203 const SCEV *LA = getSCEV(TrueVal);
6204 const SCEV *RA = getSCEV(FalseVal);
6205 const SCEV *LS = getSCEV(LHS);
6206 const SCEV *RS = getSCEV(RHS);
6207 if (LA->getType()->isPointerTy()) {
6208 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6209 // Need to make sure we can't produce weird expressions involving
6210 // negated pointers.
6211 if (LA == LS && RA == RS)
6212 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6213 if (LA == RS && RA == LS)
6214 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6215 }
6216 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6217 if (Op->getType()->isPointerTy()) {
6220 return Op;
6221 }
6222 if (Signed)
6223 Op = getNoopOrSignExtend(Op, Ty);
6224 else
6225 Op = getNoopOrZeroExtend(Op, Ty);
6226 return Op;
6227 };
6228 LS = CoerceOperand(LS);
6229 RS = CoerceOperand(RS);
6231 break;
6232 const SCEV *LDiff = getMinusSCEV(LA, LS);
6233 const SCEV *RDiff = getMinusSCEV(RA, RS);
6234 if (LDiff == RDiff)
6235 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6236 LDiff);
6237 LDiff = getMinusSCEV(LA, RS);
6238 RDiff = getMinusSCEV(RA, LS);
6239 if (LDiff == RDiff)
6240 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6241 LDiff);
6242 }
6243 break;
6244 case ICmpInst::ICMP_NE:
6245 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6246 std::swap(TrueVal, FalseVal);
6247 [[fallthrough]];
6248 case ICmpInst::ICMP_EQ:
6249 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6252 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6253 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6254 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6255 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6256 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6257 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6258 return getAddExpr(getUMaxExpr(X, C), Y);
6259 }
6260 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6261 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6262 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6263 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6265 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6266 const SCEV *X = getSCEV(LHS);
6267 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6268 X = ZExt->getOperand();
6269 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6270 const SCEV *FalseValExpr = getSCEV(FalseVal);
6271 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6272 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6273 /*Sequential=*/true);
6274 }
6275 }
6276 break;
6277 default:
6278 break;
6279 }
6280
6281 return std::nullopt;
6282}
6283
6284static std::optional<const SCEV *>
6286 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6287 assert(CondExpr->getType()->isIntegerTy(1) &&
6288 TrueExpr->getType() == FalseExpr->getType() &&
6289 TrueExpr->getType()->isIntegerTy(1) &&
6290 "Unexpected operands of a select.");
6291
6292 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6293 // --> C + (umin_seq cond, x - C)
6294 //
6295 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6296 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6297 // --> C + (umin_seq ~cond, x - C)
6298
6299 // FIXME: while we can't legally model the case where both of the hands
6300 // are fully variable, we only require that the *difference* is constant.
6301 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6302 return std::nullopt;
6303
6304 const SCEV *X, *C;
6305 if (isa<SCEVConstant>(TrueExpr)) {
6306 CondExpr = SE->getNotSCEV(CondExpr);
6307 X = FalseExpr;
6308 C = TrueExpr;
6309 } else {
6310 X = TrueExpr;
6311 C = FalseExpr;
6312 }
6313 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6314 /*Sequential=*/true));
6315}
6316
6317static std::optional<const SCEV *>
6319 Value *FalseVal) {
6320 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6321 return std::nullopt;
6322
6323 const auto *SECond = SE->getSCEV(Cond);
6324 const auto *SETrue = SE->getSCEV(TrueVal);
6325 const auto *SEFalse = SE->getSCEV(FalseVal);
6326 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6327}
6328
6329const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6330 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6331 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6332 assert(TrueVal->getType() == FalseVal->getType() &&
6333 V->getType() == TrueVal->getType() &&
6334 "Types of select hands and of the result must match.");
6335
6336 // For now, only deal with i1-typed `select`s.
6337 if (!V->getType()->isIntegerTy(1))
6338 return getUnknown(V);
6339
6340 if (std::optional<const SCEV *> S =
6341 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6342 return *S;
6343
6344 return getUnknown(V);
6345}
6346
6347const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6348 Value *TrueVal,
6349 Value *FalseVal) {
6350 // Handle "constant" branch or select. This can occur for instance when a
6351 // loop pass transforms an inner loop and moves on to process the outer loop.
6352 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6353 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6354
6355 if (auto *I = dyn_cast<Instruction>(V)) {
6356 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6357 if (std::optional<const SCEV *> S =
6358 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6359 TrueVal, FalseVal))
6360 return *S;
6361 }
6362 }
6363
6364 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6365}
6366
6367/// Expand GEP instructions into add and multiply operations. This allows them
6368/// to be analyzed by regular SCEV code.
6369const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6370 assert(GEP->getSourceElementType()->isSized() &&
6371 "GEP source element type must be sized");
6372
6374 for (Value *Index : GEP->indices())
6375 IndexExprs.push_back(getSCEV(Index));
6376 return getGEPExpr(GEP, IndexExprs);
6377}
6378
6379APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S,
6380 const Instruction *CtxI) {
6381 uint64_t BitWidth = getTypeSizeInBits(S->getType());
6382 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6383 return TrailingZeros >= BitWidth
6385 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6386 };
6387 auto GetGCDMultiple = [this, CtxI](const SCEVNAryExpr *N) {
6388 // The result is GCD of all operands results.
6389 APInt Res = getConstantMultiple(N->getOperand(0), CtxI);
6390 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6392 Res, getConstantMultiple(N->getOperand(I), CtxI));
6393 return Res;
6394 };
6395
6396 switch (S->getSCEVType()) {
6397 case scConstant:
6398 return cast<SCEVConstant>(S)->getAPInt();
6399 case scPtrToAddr:
6400 case scPtrToInt:
6401 return getConstantMultiple(cast<SCEVCastExpr>(S)->getOperand());
6402 case scUDivExpr:
6403 case scVScale:
6404 return APInt(BitWidth, 1);
6405 case scTruncate: {
6406 // Only multiples that are a power of 2 will hold after truncation.
6407 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6408 uint32_t TZ = getMinTrailingZeros(T->getOperand(), CtxI);
6409 return GetShiftedByZeros(TZ);
6410 }
6411 case scZeroExtend: {
6412 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6413 return getConstantMultiple(Z->getOperand(), CtxI).zext(BitWidth);
6414 }
6415 case scSignExtend: {
6416 // Only multiples that are a power of 2 will hold after sext.
6417 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6418 uint32_t TZ = getMinTrailingZeros(E->getOperand(), CtxI);
6419 return GetShiftedByZeros(TZ);
6420 }
6421 case scMulExpr: {
6422 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6423 if (M->hasNoUnsignedWrap()) {
6424 // The result is the product of all operand results.
6425 APInt Res = getConstantMultiple(M->getOperand(0), CtxI);
6426 for (const SCEV *Operand : M->operands().drop_front())
6427 Res = Res * getConstantMultiple(Operand, CtxI);
6428 return Res;
6429 }
6430
6431 // If there are no wrap guarentees, find the trailing zeros, which is the
6432 // sum of trailing zeros for all its operands.
6433 uint32_t TZ = 0;
6434 for (const SCEV *Operand : M->operands())
6435 TZ += getMinTrailingZeros(Operand, CtxI);
6436 return GetShiftedByZeros(TZ);
6437 }
6438 case scAddExpr:
6439 case scAddRecExpr: {
6440 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6441 if (N->hasNoUnsignedWrap())
6442 return GetGCDMultiple(N);
6443 // Find the trailing bits, which is the minimum of its operands.
6444 uint32_t TZ = getMinTrailingZeros(N->getOperand(0), CtxI);
6445 for (const SCEV *Operand : N->operands().drop_front())
6446 TZ = std::min(TZ, getMinTrailingZeros(Operand, CtxI));
6447 return GetShiftedByZeros(TZ);
6448 }
6449 case scUMaxExpr:
6450 case scSMaxExpr:
6451 case scUMinExpr:
6452 case scSMinExpr:
6454 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6455 case scUnknown: {
6456 // Ask ValueTracking for known bits. SCEVUnknown only become available at
6457 // the point their underlying IR instruction has been defined. If CtxI was
6458 // not provided, use:
6459 // * the first instruction in the entry block if it is an argument
6460 // * the instruction itself otherwise.
6461 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6462 if (!CtxI) {
6463 if (isa<Argument>(U->getValue()))
6464 CtxI = &*F.getEntryBlock().begin();
6465 else if (auto *I = dyn_cast<Instruction>(U->getValue()))
6466 CtxI = I;
6467 }
6468 unsigned Known =
6469 computeKnownBits(U->getValue(), getDataLayout(), &AC, CtxI, &DT)
6470 .countMinTrailingZeros();
6471 return GetShiftedByZeros(Known);
6472 }
6473 case scCouldNotCompute:
6474 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6475 }
6476 llvm_unreachable("Unknown SCEV kind!");
6477}
6478
6480 const Instruction *CtxI) {
6481 // Skip looking up and updating the cache if there is a context instruction,
6482 // as the result will only be valid in the specified context.
6483 if (CtxI)
6484 return getConstantMultipleImpl(S, CtxI);
6485
6486 auto I = ConstantMultipleCache.find(S);
6487 if (I != ConstantMultipleCache.end())
6488 return I->second;
6489
6490 APInt Result = getConstantMultipleImpl(S, CtxI);
6491 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6492 assert(InsertPair.second && "Should insert a new key");
6493 return InsertPair.first->second;
6494}
6495
6497 APInt Multiple = getConstantMultiple(S);
6498 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6499}
6500
6502 const Instruction *CtxI) {
6503 return std::min(getConstantMultiple(S, CtxI).countTrailingZeros(),
6504 (unsigned)getTypeSizeInBits(S->getType()));
6505}
6506
6507/// Helper method to assign a range to V from metadata present in the IR.
6508static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6510 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6511 return getConstantRangeFromMetadata(*MD);
6512 if (const auto *CB = dyn_cast<CallBase>(V))
6513 if (std::optional<ConstantRange> Range = CB->getRange())
6514 return Range;
6515 }
6516 if (auto *A = dyn_cast<Argument>(V))
6517 if (std::optional<ConstantRange> Range = A->getRange())
6518 return Range;
6519
6520 return std::nullopt;
6521}
6522
6524 SCEV::NoWrapFlags Flags) {
6525 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6526 AddRec->setNoWrapFlags(Flags);
6527 UnsignedRanges.erase(AddRec);
6528 SignedRanges.erase(AddRec);
6529 ConstantMultipleCache.erase(AddRec);
6530 }
6531}
6532
6533ConstantRange ScalarEvolution::
6534getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6535 const DataLayout &DL = getDataLayout();
6536
6537 unsigned BitWidth = getTypeSizeInBits(U->getType());
6538 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6539
6540 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6541 // use information about the trip count to improve our available range. Note
6542 // that the trip count independent cases are already handled by known bits.
6543 // WARNING: The definition of recurrence used here is subtly different than
6544 // the one used by AddRec (and thus most of this file). Step is allowed to
6545 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6546 // and other addrecs in the same loop (for non-affine addrecs). The code
6547 // below intentionally handles the case where step is not loop invariant.
6548 auto *P = dyn_cast<PHINode>(U->getValue());
6549 if (!P)
6550 return FullSet;
6551
6552 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6553 // even the values that are not available in these blocks may come from them,
6554 // and this leads to false-positive recurrence test.
6555 for (auto *Pred : predecessors(P->getParent()))
6556 if (!DT.isReachableFromEntry(Pred))
6557 return FullSet;
6558
6559 BinaryOperator *BO;
6560 Value *Start, *Step;
6561 if (!matchSimpleRecurrence(P, BO, Start, Step))
6562 return FullSet;
6563
6564 // If we found a recurrence in reachable code, we must be in a loop. Note
6565 // that BO might be in some subloop of L, and that's completely okay.
6566 auto *L = LI.getLoopFor(P->getParent());
6567 assert(L && L->getHeader() == P->getParent());
6568 if (!L->contains(BO->getParent()))
6569 // NOTE: This bailout should be an assert instead. However, asserting
6570 // the condition here exposes a case where LoopFusion is querying SCEV
6571 // with malformed loop information during the midst of the transform.
6572 // There doesn't appear to be an obvious fix, so for the moment bailout
6573 // until the caller issue can be fixed. PR49566 tracks the bug.
6574 return FullSet;
6575
6576 // TODO: Extend to other opcodes such as mul, and div
6577 switch (BO->getOpcode()) {
6578 default:
6579 return FullSet;
6580 case Instruction::AShr:
6581 case Instruction::LShr:
6582 case Instruction::Shl:
6583 break;
6584 };
6585
6586 if (BO->getOperand(0) != P)
6587 // TODO: Handle the power function forms some day.
6588 return FullSet;
6589
6590 unsigned TC = getSmallConstantMaxTripCount(L);
6591 if (!TC || TC >= BitWidth)
6592 return FullSet;
6593
6594 auto KnownStart = computeKnownBits(Start, DL, &AC, nullptr, &DT);
6595 auto KnownStep = computeKnownBits(Step, DL, &AC, nullptr, &DT);
6596 assert(KnownStart.getBitWidth() == BitWidth &&
6597 KnownStep.getBitWidth() == BitWidth);
6598
6599 // Compute total shift amount, being careful of overflow and bitwidths.
6600 auto MaxShiftAmt = KnownStep.getMaxValue();
6601 APInt TCAP(BitWidth, TC-1);
6602 bool Overflow = false;
6603 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6604 if (Overflow)
6605 return FullSet;
6606
6607 switch (BO->getOpcode()) {
6608 default:
6609 llvm_unreachable("filtered out above");
6610 case Instruction::AShr: {
6611 // For each ashr, three cases:
6612 // shift = 0 => unchanged value
6613 // saturation => 0 or -1
6614 // other => a value closer to zero (of the same sign)
6615 // Thus, the end value is closer to zero than the start.
6616 auto KnownEnd = KnownBits::ashr(KnownStart,
6617 KnownBits::makeConstant(TotalShift));
6618 if (KnownStart.isNonNegative())
6619 // Analogous to lshr (simply not yet canonicalized)
6620 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6621 KnownStart.getMaxValue() + 1);
6622 if (KnownStart.isNegative())
6623 // End >=u Start && End <=s Start
6624 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6625 KnownEnd.getMaxValue() + 1);
6626 break;
6627 }
6628 case Instruction::LShr: {
6629 // For each lshr, three cases:
6630 // shift = 0 => unchanged value
6631 // saturation => 0
6632 // other => a smaller positive number
6633 // Thus, the low end of the unsigned range is the last value produced.
6634 auto KnownEnd = KnownBits::lshr(KnownStart,
6635 KnownBits::makeConstant(TotalShift));
6636 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6637 KnownStart.getMaxValue() + 1);
6638 }
6639 case Instruction::Shl: {
6640 // Iff no bits are shifted out, value increases on every shift.
6641 auto KnownEnd = KnownBits::shl(KnownStart,
6642 KnownBits::makeConstant(TotalShift));
6643 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6644 return ConstantRange(KnownStart.getMinValue(),
6645 KnownEnd.getMaxValue() + 1);
6646 break;
6647 }
6648 };
6649 return FullSet;
6650}
6651
6652const ConstantRange &
6653ScalarEvolution::getRangeRefIter(const SCEV *S,
6654 ScalarEvolution::RangeSignHint SignHint) {
6655 DenseMap<const SCEV *, ConstantRange> &Cache =
6656 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6657 : SignedRanges;
6659 SmallPtrSet<const SCEV *, 8> Seen;
6660
6661 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6662 // SCEVUnknown PHI node.
6663 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6664 if (!Seen.insert(Expr).second)
6665 return;
6666 if (Cache.contains(Expr))
6667 return;
6668 switch (Expr->getSCEVType()) {
6669 case scUnknown:
6670 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6671 break;
6672 [[fallthrough]];
6673 case scConstant:
6674 case scVScale:
6675 case scTruncate:
6676 case scZeroExtend:
6677 case scSignExtend:
6678 case scPtrToAddr:
6679 case scPtrToInt:
6680 case scAddExpr:
6681 case scMulExpr:
6682 case scUDivExpr:
6683 case scAddRecExpr:
6684 case scUMaxExpr:
6685 case scSMaxExpr:
6686 case scUMinExpr:
6687 case scSMinExpr:
6689 WorkList.push_back(Expr);
6690 break;
6691 case scCouldNotCompute:
6692 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6693 }
6694 };
6695 AddToWorklist(S);
6696
6697 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6698 for (unsigned I = 0; I != WorkList.size(); ++I) {
6699 const SCEV *P = WorkList[I];
6700 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6701 // If it is not a `SCEVUnknown`, just recurse into operands.
6702 if (!UnknownS) {
6703 for (const SCEV *Op : P->operands())
6704 AddToWorklist(Op);
6705 continue;
6706 }
6707 // `SCEVUnknown`'s require special treatment.
6708 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6709 if (!PendingPhiRangesIter.insert(P).second)
6710 continue;
6711 for (auto &Op : reverse(P->operands()))
6712 AddToWorklist(getSCEV(Op));
6713 }
6714 }
6715
6716 if (!WorkList.empty()) {
6717 // Use getRangeRef to compute ranges for items in the worklist in reverse
6718 // order. This will force ranges for earlier operands to be computed before
6719 // their users in most cases.
6720 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6721 getRangeRef(P, SignHint);
6722
6723 if (auto *UnknownS = dyn_cast<SCEVUnknown>(P))
6724 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue()))
6725 PendingPhiRangesIter.erase(P);
6726 }
6727 }
6728
6729 return getRangeRef(S, SignHint, 0);
6730}
6731
6732/// Determine the range for a particular SCEV. If SignHint is
6733/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6734/// with a "cleaner" unsigned (resp. signed) representation.
6735const ConstantRange &ScalarEvolution::getRangeRef(
6736 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6737 DenseMap<const SCEV *, ConstantRange> &Cache =
6738 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6739 : SignedRanges;
6741 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6743
6744 // See if we've computed this range already.
6746 if (I != Cache.end())
6747 return I->second;
6748
6749 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6750 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6751
6752 // Switch to iteratively computing the range for S, if it is part of a deeply
6753 // nested expression.
6755 return getRangeRefIter(S, SignHint);
6756
6757 unsigned BitWidth = getTypeSizeInBits(S->getType());
6758 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6759 using OBO = OverflowingBinaryOperator;
6760
6761 // If the value has known zeros, the maximum value will have those known zeros
6762 // as well.
6763 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6764 APInt Multiple = getNonZeroConstantMultiple(S);
6765 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6766 if (!Remainder.isZero())
6767 ConservativeResult =
6768 ConstantRange(APInt::getMinValue(BitWidth),
6769 APInt::getMaxValue(BitWidth) - Remainder + 1);
6770 }
6771 else {
6772 uint32_t TZ = getMinTrailingZeros(S);
6773 if (TZ != 0) {
6774 ConservativeResult = ConstantRange(
6776 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6777 }
6778 }
6779
6780 switch (S->getSCEVType()) {
6781 case scConstant:
6782 llvm_unreachable("Already handled above.");
6783 case scVScale:
6784 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6785 case scTruncate: {
6786 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6787 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6788 return setRange(
6789 Trunc, SignHint,
6790 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6791 }
6792 case scZeroExtend: {
6793 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6794 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6795 return setRange(
6796 ZExt, SignHint,
6797 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6798 }
6799 case scSignExtend: {
6800 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6801 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6802 return setRange(
6803 SExt, SignHint,
6804 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6805 }
6806 case scPtrToAddr:
6807 case scPtrToInt: {
6808 const SCEVCastExpr *Cast = cast<SCEVCastExpr>(S);
6809 ConstantRange X = getRangeRef(Cast->getOperand(), SignHint, Depth + 1);
6810 return setRange(Cast, SignHint, X);
6811 }
6812 case scAddExpr: {
6813 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6814 // Check if this is a URem pattern: A - (A / B) * B, which is always < B.
6815 const SCEV *URemLHS = nullptr, *URemRHS = nullptr;
6816 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED &&
6817 match(S, m_scev_URem(m_SCEV(URemLHS), m_SCEV(URemRHS), *this))) {
6818 ConstantRange LHSRange = getRangeRef(URemLHS, SignHint, Depth + 1);
6819 ConstantRange RHSRange = getRangeRef(URemRHS, SignHint, Depth + 1);
6820 ConservativeResult =
6821 ConservativeResult.intersectWith(LHSRange.urem(RHSRange), RangeType);
6822 }
6823 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6824 unsigned WrapType = OBO::AnyWrap;
6825 if (Add->hasNoSignedWrap())
6826 WrapType |= OBO::NoSignedWrap;
6827 if (Add->hasNoUnsignedWrap())
6828 WrapType |= OBO::NoUnsignedWrap;
6829 for (const SCEV *Op : drop_begin(Add->operands()))
6830 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType,
6831 RangeType);
6832 return setRange(Add, SignHint,
6833 ConservativeResult.intersectWith(X, RangeType));
6834 }
6835 case scMulExpr: {
6836 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6837 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6838 for (const SCEV *Op : drop_begin(Mul->operands()))
6839 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1));
6840 return setRange(Mul, SignHint,
6841 ConservativeResult.intersectWith(X, RangeType));
6842 }
6843 case scUDivExpr: {
6844 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6845 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6846 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6847 return setRange(UDiv, SignHint,
6848 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6849 }
6850 case scAddRecExpr: {
6851 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6852 // If there's no unsigned wrap, the value will never be less than its
6853 // initial value.
6854 if (AddRec->hasNoUnsignedWrap()) {
6855 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6856 if (!UnsignedMinValue.isZero())
6857 ConservativeResult = ConservativeResult.intersectWith(
6858 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6859 }
6860
6861 // If there's no signed wrap, and all the operands except initial value have
6862 // the same sign or zero, the value won't ever be:
6863 // 1: smaller than initial value if operands are non negative,
6864 // 2: bigger than initial value if operands are non positive.
6865 // For both cases, value can not cross signed min/max boundary.
6866 if (AddRec->hasNoSignedWrap()) {
6867 bool AllNonNeg = true;
6868 bool AllNonPos = true;
6869 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6870 if (!isKnownNonNegative(AddRec->getOperand(i)))
6871 AllNonNeg = false;
6872 if (!isKnownNonPositive(AddRec->getOperand(i)))
6873 AllNonPos = false;
6874 }
6875 if (AllNonNeg)
6876 ConservativeResult = ConservativeResult.intersectWith(
6879 RangeType);
6880 else if (AllNonPos)
6881 ConservativeResult = ConservativeResult.intersectWith(
6883 getSignedRangeMax(AddRec->getStart()) +
6884 1),
6885 RangeType);
6886 }
6887
6888 // TODO: non-affine addrec
6889 if (AddRec->isAffine()) {
6890 const SCEV *MaxBEScev =
6892 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
6893 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
6894
6895 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
6896 // MaxBECount's active bits are all <= AddRec's bit width.
6897 if (MaxBECount.getBitWidth() > BitWidth &&
6898 MaxBECount.getActiveBits() <= BitWidth)
6899 MaxBECount = MaxBECount.trunc(BitWidth);
6900 else if (MaxBECount.getBitWidth() < BitWidth)
6901 MaxBECount = MaxBECount.zext(BitWidth);
6902
6903 if (MaxBECount.getBitWidth() == BitWidth) {
6904 auto RangeFromAffine = getRangeForAffineAR(
6905 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6906 ConservativeResult =
6907 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
6908
6909 auto RangeFromFactoring = getRangeViaFactoring(
6910 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6911 ConservativeResult =
6912 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
6913 }
6914 }
6915
6916 // Now try symbolic BE count and more powerful methods.
6918 const SCEV *SymbolicMaxBECount =
6920 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
6921 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
6922 AddRec->hasNoSelfWrap()) {
6923 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
6924 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
6925 ConservativeResult =
6926 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
6927 }
6928 }
6929 }
6930
6931 return setRange(AddRec, SignHint, std::move(ConservativeResult));
6932 }
6933 case scUMaxExpr:
6934 case scSMaxExpr:
6935 case scUMinExpr:
6936 case scSMinExpr:
6937 case scSequentialUMinExpr: {
6939 switch (S->getSCEVType()) {
6940 case scUMaxExpr:
6941 ID = Intrinsic::umax;
6942 break;
6943 case scSMaxExpr:
6944 ID = Intrinsic::smax;
6945 break;
6946 case scUMinExpr:
6948 ID = Intrinsic::umin;
6949 break;
6950 case scSMinExpr:
6951 ID = Intrinsic::smin;
6952 break;
6953 default:
6954 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
6955 }
6956
6957 const auto *NAry = cast<SCEVNAryExpr>(S);
6958 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
6959 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
6960 X = X.intrinsic(
6961 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
6962 return setRange(S, SignHint,
6963 ConservativeResult.intersectWith(X, RangeType));
6964 }
6965 case scUnknown: {
6966 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6967 Value *V = U->getValue();
6968
6969 // Check if the IR explicitly contains !range metadata.
6970 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
6971 if (MDRange)
6972 ConservativeResult =
6973 ConservativeResult.intersectWith(*MDRange, RangeType);
6974
6975 // Use facts about recurrences in the underlying IR. Note that add
6976 // recurrences are AddRecExprs and thus don't hit this path. This
6977 // primarily handles shift recurrences.
6978 auto CR = getRangeForUnknownRecurrence(U);
6979 ConservativeResult = ConservativeResult.intersectWith(CR);
6980
6981 // See if ValueTracking can give us a useful range.
6982 const DataLayout &DL = getDataLayout();
6983 KnownBits Known = computeKnownBits(V, DL, &AC, nullptr, &DT);
6984 if (Known.getBitWidth() != BitWidth)
6985 Known = Known.zextOrTrunc(BitWidth);
6986
6987 // ValueTracking may be able to compute a tighter result for the number of
6988 // sign bits than for the value of those sign bits.
6989 unsigned NS = ComputeNumSignBits(V, DL, &AC, nullptr, &DT);
6990 if (U->getType()->isPointerTy()) {
6991 // If the pointer size is larger than the index size type, this can cause
6992 // NS to be larger than BitWidth. So compensate for this.
6993 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
6994 int ptrIdxDiff = ptrSize - BitWidth;
6995 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
6996 NS -= ptrIdxDiff;
6997 }
6998
6999 if (NS > 1) {
7000 // If we know any of the sign bits, we know all of the sign bits.
7001 if (!Known.Zero.getHiBits(NS).isZero())
7002 Known.Zero.setHighBits(NS);
7003 if (!Known.One.getHiBits(NS).isZero())
7004 Known.One.setHighBits(NS);
7005 }
7006
7007 if (Known.getMinValue() != Known.getMaxValue() + 1)
7008 ConservativeResult = ConservativeResult.intersectWith(
7009 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
7010 RangeType);
7011 if (NS > 1)
7012 ConservativeResult = ConservativeResult.intersectWith(
7013 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
7014 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
7015 RangeType);
7016
7017 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
7018 // Strengthen the range if the underlying IR value is a
7019 // global/alloca/heap allocation using the size of the object.
7020 bool CanBeNull, CanBeFreed;
7021 uint64_t DerefBytes =
7022 V->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed);
7023 if (DerefBytes > 1 && isUIntN(BitWidth, DerefBytes)) {
7024 // The highest address the object can start is DerefBytes bytes before
7025 // the end (unsigned max value). If this value is not a multiple of the
7026 // alignment, the last possible start value is the next lowest multiple
7027 // of the alignment. Note: The computations below cannot overflow,
7028 // because if they would there's no possible start address for the
7029 // object.
7030 APInt MaxVal =
7031 APInt::getMaxValue(BitWidth) - APInt(BitWidth, DerefBytes);
7032 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
7033 uint64_t Rem = MaxVal.urem(Align);
7034 MaxVal -= APInt(BitWidth, Rem);
7035 APInt MinVal = APInt::getZero(BitWidth);
7036 if (llvm::isKnownNonZero(V, DL))
7037 MinVal = Align;
7038 ConservativeResult = ConservativeResult.intersectWith(
7039 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
7040 }
7041 }
7042
7043 // A range of Phi is a subset of union of all ranges of its input.
7044 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
7045 // Make sure that we do not run over cycled Phis.
7046 if (PendingPhiRanges.insert(Phi).second) {
7047 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
7048
7049 for (const auto &Op : Phi->operands()) {
7050 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
7051 RangeFromOps = RangeFromOps.unionWith(OpRange);
7052 // No point to continue if we already have a full set.
7053 if (RangeFromOps.isFullSet())
7054 break;
7055 }
7056 ConservativeResult =
7057 ConservativeResult.intersectWith(RangeFromOps, RangeType);
7058 bool Erased = PendingPhiRanges.erase(Phi);
7059 assert(Erased && "Failed to erase Phi properly?");
7060 (void)Erased;
7061 }
7062 }
7063
7064 // vscale can't be equal to zero
7065 if (const auto *II = dyn_cast<IntrinsicInst>(V))
7066 if (II->getIntrinsicID() == Intrinsic::vscale) {
7067 ConstantRange Disallowed = APInt::getZero(BitWidth);
7068 ConservativeResult = ConservativeResult.difference(Disallowed);
7069 }
7070
7071 return setRange(U, SignHint, std::move(ConservativeResult));
7072 }
7073 case scCouldNotCompute:
7074 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
7075 }
7076
7077 return setRange(S, SignHint, std::move(ConservativeResult));
7078}
7079
7080// Given a StartRange, Step and MaxBECount for an expression compute a range of
7081// values that the expression can take. Initially, the expression has a value
7082// from StartRange and then is changed by Step up to MaxBECount times. Signed
7083// argument defines if we treat Step as signed or unsigned.
7085 const ConstantRange &StartRange,
7086 const APInt &MaxBECount,
7087 bool Signed) {
7088 unsigned BitWidth = Step.getBitWidth();
7089 assert(BitWidth == StartRange.getBitWidth() &&
7090 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
7091 // If either Step or MaxBECount is 0, then the expression won't change, and we
7092 // just need to return the initial range.
7093 if (Step == 0 || MaxBECount == 0)
7094 return StartRange;
7095
7096 // If we don't know anything about the initial value (i.e. StartRange is
7097 // FullRange), then we don't know anything about the final range either.
7098 // Return FullRange.
7099 if (StartRange.isFullSet())
7100 return ConstantRange::getFull(BitWidth);
7101
7102 // If Step is signed and negative, then we use its absolute value, but we also
7103 // note that we're moving in the opposite direction.
7104 bool Descending = Signed && Step.isNegative();
7105
7106 if (Signed)
7107 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
7108 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
7109 // This equations hold true due to the well-defined wrap-around behavior of
7110 // APInt.
7111 Step = Step.abs();
7112
7113 // Check if Offset is more than full span of BitWidth. If it is, the
7114 // expression is guaranteed to overflow.
7115 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
7116 return ConstantRange::getFull(BitWidth);
7117
7118 // Offset is by how much the expression can change. Checks above guarantee no
7119 // overflow here.
7120 APInt Offset = Step * MaxBECount;
7121
7122 // Minimum value of the final range will match the minimal value of StartRange
7123 // if the expression is increasing and will be decreased by Offset otherwise.
7124 // Maximum value of the final range will match the maximal value of StartRange
7125 // if the expression is decreasing and will be increased by Offset otherwise.
7126 APInt StartLower = StartRange.getLower();
7127 APInt StartUpper = StartRange.getUpper() - 1;
7128 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
7129 : (StartUpper + std::move(Offset));
7130
7131 // It's possible that the new minimum/maximum value will fall into the initial
7132 // range (due to wrap around). This means that the expression can take any
7133 // value in this bitwidth, and we have to return full range.
7134 if (StartRange.contains(MovedBoundary))
7135 return ConstantRange::getFull(BitWidth);
7136
7137 APInt NewLower =
7138 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7139 APInt NewUpper =
7140 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7141 NewUpper += 1;
7142
7143 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7144 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7145}
7146
7147ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7148 const SCEV *Step,
7149 const APInt &MaxBECount) {
7150 assert(getTypeSizeInBits(Start->getType()) ==
7151 getTypeSizeInBits(Step->getType()) &&
7152 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7153 "mismatched bit widths");
7154
7155 // First, consider step signed.
7156 ConstantRange StartSRange = getSignedRange(Start);
7157 ConstantRange StepSRange = getSignedRange(Step);
7158
7159 // If Step can be both positive and negative, we need to find ranges for the
7160 // maximum absolute step values in both directions and union them.
7161 ConstantRange SR = getRangeForAffineARHelper(
7162 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7164 StartSRange, MaxBECount,
7165 /* Signed = */ true));
7166
7167 // Next, consider step unsigned.
7168 ConstantRange UR = getRangeForAffineARHelper(
7169 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7170 /* Signed = */ false);
7171
7172 // Finally, intersect signed and unsigned ranges.
7174}
7175
7176ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7177 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7178 ScalarEvolution::RangeSignHint SignHint) {
7179 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7180 assert(AddRec->hasNoSelfWrap() &&
7181 "This only works for non-self-wrapping AddRecs!");
7182 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7183 const SCEV *Step = AddRec->getStepRecurrence(*this);
7184 // Only deal with constant step to save compile time.
7185 if (!isa<SCEVConstant>(Step))
7186 return ConstantRange::getFull(BitWidth);
7187 // Let's make sure that we can prove that we do not self-wrap during
7188 // MaxBECount iterations. We need this because MaxBECount is a maximum
7189 // iteration count estimate, and we might infer nw from some exit for which we
7190 // do not know max exit count (or any other side reasoning).
7191 // TODO: Turn into assert at some point.
7192 if (getTypeSizeInBits(MaxBECount->getType()) >
7193 getTypeSizeInBits(AddRec->getType()))
7194 return ConstantRange::getFull(BitWidth);
7195 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7196 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7197 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7198 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7199 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7200 MaxItersWithoutWrap))
7201 return ConstantRange::getFull(BitWidth);
7202
7203 ICmpInst::Predicate LEPred =
7205 ICmpInst::Predicate GEPred =
7207 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7208
7209 // We know that there is no self-wrap. Let's take Start and End values and
7210 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7211 // the iteration. They either lie inside the range [Min(Start, End),
7212 // Max(Start, End)] or outside it:
7213 //
7214 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7215 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7216 //
7217 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7218 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7219 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7220 // Start <= End and step is positive, or Start >= End and step is negative.
7221 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7222 ConstantRange StartRange = getRangeRef(Start, SignHint);
7223 ConstantRange EndRange = getRangeRef(End, SignHint);
7224 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7225 // If they already cover full iteration space, we will know nothing useful
7226 // even if we prove what we want to prove.
7227 if (RangeBetween.isFullSet())
7228 return RangeBetween;
7229 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7230 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7231 : RangeBetween.isWrappedSet();
7232 if (IsWrappedSet)
7233 return ConstantRange::getFull(BitWidth);
7234
7235 if (isKnownPositive(Step) &&
7236 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7237 return RangeBetween;
7238 if (isKnownNegative(Step) &&
7239 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7240 return RangeBetween;
7241 return ConstantRange::getFull(BitWidth);
7242}
7243
7244ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7245 const SCEV *Step,
7246 const APInt &MaxBECount) {
7247 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7248 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7249
7250 unsigned BitWidth = MaxBECount.getBitWidth();
7251 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7252 getTypeSizeInBits(Step->getType()) == BitWidth &&
7253 "mismatched bit widths");
7254
7255 struct SelectPattern {
7256 Value *Condition = nullptr;
7257 APInt TrueValue;
7258 APInt FalseValue;
7259
7260 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7261 const SCEV *S) {
7262 std::optional<unsigned> CastOp;
7263 APInt Offset(BitWidth, 0);
7264
7266 "Should be!");
7267
7268 // Peel off a constant offset. In the future we could consider being
7269 // smarter here and handle {Start+Step,+,Step} too.
7270 const APInt *Off;
7271 if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
7272 Offset = *Off;
7273
7274 // Peel off a cast operation
7275 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7276 CastOp = SCast->getSCEVType();
7277 S = SCast->getOperand();
7278 }
7279
7280 using namespace llvm::PatternMatch;
7281
7282 auto *SU = dyn_cast<SCEVUnknown>(S);
7283 const APInt *TrueVal, *FalseVal;
7284 if (!SU ||
7285 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7286 m_APInt(FalseVal)))) {
7287 Condition = nullptr;
7288 return;
7289 }
7290
7291 TrueValue = *TrueVal;
7292 FalseValue = *FalseVal;
7293
7294 // Re-apply the cast we peeled off earlier
7295 if (CastOp)
7296 switch (*CastOp) {
7297 default:
7298 llvm_unreachable("Unknown SCEV cast type!");
7299
7300 case scTruncate:
7301 TrueValue = TrueValue.trunc(BitWidth);
7302 FalseValue = FalseValue.trunc(BitWidth);
7303 break;
7304 case scZeroExtend:
7305 TrueValue = TrueValue.zext(BitWidth);
7306 FalseValue = FalseValue.zext(BitWidth);
7307 break;
7308 case scSignExtend:
7309 TrueValue = TrueValue.sext(BitWidth);
7310 FalseValue = FalseValue.sext(BitWidth);
7311 break;
7312 }
7313
7314 // Re-apply the constant offset we peeled off earlier
7315 TrueValue += Offset;
7316 FalseValue += Offset;
7317 }
7318
7319 bool isRecognized() { return Condition != nullptr; }
7320 };
7321
7322 SelectPattern StartPattern(*this, BitWidth, Start);
7323 if (!StartPattern.isRecognized())
7324 return ConstantRange::getFull(BitWidth);
7325
7326 SelectPattern StepPattern(*this, BitWidth, Step);
7327 if (!StepPattern.isRecognized())
7328 return ConstantRange::getFull(BitWidth);
7329
7330 if (StartPattern.Condition != StepPattern.Condition) {
7331 // We don't handle this case today; but we could, by considering four
7332 // possibilities below instead of two. I'm not sure if there are cases where
7333 // that will help over what getRange already does, though.
7334 return ConstantRange::getFull(BitWidth);
7335 }
7336
7337 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7338 // construct arbitrary general SCEV expressions here. This function is called
7339 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7340 // say) can end up caching a suboptimal value.
7341
7342 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7343 // C2352 and C2512 (otherwise it isn't needed).
7344
7345 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7346 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7347 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7348 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7349
7350 ConstantRange TrueRange =
7351 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7352 ConstantRange FalseRange =
7353 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7354
7355 return TrueRange.unionWith(FalseRange);
7356}
7357
7358SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7359 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7360 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7361
7362 // Return early if there are no flags to propagate to the SCEV.
7364 if (BinOp->hasNoUnsignedWrap())
7366 if (BinOp->hasNoSignedWrap())
7368 if (Flags == SCEV::FlagAnyWrap)
7369 return SCEV::FlagAnyWrap;
7370
7371 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7372}
7373
7374const Instruction *
7375ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7376 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7377 return &*AddRec->getLoop()->getHeader()->begin();
7378 if (auto *U = dyn_cast<SCEVUnknown>(S))
7379 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7380 return I;
7381 return nullptr;
7382}
7383
7384const Instruction *
7385ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops,
7386 bool &Precise) {
7387 Precise = true;
7388 // Do a bounded search of the def relation of the requested SCEVs.
7389 SmallPtrSet<const SCEV *, 16> Visited;
7391 auto pushOp = [&](const SCEV *S) {
7392 if (!Visited.insert(S).second)
7393 return;
7394 // Threshold of 30 here is arbitrary.
7395 if (Visited.size() > 30) {
7396 Precise = false;
7397 return;
7398 }
7399 Worklist.push_back(S);
7400 };
7401
7402 for (const auto *S : Ops)
7403 pushOp(S);
7404
7405 const Instruction *Bound = nullptr;
7406 while (!Worklist.empty()) {
7407 auto *S = Worklist.pop_back_val();
7408 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7409 if (!Bound || DT.dominates(Bound, DefI))
7410 Bound = DefI;
7411 } else {
7412 for (const auto *Op : S->operands())
7413 pushOp(Op);
7414 }
7415 }
7416 return Bound ? Bound : &*F.getEntryBlock().begin();
7417}
7418
7419const Instruction *
7420ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) {
7421 bool Discard;
7422 return getDefiningScopeBound(Ops, Discard);
7423}
7424
7425bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7426 const Instruction *B) {
7427 if (A->getParent() == B->getParent() &&
7429 B->getIterator()))
7430 return true;
7431
7432 auto *BLoop = LI.getLoopFor(B->getParent());
7433 if (BLoop && BLoop->getHeader() == B->getParent() &&
7434 BLoop->getLoopPreheader() == A->getParent() &&
7436 A->getParent()->end()) &&
7437 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7438 B->getIterator()))
7439 return true;
7440 return false;
7441}
7442
7443bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
7444 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ true);
7445 visitAll(Op, PC);
7446 return PC.MaybePoison.empty();
7447}
7448
7449bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
7450 return !SCEVExprContains(Op, [this](const SCEV *S) {
7451 const SCEV *Op1;
7452 bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
7453 // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
7454 // is a non-zero constant, we have to assume the UDiv may be UB.
7455 return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
7456 });
7457}
7458
7459bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7460 // Only proceed if we can prove that I does not yield poison.
7462 return false;
7463
7464 // At this point we know that if I is executed, then it does not wrap
7465 // according to at least one of NSW or NUW. If I is not executed, then we do
7466 // not know if the calculation that I represents would wrap. Multiple
7467 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7468 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7469 // derived from other instructions that map to the same SCEV. We cannot make
7470 // that guarantee for cases where I is not executed. So we need to find a
7471 // upper bound on the defining scope for the SCEV, and prove that I is
7472 // executed every time we enter that scope. When the bounding scope is a
7473 // loop (the common case), this is equivalent to proving I executes on every
7474 // iteration of that loop.
7476 for (const Use &Op : I->operands()) {
7477 // I could be an extractvalue from a call to an overflow intrinsic.
7478 // TODO: We can do better here in some cases.
7479 if (isSCEVable(Op->getType()))
7480 SCEVOps.push_back(getSCEV(Op));
7481 }
7482 auto *DefI = getDefiningScopeBound(SCEVOps);
7483 return isGuaranteedToTransferExecutionTo(DefI, I);
7484}
7485
7486bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7487 // If we know that \c I can never be poison period, then that's enough.
7488 if (isSCEVExprNeverPoison(I))
7489 return true;
7490
7491 // If the loop only has one exit, then we know that, if the loop is entered,
7492 // any instruction dominating that exit will be executed. If any such
7493 // instruction would result in UB, the addrec cannot be poison.
7494 //
7495 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7496 // also handles uses outside the loop header (they just need to dominate the
7497 // single exit).
7498
7499 auto *ExitingBB = L->getExitingBlock();
7500 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7501 return false;
7502
7503 SmallPtrSet<const Value *, 16> KnownPoison;
7505
7506 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7507 // things that are known to be poison under that assumption go on the
7508 // Worklist.
7509 KnownPoison.insert(I);
7510 Worklist.push_back(I);
7511
7512 while (!Worklist.empty()) {
7513 const Instruction *Poison = Worklist.pop_back_val();
7514
7515 for (const Use &U : Poison->uses()) {
7516 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7517 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7518 DT.dominates(PoisonUser->getParent(), ExitingBB))
7519 return true;
7520
7521 if (propagatesPoison(U) && L->contains(PoisonUser))
7522 if (KnownPoison.insert(PoisonUser).second)
7523 Worklist.push_back(PoisonUser);
7524 }
7525 }
7526
7527 return false;
7528}
7529
7530ScalarEvolution::LoopProperties
7531ScalarEvolution::getLoopProperties(const Loop *L) {
7532 using LoopProperties = ScalarEvolution::LoopProperties;
7533
7534 auto Itr = LoopPropertiesCache.find(L);
7535 if (Itr == LoopPropertiesCache.end()) {
7536 auto HasSideEffects = [](Instruction *I) {
7537 if (auto *SI = dyn_cast<StoreInst>(I))
7538 return !SI->isSimple();
7539
7540 if (I->mayThrow())
7541 return true;
7542
7543 // Non-volatile memset / memcpy do not count as side-effect for forward
7544 // progress.
7545 if (isa<MemIntrinsic>(I) && !I->isVolatile())
7546 return false;
7547
7548 return I->mayWriteToMemory();
7549 };
7550
7551 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7552 /*HasNoSideEffects*/ true};
7553
7554 for (auto *BB : L->getBlocks())
7555 for (auto &I : *BB) {
7557 LP.HasNoAbnormalExits = false;
7558 if (HasSideEffects(&I))
7559 LP.HasNoSideEffects = false;
7560 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7561 break; // We're already as pessimistic as we can get.
7562 }
7563
7564 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7565 assert(InsertPair.second && "We just checked!");
7566 Itr = InsertPair.first;
7567 }
7568
7569 return Itr->second;
7570}
7571
7573 // A mustprogress loop without side effects must be finite.
7574 // TODO: The check used here is very conservative. It's only *specific*
7575 // side effects which are well defined in infinite loops.
7576 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7577}
7578
7579const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7580 // Worklist item with a Value and a bool indicating whether all operands have
7581 // been visited already.
7584
7585 Stack.emplace_back(V, true);
7586 Stack.emplace_back(V, false);
7587 while (!Stack.empty()) {
7588 auto E = Stack.pop_back_val();
7589 Value *CurV = E.getPointer();
7590
7591 if (getExistingSCEV(CurV))
7592 continue;
7593
7595 const SCEV *CreatedSCEV = nullptr;
7596 // If all operands have been visited already, create the SCEV.
7597 if (E.getInt()) {
7598 CreatedSCEV = createSCEV(CurV);
7599 } else {
7600 // Otherwise get the operands we need to create SCEV's for before creating
7601 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7602 // just use it.
7603 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7604 }
7605
7606 if (CreatedSCEV) {
7607 insertValueToMap(CurV, CreatedSCEV);
7608 } else {
7609 // Queue CurV for SCEV creation, followed by its's operands which need to
7610 // be constructed first.
7611 Stack.emplace_back(CurV, true);
7612 for (Value *Op : Ops)
7613 Stack.emplace_back(Op, false);
7614 }
7615 }
7616
7617 return getExistingSCEV(V);
7618}
7619
7620const SCEV *
7621ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7622 if (!isSCEVable(V->getType()))
7623 return getUnknown(V);
7624
7625 if (Instruction *I = dyn_cast<Instruction>(V)) {
7626 // Don't attempt to analyze instructions in blocks that aren't
7627 // reachable. Such instructions don't matter, and they aren't required
7628 // to obey basic rules for definitions dominating uses which this
7629 // analysis depends on.
7630 if (!DT.isReachableFromEntry(I->getParent()))
7631 return getUnknown(PoisonValue::get(V->getType()));
7632 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7633 return getConstant(CI);
7634 else if (isa<GlobalAlias>(V))
7635 return getUnknown(V);
7636 else if (!isa<ConstantExpr>(V))
7637 return getUnknown(V);
7638
7640 if (auto BO =
7642 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7643 switch (BO->Opcode) {
7644 case Instruction::Add:
7645 case Instruction::Mul: {
7646 // For additions and multiplications, traverse add/mul chains for which we
7647 // can potentially create a single SCEV, to reduce the number of
7648 // get{Add,Mul}Expr calls.
7649 do {
7650 if (BO->Op) {
7651 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7652 Ops.push_back(BO->Op);
7653 break;
7654 }
7655 }
7656 Ops.push_back(BO->RHS);
7657 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7659 if (!NewBO ||
7660 (BO->Opcode == Instruction::Add &&
7661 (NewBO->Opcode != Instruction::Add &&
7662 NewBO->Opcode != Instruction::Sub)) ||
7663 (BO->Opcode == Instruction::Mul &&
7664 NewBO->Opcode != Instruction::Mul)) {
7665 Ops.push_back(BO->LHS);
7666 break;
7667 }
7668 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7669 // requires a SCEV for the LHS.
7670 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7671 auto *I = dyn_cast<Instruction>(BO->Op);
7672 if (I && programUndefinedIfPoison(I)) {
7673 Ops.push_back(BO->LHS);
7674 break;
7675 }
7676 }
7677 BO = NewBO;
7678 } while (true);
7679 return nullptr;
7680 }
7681 case Instruction::Sub:
7682 case Instruction::UDiv:
7683 case Instruction::URem:
7684 break;
7685 case Instruction::AShr:
7686 case Instruction::Shl:
7687 case Instruction::Xor:
7688 if (!IsConstArg)
7689 return nullptr;
7690 break;
7691 case Instruction::And:
7692 case Instruction::Or:
7693 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7694 return nullptr;
7695 break;
7696 case Instruction::LShr:
7697 return getUnknown(V);
7698 default:
7699 llvm_unreachable("Unhandled binop");
7700 break;
7701 }
7702
7703 Ops.push_back(BO->LHS);
7704 Ops.push_back(BO->RHS);
7705 return nullptr;
7706 }
7707
7708 switch (U->getOpcode()) {
7709 case Instruction::Trunc:
7710 case Instruction::ZExt:
7711 case Instruction::SExt:
7712 case Instruction::PtrToAddr:
7713 case Instruction::PtrToInt:
7714 Ops.push_back(U->getOperand(0));
7715 return nullptr;
7716
7717 case Instruction::BitCast:
7718 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7719 Ops.push_back(U->getOperand(0));
7720 return nullptr;
7721 }
7722 return getUnknown(V);
7723
7724 case Instruction::SDiv:
7725 case Instruction::SRem:
7726 Ops.push_back(U->getOperand(0));
7727 Ops.push_back(U->getOperand(1));
7728 return nullptr;
7729
7730 case Instruction::GetElementPtr:
7731 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7732 "GEP source element type must be sized");
7733 llvm::append_range(Ops, U->operands());
7734 return nullptr;
7735
7736 case Instruction::IntToPtr:
7737 return getUnknown(V);
7738
7739 case Instruction::PHI:
7740 // Keep constructing SCEVs' for phis recursively for now.
7741 return nullptr;
7742
7743 case Instruction::Select: {
7744 // Check if U is a select that can be simplified to a SCEVUnknown.
7745 auto CanSimplifyToUnknown = [this, U]() {
7746 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7747 return false;
7748
7749 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7750 if (!ICI)
7751 return false;
7752 Value *LHS = ICI->getOperand(0);
7753 Value *RHS = ICI->getOperand(1);
7754 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7755 ICI->getPredicate() == CmpInst::ICMP_NE) {
7757 return true;
7758 } else if (getTypeSizeInBits(LHS->getType()) >
7759 getTypeSizeInBits(U->getType()))
7760 return true;
7761 return false;
7762 };
7763 if (CanSimplifyToUnknown())
7764 return getUnknown(U);
7765
7766 llvm::append_range(Ops, U->operands());
7767 return nullptr;
7768 break;
7769 }
7770 case Instruction::Call:
7771 case Instruction::Invoke:
7772 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7773 Ops.push_back(RV);
7774 return nullptr;
7775 }
7776
7777 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7778 switch (II->getIntrinsicID()) {
7779 case Intrinsic::abs:
7780 Ops.push_back(II->getArgOperand(0));
7781 return nullptr;
7782 case Intrinsic::umax:
7783 case Intrinsic::umin:
7784 case Intrinsic::smax:
7785 case Intrinsic::smin:
7786 case Intrinsic::usub_sat:
7787 case Intrinsic::uadd_sat:
7788 Ops.push_back(II->getArgOperand(0));
7789 Ops.push_back(II->getArgOperand(1));
7790 return nullptr;
7791 case Intrinsic::start_loop_iterations:
7792 case Intrinsic::annotation:
7793 case Intrinsic::ptr_annotation:
7794 Ops.push_back(II->getArgOperand(0));
7795 return nullptr;
7796 default:
7797 break;
7798 }
7799 }
7800 break;
7801 }
7802
7803 return nullptr;
7804}
7805
7806const SCEV *ScalarEvolution::createSCEV(Value *V) {
7807 if (!isSCEVable(V->getType()))
7808 return getUnknown(V);
7809
7810 if (Instruction *I = dyn_cast<Instruction>(V)) {
7811 // Don't attempt to analyze instructions in blocks that aren't
7812 // reachable. Such instructions don't matter, and they aren't required
7813 // to obey basic rules for definitions dominating uses which this
7814 // analysis depends on.
7815 if (!DT.isReachableFromEntry(I->getParent()))
7816 return getUnknown(PoisonValue::get(V->getType()));
7817 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7818 return getConstant(CI);
7819 else if (isa<GlobalAlias>(V))
7820 return getUnknown(V);
7821 else if (!isa<ConstantExpr>(V))
7822 return getUnknown(V);
7823
7824 const SCEV *LHS;
7825 const SCEV *RHS;
7826
7828 if (auto BO =
7830 switch (BO->Opcode) {
7831 case Instruction::Add: {
7832 // The simple thing to do would be to just call getSCEV on both operands
7833 // and call getAddExpr with the result. However if we're looking at a
7834 // bunch of things all added together, this can be quite inefficient,
7835 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7836 // Instead, gather up all the operands and make a single getAddExpr call.
7837 // LLVM IR canonical form means we need only traverse the left operands.
7839 do {
7840 if (BO->Op) {
7841 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7842 AddOps.push_back(OpSCEV);
7843 break;
7844 }
7845
7846 // If a NUW or NSW flag can be applied to the SCEV for this
7847 // addition, then compute the SCEV for this addition by itself
7848 // with a separate call to getAddExpr. We need to do that
7849 // instead of pushing the operands of the addition onto AddOps,
7850 // since the flags are only known to apply to this particular
7851 // addition - they may not apply to other additions that can be
7852 // formed with operands from AddOps.
7853 const SCEV *RHS = getSCEV(BO->RHS);
7854 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7855 if (Flags != SCEV::FlagAnyWrap) {
7856 const SCEV *LHS = getSCEV(BO->LHS);
7857 if (BO->Opcode == Instruction::Sub)
7858 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
7859 else
7860 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
7861 break;
7862 }
7863 }
7864
7865 if (BO->Opcode == Instruction::Sub)
7866 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
7867 else
7868 AddOps.push_back(getSCEV(BO->RHS));
7869
7870 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7872 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7873 NewBO->Opcode != Instruction::Sub)) {
7874 AddOps.push_back(getSCEV(BO->LHS));
7875 break;
7876 }
7877 BO = NewBO;
7878 } while (true);
7879
7880 return getAddExpr(AddOps);
7881 }
7882
7883 case Instruction::Mul: {
7885 do {
7886 if (BO->Op) {
7887 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7888 MulOps.push_back(OpSCEV);
7889 break;
7890 }
7891
7892 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7893 if (Flags != SCEV::FlagAnyWrap) {
7894 LHS = getSCEV(BO->LHS);
7895 RHS = getSCEV(BO->RHS);
7896 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
7897 break;
7898 }
7899 }
7900
7901 MulOps.push_back(getSCEV(BO->RHS));
7902 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7904 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7905 MulOps.push_back(getSCEV(BO->LHS));
7906 break;
7907 }
7908 BO = NewBO;
7909 } while (true);
7910
7911 return getMulExpr(MulOps);
7912 }
7913 case Instruction::UDiv:
7914 LHS = getSCEV(BO->LHS);
7915 RHS = getSCEV(BO->RHS);
7916 return getUDivExpr(LHS, RHS);
7917 case Instruction::URem:
7918 LHS = getSCEV(BO->LHS);
7919 RHS = getSCEV(BO->RHS);
7920 return getURemExpr(LHS, RHS);
7921 case Instruction::Sub: {
7923 if (BO->Op)
7924 Flags = getNoWrapFlagsFromUB(BO->Op);
7925 LHS = getSCEV(BO->LHS);
7926 RHS = getSCEV(BO->RHS);
7927 return getMinusSCEV(LHS, RHS, Flags);
7928 }
7929 case Instruction::And:
7930 // For an expression like x&255 that merely masks off the high bits,
7931 // use zext(trunc(x)) as the SCEV expression.
7932 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7933 if (CI->isZero())
7934 return getSCEV(BO->RHS);
7935 if (CI->isMinusOne())
7936 return getSCEV(BO->LHS);
7937 const APInt &A = CI->getValue();
7938
7939 // Instcombine's ShrinkDemandedConstant may strip bits out of
7940 // constants, obscuring what would otherwise be a low-bits mask.
7941 // Use computeKnownBits to compute what ShrinkDemandedConstant
7942 // knew about to reconstruct a low-bits mask value.
7943 unsigned LZ = A.countl_zero();
7944 unsigned TZ = A.countr_zero();
7945 unsigned BitWidth = A.getBitWidth();
7946 KnownBits Known(BitWidth);
7947 computeKnownBits(BO->LHS, Known, getDataLayout(), &AC, nullptr, &DT);
7948
7949 APInt EffectiveMask =
7950 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
7951 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
7952 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
7953 const SCEV *LHS = getSCEV(BO->LHS);
7954 const SCEV *ShiftedLHS = nullptr;
7955 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
7956 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
7957 // For an expression like (x * 8) & 8, simplify the multiply.
7958 unsigned MulZeros = OpC->getAPInt().countr_zero();
7959 unsigned GCD = std::min(MulZeros, TZ);
7960 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
7962 MulOps.push_back(getConstant(OpC->getAPInt().ashr(GCD)));
7963 append_range(MulOps, LHSMul->operands().drop_front());
7964 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
7965 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
7966 }
7967 }
7968 if (!ShiftedLHS)
7969 ShiftedLHS = getUDivExpr(LHS, MulCount);
7970 return getMulExpr(
7972 getTruncateExpr(ShiftedLHS,
7973 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
7974 BO->LHS->getType()),
7975 MulCount);
7976 }
7977 }
7978 // Binary `and` is a bit-wise `umin`.
7979 if (BO->LHS->getType()->isIntegerTy(1)) {
7980 LHS = getSCEV(BO->LHS);
7981 RHS = getSCEV(BO->RHS);
7982 return getUMinExpr(LHS, RHS);
7983 }
7984 break;
7985
7986 case Instruction::Or:
7987 // Binary `or` is a bit-wise `umax`.
7988 if (BO->LHS->getType()->isIntegerTy(1)) {
7989 LHS = getSCEV(BO->LHS);
7990 RHS = getSCEV(BO->RHS);
7991 return getUMaxExpr(LHS, RHS);
7992 }
7993 break;
7994
7995 case Instruction::Xor:
7996 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7997 // If the RHS of xor is -1, then this is a not operation.
7998 if (CI->isMinusOne())
7999 return getNotSCEV(getSCEV(BO->LHS));
8000
8001 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
8002 // This is a variant of the check for xor with -1, and it handles
8003 // the case where instcombine has trimmed non-demanded bits out
8004 // of an xor with -1.
8005 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
8006 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
8007 if (LBO->getOpcode() == Instruction::And &&
8008 LCI->getValue() == CI->getValue())
8009 if (const SCEVZeroExtendExpr *Z =
8011 Type *UTy = BO->LHS->getType();
8012 const SCEV *Z0 = Z->getOperand();
8013 Type *Z0Ty = Z0->getType();
8014 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
8015
8016 // If C is a low-bits mask, the zero extend is serving to
8017 // mask off the high bits. Complement the operand and
8018 // re-apply the zext.
8019 if (CI->getValue().isMask(Z0TySize))
8020 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
8021
8022 // If C is a single bit, it may be in the sign-bit position
8023 // before the zero-extend. In this case, represent the xor
8024 // using an add, which is equivalent, and re-apply the zext.
8025 APInt Trunc = CI->getValue().trunc(Z0TySize);
8026 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
8027 Trunc.isSignMask())
8028 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
8029 UTy);
8030 }
8031 }
8032 break;
8033
8034 case Instruction::Shl:
8035 // Turn shift left of a constant amount into a multiply.
8036 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
8037 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
8038
8039 // If the shift count is not less than the bitwidth, the result of
8040 // the shift is undefined. Don't try to analyze it, because the
8041 // resolution chosen here may differ from the resolution chosen in
8042 // other parts of the compiler.
8043 if (SA->getValue().uge(BitWidth))
8044 break;
8045
8046 // We can safely preserve the nuw flag in all cases. It's also safe to
8047 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
8048 // requires special handling. It can be preserved as long as we're not
8049 // left shifting by bitwidth - 1.
8050 auto Flags = SCEV::FlagAnyWrap;
8051 if (BO->Op) {
8052 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
8053 if ((MulFlags & SCEV::FlagNSW) &&
8054 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
8056 if (MulFlags & SCEV::FlagNUW)
8058 }
8059
8060 ConstantInt *X = ConstantInt::get(
8061 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
8062 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
8063 }
8064 break;
8065
8066 case Instruction::AShr:
8067 // AShr X, C, where C is a constant.
8068 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
8069 if (!CI)
8070 break;
8071
8072 Type *OuterTy = BO->LHS->getType();
8073 uint64_t BitWidth = getTypeSizeInBits(OuterTy);
8074 // If the shift count is not less than the bitwidth, the result of
8075 // the shift is undefined. Don't try to analyze it, because the
8076 // resolution chosen here may differ from the resolution chosen in
8077 // other parts of the compiler.
8078 if (CI->getValue().uge(BitWidth))
8079 break;
8080
8081 if (CI->isZero())
8082 return getSCEV(BO->LHS); // shift by zero --> noop
8083
8084 uint64_t AShrAmt = CI->getZExtValue();
8085 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
8086
8087 Operator *L = dyn_cast<Operator>(BO->LHS);
8088 const SCEV *AddTruncateExpr = nullptr;
8089 ConstantInt *ShlAmtCI = nullptr;
8090 const SCEV *AddConstant = nullptr;
8091
8092 if (L && L->getOpcode() == Instruction::Add) {
8093 // X = Shl A, n
8094 // Y = Add X, c
8095 // Z = AShr Y, m
8096 // n, c and m are constants.
8097
8098 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
8099 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
8100 if (LShift && LShift->getOpcode() == Instruction::Shl) {
8101 if (AddOperandCI) {
8102 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
8103 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
8104 // since we truncate to TruncTy, the AddConstant should be of the
8105 // same type, so create a new Constant with type same as TruncTy.
8106 // Also, the Add constant should be shifted right by AShr amount.
8107 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
8108 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
8109 // we model the expression as sext(add(trunc(A), c << n)), since the
8110 // sext(trunc) part is already handled below, we create a
8111 // AddExpr(TruncExp) which will be used later.
8112 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8113 }
8114 }
8115 } else if (L && L->getOpcode() == Instruction::Shl) {
8116 // X = Shl A, n
8117 // Y = AShr X, m
8118 // Both n and m are constant.
8119
8120 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
8121 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
8122 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8123 }
8124
8125 if (AddTruncateExpr && ShlAmtCI) {
8126 // We can merge the two given cases into a single SCEV statement,
8127 // incase n = m, the mul expression will be 2^0, so it gets resolved to
8128 // a simpler case. The following code handles the two cases:
8129 //
8130 // 1) For a two-shift sext-inreg, i.e. n = m,
8131 // use sext(trunc(x)) as the SCEV expression.
8132 //
8133 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
8134 // expression. We already checked that ShlAmt < BitWidth, so
8135 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
8136 // ShlAmt - AShrAmt < Amt.
8137 const APInt &ShlAmt = ShlAmtCI->getValue();
8138 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
8139 APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
8140 ShlAmtCI->getZExtValue() - AShrAmt);
8141 const SCEV *CompositeExpr =
8142 getMulExpr(AddTruncateExpr, getConstant(Mul));
8143 if (L->getOpcode() != Instruction::Shl)
8144 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8145
8146 return getSignExtendExpr(CompositeExpr, OuterTy);
8147 }
8148 }
8149 break;
8150 }
8151 }
8152
8153 switch (U->getOpcode()) {
8154 case Instruction::Trunc:
8155 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8156
8157 case Instruction::ZExt:
8158 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8159
8160 case Instruction::SExt:
8161 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8163 // The NSW flag of a subtract does not always survive the conversion to
8164 // A + (-1)*B. By pushing sign extension onto its operands we are much
8165 // more likely to preserve NSW and allow later AddRec optimisations.
8166 //
8167 // NOTE: This is effectively duplicating this logic from getSignExtend:
8168 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8169 // but by that point the NSW information has potentially been lost.
8170 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8171 Type *Ty = U->getType();
8172 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8173 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8174 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8175 }
8176 }
8177 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8178
8179 case Instruction::BitCast:
8180 // BitCasts are no-op casts so we just eliminate the cast.
8181 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8182 return getSCEV(U->getOperand(0));
8183 break;
8184
8185 case Instruction::PtrToAddr:
8186 return getPtrToAddrExpr(getSCEV(U->getOperand(0)));
8187
8188 case Instruction::PtrToInt: {
8189 // Pointer to integer cast is straight-forward, so do model it.
8190 const SCEV *Op = getSCEV(U->getOperand(0));
8191 Type *DstIntTy = U->getType();
8192 // But only if effective SCEV (integer) type is wide enough to represent
8193 // all possible pointer values.
8194 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8195 if (isa<SCEVCouldNotCompute>(IntOp))
8196 return getUnknown(V);
8197 return IntOp;
8198 }
8199 case Instruction::IntToPtr:
8200 // Just don't deal with inttoptr casts.
8201 return getUnknown(V);
8202
8203 case Instruction::SDiv:
8204 // If both operands are non-negative, this is just an udiv.
8205 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8206 isKnownNonNegative(getSCEV(U->getOperand(1))))
8207 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8208 break;
8209
8210 case Instruction::SRem:
8211 // If both operands are non-negative, this is just an urem.
8212 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8213 isKnownNonNegative(getSCEV(U->getOperand(1))))
8214 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8215 break;
8216
8217 case Instruction::GetElementPtr:
8218 return createNodeForGEP(cast<GEPOperator>(U));
8219
8220 case Instruction::PHI:
8221 return createNodeForPHI(cast<PHINode>(U));
8222
8223 case Instruction::Select:
8224 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8225 U->getOperand(2));
8226
8227 case Instruction::Call:
8228 case Instruction::Invoke:
8229 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8230 return getSCEV(RV);
8231
8232 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8233 switch (II->getIntrinsicID()) {
8234 case Intrinsic::abs:
8235 return getAbsExpr(
8236 getSCEV(II->getArgOperand(0)),
8237 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8238 case Intrinsic::umax:
8239 LHS = getSCEV(II->getArgOperand(0));
8240 RHS = getSCEV(II->getArgOperand(1));
8241 return getUMaxExpr(LHS, RHS);
8242 case Intrinsic::umin:
8243 LHS = getSCEV(II->getArgOperand(0));
8244 RHS = getSCEV(II->getArgOperand(1));
8245 return getUMinExpr(LHS, RHS);
8246 case Intrinsic::smax:
8247 LHS = getSCEV(II->getArgOperand(0));
8248 RHS = getSCEV(II->getArgOperand(1));
8249 return getSMaxExpr(LHS, RHS);
8250 case Intrinsic::smin:
8251 LHS = getSCEV(II->getArgOperand(0));
8252 RHS = getSCEV(II->getArgOperand(1));
8253 return getSMinExpr(LHS, RHS);
8254 case Intrinsic::usub_sat: {
8255 const SCEV *X = getSCEV(II->getArgOperand(0));
8256 const SCEV *Y = getSCEV(II->getArgOperand(1));
8257 const SCEV *ClampedY = getUMinExpr(X, Y);
8258 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8259 }
8260 case Intrinsic::uadd_sat: {
8261 const SCEV *X = getSCEV(II->getArgOperand(0));
8262 const SCEV *Y = getSCEV(II->getArgOperand(1));
8263 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8264 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8265 }
8266 case Intrinsic::start_loop_iterations:
8267 case Intrinsic::annotation:
8268 case Intrinsic::ptr_annotation:
8269 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8270 // just eqivalent to the first operand for SCEV purposes.
8271 return getSCEV(II->getArgOperand(0));
8272 case Intrinsic::vscale:
8273 return getVScale(II->getType());
8274 default:
8275 break;
8276 }
8277 }
8278 break;
8279 }
8280
8281 return getUnknown(V);
8282}
8283
8284//===----------------------------------------------------------------------===//
8285// Iteration Count Computation Code
8286//
8287
8289 if (isa<SCEVCouldNotCompute>(ExitCount))
8290 return getCouldNotCompute();
8291
8292 auto *ExitCountType = ExitCount->getType();
8293 assert(ExitCountType->isIntegerTy());
8294 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8295 1 + ExitCountType->getScalarSizeInBits());
8296 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8297}
8298
8300 Type *EvalTy,
8301 const Loop *L) {
8302 if (isa<SCEVCouldNotCompute>(ExitCount))
8303 return getCouldNotCompute();
8304
8305 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8306 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8307
8308 auto CanAddOneWithoutOverflow = [&]() {
8309 ConstantRange ExitCountRange =
8310 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8311 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8312 return true;
8313
8314 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8315 getMinusOne(ExitCount->getType()));
8316 };
8317
8318 // If we need to zero extend the backedge count, check if we can add one to
8319 // it prior to zero extending without overflow. Provided this is safe, it
8320 // allows better simplification of the +1.
8321 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8322 return getZeroExtendExpr(
8323 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8324
8325 // Get the total trip count from the count by adding 1. This may wrap.
8326 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8327}
8328
8329static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8330 if (!ExitCount)
8331 return 0;
8332
8333 ConstantInt *ExitConst = ExitCount->getValue();
8334
8335 // Guard against huge trip counts.
8336 if (ExitConst->getValue().getActiveBits() > 32)
8337 return 0;
8338
8339 // In case of integer overflow, this returns 0, which is correct.
8340 return ((unsigned)ExitConst->getZExtValue()) + 1;
8341}
8342
8344 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8345 return getConstantTripCount(ExitCount);
8346}
8347
8348unsigned
8350 const BasicBlock *ExitingBlock) {
8351 assert(ExitingBlock && "Must pass a non-null exiting block!");
8352 assert(L->isLoopExiting(ExitingBlock) &&
8353 "Exiting block must actually branch out of the loop!");
8354 const SCEVConstant *ExitCount =
8355 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8356 return getConstantTripCount(ExitCount);
8357}
8358
8360 const Loop *L, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8361
8362 const auto *MaxExitCount =
8363 Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
8365 return getConstantTripCount(dyn_cast<SCEVConstant>(MaxExitCount));
8366}
8367
8369 SmallVector<BasicBlock *, 8> ExitingBlocks;
8370 L->getExitingBlocks(ExitingBlocks);
8371
8372 std::optional<unsigned> Res;
8373 for (auto *ExitingBB : ExitingBlocks) {
8374 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8375 if (!Res)
8376 Res = Multiple;
8377 Res = std::gcd(*Res, Multiple);
8378 }
8379 return Res.value_or(1);
8380}
8381
8383 const SCEV *ExitCount) {
8384 if (isa<SCEVCouldNotCompute>(ExitCount))
8385 return 1;
8386
8387 // Get the trip count
8388 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8389
8390 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8391 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8392 // the greatest power of 2 divisor less than 2^32.
8393 return Multiple.getActiveBits() > 32
8394 ? 1U << std::min(31U, Multiple.countTrailingZeros())
8395 : (unsigned)Multiple.getZExtValue();
8396}
8397
8398/// Returns the largest constant divisor of the trip count of this loop as a
8399/// normal unsigned value, if possible. This means that the actual trip count is
8400/// always a multiple of the returned value (don't forget the trip count could
8401/// very well be zero as well!).
8402///
8403/// Returns 1 if the trip count is unknown or not guaranteed to be the
8404/// multiple of a constant (which is also the case if the trip count is simply
8405/// constant, use getSmallConstantTripCount for that case), Will also return 1
8406/// if the trip count is very large (>= 2^32).
8407///
8408/// As explained in the comments for getSmallConstantTripCount, this assumes
8409/// that control exits the loop via ExitingBlock.
8410unsigned
8412 const BasicBlock *ExitingBlock) {
8413 assert(ExitingBlock && "Must pass a non-null exiting block!");
8414 assert(L->isLoopExiting(ExitingBlock) &&
8415 "Exiting block must actually branch out of the loop!");
8416 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8417 return getSmallConstantTripMultiple(L, ExitCount);
8418}
8419
8421 const BasicBlock *ExitingBlock,
8422 ExitCountKind Kind) {
8423 switch (Kind) {
8424 case Exact:
8425 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8426 case SymbolicMaximum:
8427 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8428 case ConstantMaximum:
8429 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8430 };
8431 llvm_unreachable("Invalid ExitCountKind!");
8432}
8433
8435 const Loop *L, const BasicBlock *ExitingBlock,
8437 switch (Kind) {
8438 case Exact:
8439 return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
8440 Predicates);
8441 case SymbolicMaximum:
8442 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
8443 Predicates);
8444 case ConstantMaximum:
8445 return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
8446 Predicates);
8447 };
8448 llvm_unreachable("Invalid ExitCountKind!");
8449}
8450
8453 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8454}
8455
8457 ExitCountKind Kind) {
8458 switch (Kind) {
8459 case Exact:
8460 return getBackedgeTakenInfo(L).getExact(L, this);
8461 case ConstantMaximum:
8462 return getBackedgeTakenInfo(L).getConstantMax(this);
8463 case SymbolicMaximum:
8464 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8465 };
8466 llvm_unreachable("Invalid ExitCountKind!");
8467}
8468
8471 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8472}
8473
8476 return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds);
8477}
8478
8480 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8481}
8482
8483/// Push PHI nodes in the header of the given loop onto the given Worklist.
8484static void PushLoopPHIs(const Loop *L,
8487 BasicBlock *Header = L->getHeader();
8488
8489 // Push all Loop-header PHIs onto the Worklist stack.
8490 for (PHINode &PN : Header->phis())
8491 if (Visited.insert(&PN).second)
8492 Worklist.push_back(&PN);
8493}
8494
8495ScalarEvolution::BackedgeTakenInfo &
8496ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8497 auto &BTI = getBackedgeTakenInfo(L);
8498 if (BTI.hasFullInfo())
8499 return BTI;
8500
8501 auto Pair = PredicatedBackedgeTakenCounts.try_emplace(L);
8502
8503 if (!Pair.second)
8504 return Pair.first->second;
8505
8506 BackedgeTakenInfo Result =
8507 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8508
8509 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8510}
8511
8512ScalarEvolution::BackedgeTakenInfo &
8513ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8514 // Initially insert an invalid entry for this loop. If the insertion
8515 // succeeds, proceed to actually compute a backedge-taken count and
8516 // update the value. The temporary CouldNotCompute value tells SCEV
8517 // code elsewhere that it shouldn't attempt to request a new
8518 // backedge-taken count, which could result in infinite recursion.
8519 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8520 BackedgeTakenCounts.try_emplace(L);
8521 if (!Pair.second)
8522 return Pair.first->second;
8523
8524 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8525 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8526 // must be cleared in this scope.
8527 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8528
8529 // Now that we know more about the trip count for this loop, forget any
8530 // existing SCEV values for PHI nodes in this loop since they are only
8531 // conservative estimates made without the benefit of trip count
8532 // information. This invalidation is not necessary for correctness, and is
8533 // only done to produce more precise results.
8534 if (Result.hasAnyInfo()) {
8535 // Invalidate any expression using an addrec in this loop.
8537 auto LoopUsersIt = LoopUsers.find(L);
8538 if (LoopUsersIt != LoopUsers.end())
8539 append_range(ToForget, LoopUsersIt->second);
8540 forgetMemoizedResults(ToForget);
8541
8542 // Invalidate constant-evolved loop header phis.
8543 for (PHINode &PN : L->getHeader()->phis())
8544 ConstantEvolutionLoopExitValue.erase(&PN);
8545 }
8546
8547 // Re-lookup the insert position, since the call to
8548 // computeBackedgeTakenCount above could result in a
8549 // recusive call to getBackedgeTakenInfo (on a different
8550 // loop), which would invalidate the iterator computed
8551 // earlier.
8552 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8553}
8554
8556 // This method is intended to forget all info about loops. It should
8557 // invalidate caches as if the following happened:
8558 // - The trip counts of all loops have changed arbitrarily
8559 // - Every llvm::Value has been updated in place to produce a different
8560 // result.
8561 BackedgeTakenCounts.clear();
8562 PredicatedBackedgeTakenCounts.clear();
8563 BECountUsers.clear();
8564 LoopPropertiesCache.clear();
8565 ConstantEvolutionLoopExitValue.clear();
8566 ValueExprMap.clear();
8567 ValuesAtScopes.clear();
8568 ValuesAtScopesUsers.clear();
8569 LoopDispositions.clear();
8570 BlockDispositions.clear();
8571 UnsignedRanges.clear();
8572 SignedRanges.clear();
8573 ExprValueMap.clear();
8574 HasRecMap.clear();
8575 ConstantMultipleCache.clear();
8576 PredicatedSCEVRewrites.clear();
8577 FoldCache.clear();
8578 FoldCacheUser.clear();
8579}
8580void ScalarEvolution::visitAndClearUsers(
8584 while (!Worklist.empty()) {
8585 Instruction *I = Worklist.pop_back_val();
8586 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8587 continue;
8588
8590 ValueExprMap.find_as(static_cast<Value *>(I));
8591 if (It != ValueExprMap.end()) {
8592 eraseValueFromMap(It->first);
8593 ToForget.push_back(It->second);
8594 if (PHINode *PN = dyn_cast<PHINode>(I))
8595 ConstantEvolutionLoopExitValue.erase(PN);
8596 }
8597
8598 PushDefUseChildren(I, Worklist, Visited);
8599 }
8600}
8601
8603 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8607
8608 // Iterate over all the loops and sub-loops to drop SCEV information.
8609 while (!LoopWorklist.empty()) {
8610 auto *CurrL = LoopWorklist.pop_back_val();
8611
8612 // Drop any stored trip count value.
8613 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8614 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8615
8616 // Drop information about predicated SCEV rewrites for this loop.
8617 for (auto I = PredicatedSCEVRewrites.begin();
8618 I != PredicatedSCEVRewrites.end();) {
8619 std::pair<const SCEV *, const Loop *> Entry = I->first;
8620 if (Entry.second == CurrL)
8621 PredicatedSCEVRewrites.erase(I++);
8622 else
8623 ++I;
8624 }
8625
8626 auto LoopUsersItr = LoopUsers.find(CurrL);
8627 if (LoopUsersItr != LoopUsers.end())
8628 llvm::append_range(ToForget, LoopUsersItr->second);
8629
8630 // Drop information about expressions based on loop-header PHIs.
8631 PushLoopPHIs(CurrL, Worklist, Visited);
8632 visitAndClearUsers(Worklist, Visited, ToForget);
8633
8634 LoopPropertiesCache.erase(CurrL);
8635 // Forget all contained loops too, to avoid dangling entries in the
8636 // ValuesAtScopes map.
8637 LoopWorklist.append(CurrL->begin(), CurrL->end());
8638 }
8639 forgetMemoizedResults(ToForget);
8640}
8641
8643 forgetLoop(L->getOutermostLoop());
8644}
8645
8648 if (!I) return;
8649
8650 // Drop information about expressions based on loop-header PHIs.
8654 Worklist.push_back(I);
8655 Visited.insert(I);
8656 visitAndClearUsers(Worklist, Visited, ToForget);
8657
8658 forgetMemoizedResults(ToForget);
8659}
8660
8662 if (!isSCEVable(V->getType()))
8663 return;
8664
8665 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8666 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8667 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8668 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8669 if (const SCEV *S = getExistingSCEV(V)) {
8670 struct InvalidationRootCollector {
8671 Loop *L;
8673
8674 InvalidationRootCollector(Loop *L) : L(L) {}
8675
8676 bool follow(const SCEV *S) {
8677 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8678 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8679 if (L->contains(I))
8680 Roots.push_back(S);
8681 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8682 if (L->contains(AddRec->getLoop()))
8683 Roots.push_back(S);
8684 }
8685 return true;
8686 }
8687 bool isDone() const { return false; }
8688 };
8689
8690 InvalidationRootCollector C(L);
8691 visitAll(S, C);
8692 forgetMemoizedResults(C.Roots);
8693 }
8694
8695 // Also perform the normal invalidation.
8696 forgetValue(V);
8697}
8698
8699void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8700
8702 // Unless a specific value is passed to invalidation, completely clear both
8703 // caches.
8704 if (!V) {
8705 BlockDispositions.clear();
8706 LoopDispositions.clear();
8707 return;
8708 }
8709
8710 if (!isSCEVable(V->getType()))
8711 return;
8712
8713 const SCEV *S = getExistingSCEV(V);
8714 if (!S)
8715 return;
8716
8717 // Invalidate the block and loop dispositions cached for S. Dispositions of
8718 // S's users may change if S's disposition changes (i.e. a user may change to
8719 // loop-invariant, if S changes to loop invariant), so also invalidate
8720 // dispositions of S's users recursively.
8721 SmallVector<const SCEV *, 8> Worklist = {S};
8723 while (!Worklist.empty()) {
8724 const SCEV *Curr = Worklist.pop_back_val();
8725 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8726 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8727 if (!LoopDispoRemoved && !BlockDispoRemoved)
8728 continue;
8729 auto Users = SCEVUsers.find(Curr);
8730 if (Users != SCEVUsers.end())
8731 for (const auto *User : Users->second)
8732 if (Seen.insert(User).second)
8733 Worklist.push_back(User);
8734 }
8735}
8736
8737/// Get the exact loop backedge taken count considering all loop exits. A
8738/// computable result can only be returned for loops with all exiting blocks
8739/// dominating the latch. howFarToZero assumes that the limit of each loop test
8740/// is never skipped. This is a valid assumption as long as the loop exits via
8741/// that test. For precise results, it is the caller's responsibility to specify
8742/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8743const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
8744 const Loop *L, ScalarEvolution *SE,
8746 // If any exits were not computable, the loop is not computable.
8747 if (!isComplete() || ExitNotTaken.empty())
8748 return SE->getCouldNotCompute();
8749
8750 const BasicBlock *Latch = L->getLoopLatch();
8751 // All exiting blocks we have collected must dominate the only backedge.
8752 if (!Latch)
8753 return SE->getCouldNotCompute();
8754
8755 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8756 // count is simply a minimum out of all these calculated exit counts.
8758 for (const auto &ENT : ExitNotTaken) {
8759 const SCEV *BECount = ENT.ExactNotTaken;
8760 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8761 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8762 "We should only have known counts for exiting blocks that dominate "
8763 "latch!");
8764
8765 Ops.push_back(BECount);
8766
8767 if (Preds)
8768 append_range(*Preds, ENT.Predicates);
8769
8770 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8771 "Predicate should be always true!");
8772 }
8773
8774 // If an earlier exit exits on the first iteration (exit count zero), then
8775 // a later poison exit count should not propagate into the result. This are
8776 // exactly the semantics provided by umin_seq.
8777 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8778}
8779
8780const ScalarEvolution::ExitNotTakenInfo *
8781ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
8782 const BasicBlock *ExitingBlock,
8783 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8784 for (const auto &ENT : ExitNotTaken)
8785 if (ENT.ExitingBlock == ExitingBlock) {
8786 if (ENT.hasAlwaysTruePredicate())
8787 return &ENT;
8788 else if (Predicates) {
8789 append_range(*Predicates, ENT.Predicates);
8790 return &ENT;
8791 }
8792 }
8793
8794 return nullptr;
8795}
8796
8797/// getConstantMax - Get the constant max backedge taken count for the loop.
8798const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8799 ScalarEvolution *SE,
8800 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8801 if (!getConstantMax())
8802 return SE->getCouldNotCompute();
8803
8804 for (const auto &ENT : ExitNotTaken)
8805 if (!ENT.hasAlwaysTruePredicate()) {
8806 if (!Predicates)
8807 return SE->getCouldNotCompute();
8808 append_range(*Predicates, ENT.Predicates);
8809 }
8810
8811 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8812 isa<SCEVConstant>(getConstantMax())) &&
8813 "No point in having a non-constant max backedge taken count!");
8814 return getConstantMax();
8815}
8816
8817const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8818 const Loop *L, ScalarEvolution *SE,
8819 SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8820 if (!SymbolicMax) {
8821 // Form an expression for the maximum exit count possible for this loop. We
8822 // merge the max and exact information to approximate a version of
8823 // getConstantMaxBackedgeTakenCount which isn't restricted to just
8824 // constants.
8826
8827 for (const auto &ENT : ExitNotTaken) {
8828 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
8829 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
8830 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
8831 "We should only have known counts for exiting blocks that "
8832 "dominate latch!");
8833 ExitCounts.push_back(ExitCount);
8834 if (Predicates)
8835 append_range(*Predicates, ENT.Predicates);
8836
8837 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
8838 "Predicate should be always true!");
8839 }
8840 }
8841 if (ExitCounts.empty())
8842 SymbolicMax = SE->getCouldNotCompute();
8843 else
8844 SymbolicMax =
8845 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
8846 }
8847 return SymbolicMax;
8848}
8849
8850bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
8851 ScalarEvolution *SE) const {
8852 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8853 return !ENT.hasAlwaysTruePredicate();
8854 };
8855 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
8856}
8857
8860
8862 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8863 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8867 // If we prove the max count is zero, so is the symbolic bound. This happens
8868 // in practice due to differences in a) how context sensitive we've chosen
8869 // to be and b) how we reason about bounds implied by UB.
8870 if (ConstantMaxNotTaken->isZero()) {
8871 this->ExactNotTaken = E = ConstantMaxNotTaken;
8872 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
8873 }
8874
8877 "Exact is not allowed to be less precise than Constant Max");
8880 "Exact is not allowed to be less precise than Symbolic Max");
8883 "Symbolic Max is not allowed to be less precise than Constant Max");
8886 "No point in having a non-constant max backedge taken count!");
8888 for (const auto PredList : PredLists)
8889 for (const auto *P : PredList) {
8890 if (SeenPreds.contains(P))
8891 continue;
8892 assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
8893 SeenPreds.insert(P);
8894 Predicates.push_back(P);
8895 }
8896 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
8897 "Backedge count should be int");
8899 !ConstantMaxNotTaken->getType()->isPointerTy()) &&
8900 "Max backedge count should be int");
8901}
8902
8910
8911/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
8912/// computable exit into a persistent ExitNotTakenInfo array.
8913ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
8915 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
8916 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
8917 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8918
8919 ExitNotTaken.reserve(ExitCounts.size());
8920 std::transform(ExitCounts.begin(), ExitCounts.end(),
8921 std::back_inserter(ExitNotTaken),
8922 [&](const EdgeExitInfo &EEI) {
8923 BasicBlock *ExitBB = EEI.first;
8924 const ExitLimit &EL = EEI.second;
8925 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
8926 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
8927 EL.Predicates);
8928 });
8929 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
8930 isa<SCEVConstant>(ConstantMax)) &&
8931 "No point in having a non-constant max backedge taken count!");
8932}
8933
8934/// Compute the number of times the backedge of the specified loop will execute.
8935ScalarEvolution::BackedgeTakenInfo
8936ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
8937 bool AllowPredicates) {
8938 SmallVector<BasicBlock *, 8> ExitingBlocks;
8939 L->getExitingBlocks(ExitingBlocks);
8940
8941 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8942
8944 bool CouldComputeBECount = true;
8945 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
8946 const SCEV *MustExitMaxBECount = nullptr;
8947 const SCEV *MayExitMaxBECount = nullptr;
8948 bool MustExitMaxOrZero = false;
8949 bool IsOnlyExit = ExitingBlocks.size() == 1;
8950
8951 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
8952 // and compute maxBECount.
8953 // Do a union of all the predicates here.
8954 for (BasicBlock *ExitBB : ExitingBlocks) {
8955 // We canonicalize untaken exits to br (constant), ignore them so that
8956 // proving an exit untaken doesn't negatively impact our ability to reason
8957 // about the loop as whole.
8958 if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
8959 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
8960 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8961 if (ExitIfTrue == CI->isZero())
8962 continue;
8963 }
8964
8965 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
8966
8967 assert((AllowPredicates || EL.Predicates.empty()) &&
8968 "Predicated exit limit when predicates are not allowed!");
8969
8970 // 1. For each exit that can be computed, add an entry to ExitCounts.
8971 // CouldComputeBECount is true only if all exits can be computed.
8972 if (EL.ExactNotTaken != getCouldNotCompute())
8973 ++NumExitCountsComputed;
8974 else
8975 // We couldn't compute an exact value for this exit, so
8976 // we won't be able to compute an exact value for the loop.
8977 CouldComputeBECount = false;
8978 // Remember exit count if either exact or symbolic is known. Because
8979 // Exact always implies symbolic, only check symbolic.
8980 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
8981 ExitCounts.emplace_back(ExitBB, EL);
8982 else {
8983 assert(EL.ExactNotTaken == getCouldNotCompute() &&
8984 "Exact is known but symbolic isn't?");
8985 ++NumExitCountsNotComputed;
8986 }
8987
8988 // 2. Derive the loop's MaxBECount from each exit's max number of
8989 // non-exiting iterations. Partition the loop exits into two kinds:
8990 // LoopMustExits and LoopMayExits.
8991 //
8992 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
8993 // is a LoopMayExit. If any computable LoopMustExit is found, then
8994 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
8995 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
8996 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
8997 // any
8998 // computable EL.ConstantMaxNotTaken.
8999 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
9000 DT.dominates(ExitBB, Latch)) {
9001 if (!MustExitMaxBECount) {
9002 MustExitMaxBECount = EL.ConstantMaxNotTaken;
9003 MustExitMaxOrZero = EL.MaxOrZero;
9004 } else {
9005 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
9006 EL.ConstantMaxNotTaken);
9007 }
9008 } else if (MayExitMaxBECount != getCouldNotCompute()) {
9009 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
9010 MayExitMaxBECount = EL.ConstantMaxNotTaken;
9011 else {
9012 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
9013 EL.ConstantMaxNotTaken);
9014 }
9015 }
9016 }
9017 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
9018 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
9019 // The loop backedge will be taken the maximum or zero times if there's
9020 // a single exit that must be taken the maximum or zero times.
9021 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
9022
9023 // Remember which SCEVs are used in exit limits for invalidation purposes.
9024 // We only care about non-constant SCEVs here, so we can ignore
9025 // EL.ConstantMaxNotTaken
9026 // and MaxBECount, which must be SCEVConstant.
9027 for (const auto &Pair : ExitCounts) {
9028 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
9029 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
9030 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
9031 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
9032 {L, AllowPredicates});
9033 }
9034 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
9035 MaxBECount, MaxOrZero);
9036}
9037
9038ScalarEvolution::ExitLimit
9039ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
9040 bool IsOnlyExit, bool AllowPredicates) {
9041 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
9042 // If our exiting block does not dominate the latch, then its connection with
9043 // loop's exit limit may be far from trivial.
9044 const BasicBlock *Latch = L->getLoopLatch();
9045 if (!Latch || !DT.dominates(ExitingBlock, Latch))
9046 return getCouldNotCompute();
9047
9048 Instruction *Term = ExitingBlock->getTerminator();
9049 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
9050 assert(BI->isConditional() && "If unconditional, it can't be in loop!");
9051 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
9052 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
9053 "It should have one successor in loop and one exit block!");
9054 // Proceed to the next level to examine the exit condition expression.
9055 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
9056 /*ControlsOnlyExit=*/IsOnlyExit,
9057 AllowPredicates);
9058 }
9059
9060 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
9061 // For switch, make sure that there is a single exit from the loop.
9062 BasicBlock *Exit = nullptr;
9063 for (auto *SBB : successors(ExitingBlock))
9064 if (!L->contains(SBB)) {
9065 if (Exit) // Multiple exit successors.
9066 return getCouldNotCompute();
9067 Exit = SBB;
9068 }
9069 assert(Exit && "Exiting block must have at least one exit");
9070 return computeExitLimitFromSingleExitSwitch(
9071 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
9072 }
9073
9074 return getCouldNotCompute();
9075}
9076
9078 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9079 bool AllowPredicates) {
9080 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
9081 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
9082 ControlsOnlyExit, AllowPredicates);
9083}
9084
9085std::optional<ScalarEvolution::ExitLimit>
9086ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
9087 bool ExitIfTrue, bool ControlsOnlyExit,
9088 bool AllowPredicates) {
9089 (void)this->L;
9090 (void)this->ExitIfTrue;
9091 (void)this->AllowPredicates;
9092
9093 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9094 this->AllowPredicates == AllowPredicates &&
9095 "Variance in assumed invariant key components!");
9096 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
9097 if (Itr == TripCountMap.end())
9098 return std::nullopt;
9099 return Itr->second;
9100}
9101
9102void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
9103 bool ExitIfTrue,
9104 bool ControlsOnlyExit,
9105 bool AllowPredicates,
9106 const ExitLimit &EL) {
9107 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9108 this->AllowPredicates == AllowPredicates &&
9109 "Variance in assumed invariant key components!");
9110
9111 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
9112 assert(InsertResult.second && "Expected successful insertion!");
9113 (void)InsertResult;
9114 (void)ExitIfTrue;
9115}
9116
9117ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
9118 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9119 bool ControlsOnlyExit, bool AllowPredicates) {
9120
9121 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
9122 AllowPredicates))
9123 return *MaybeEL;
9124
9125 ExitLimit EL = computeExitLimitFromCondImpl(
9126 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
9127 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
9128 return EL;
9129}
9130
9131ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
9132 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9133 bool ControlsOnlyExit, bool AllowPredicates) {
9134 // Handle BinOp conditions (And, Or).
9135 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
9136 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
9137 return *LimitFromBinOp;
9138
9139 // With an icmp, it may be feasible to compute an exact backedge-taken count.
9140 // Proceed to the next level to examine the icmp.
9141 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
9142 ExitLimit EL =
9143 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
9144 if (EL.hasFullInfo() || !AllowPredicates)
9145 return EL;
9146
9147 // Try again, but use SCEV predicates this time.
9148 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
9149 ControlsOnlyExit,
9150 /*AllowPredicates=*/true);
9151 }
9152
9153 // Check for a constant condition. These are normally stripped out by
9154 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
9155 // preserve the CFG and is temporarily leaving constant conditions
9156 // in place.
9157 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
9158 if (ExitIfTrue == !CI->getZExtValue())
9159 // The backedge is always taken.
9160 return getCouldNotCompute();
9161 // The backedge is never taken.
9162 return getZero(CI->getType());
9163 }
9164
9165 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9166 // with a constant step, we can form an equivalent icmp predicate and figure
9167 // out how many iterations will be taken before we exit.
9168 const WithOverflowInst *WO;
9169 const APInt *C;
9170 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9171 match(WO->getRHS(), m_APInt(C))) {
9172 ConstantRange NWR =
9174 WO->getNoWrapKind());
9175 CmpInst::Predicate Pred;
9176 APInt NewRHSC, Offset;
9177 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9178 if (!ExitIfTrue)
9179 Pred = ICmpInst::getInversePredicate(Pred);
9180 auto *LHS = getSCEV(WO->getLHS());
9181 if (Offset != 0)
9183 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9184 ControlsOnlyExit, AllowPredicates);
9185 if (EL.hasAnyInfo())
9186 return EL;
9187 }
9188
9189 // If it's not an integer or pointer comparison then compute it the hard way.
9190 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9191}
9192
9193std::optional<ScalarEvolution::ExitLimit>
9194ScalarEvolution::computeExitLimitFromCondFromBinOp(
9195 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9196 bool ControlsOnlyExit, bool AllowPredicates) {
9197 // Check if the controlling expression for this loop is an And or Or.
9198 Value *Op0, *Op1;
9199 bool IsAnd = false;
9200 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9201 IsAnd = true;
9202 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9203 IsAnd = false;
9204 else
9205 return std::nullopt;
9206
9207 // EitherMayExit is true in these two cases:
9208 // br (and Op0 Op1), loop, exit
9209 // br (or Op0 Op1), exit, loop
9210 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9211 ExitLimit EL0 = computeExitLimitFromCondCached(
9212 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9213 AllowPredicates);
9214 ExitLimit EL1 = computeExitLimitFromCondCached(
9215 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9216 AllowPredicates);
9217
9218 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
9219 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9220 if (isa<ConstantInt>(Op1))
9221 return Op1 == NeutralElement ? EL0 : EL1;
9222 if (isa<ConstantInt>(Op0))
9223 return Op0 == NeutralElement ? EL1 : EL0;
9224
9225 const SCEV *BECount = getCouldNotCompute();
9226 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9227 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9228 if (EitherMayExit) {
9229 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9230 // Both conditions must be same for the loop to continue executing.
9231 // Choose the less conservative count.
9232 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9233 EL1.ExactNotTaken != getCouldNotCompute()) {
9234 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9235 UseSequentialUMin);
9236 }
9237 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9238 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9239 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9240 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9241 else
9242 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9243 EL1.ConstantMaxNotTaken);
9244 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9245 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9246 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9247 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9248 else
9249 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9250 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9251 } else {
9252 // Both conditions must be same at the same time for the loop to exit.
9253 // For now, be conservative.
9254 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9255 BECount = EL0.ExactNotTaken;
9256 }
9257
9258 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9259 // to be more aggressive when computing BECount than when computing
9260 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9261 // and
9262 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9263 // EL1.ConstantMaxNotTaken to not.
9264 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9265 !isa<SCEVCouldNotCompute>(BECount))
9266 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9267 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9268 SymbolicMaxBECount =
9269 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9270 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9271 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
9272}
9273
9274ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9275 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9276 bool AllowPredicates) {
9277 // If the condition was exit on true, convert the condition to exit on false
9278 CmpPredicate Pred;
9279 if (!ExitIfTrue)
9280 Pred = ExitCond->getCmpPredicate();
9281 else
9282 Pred = ExitCond->getInverseCmpPredicate();
9283 const ICmpInst::Predicate OriginalPred = Pred;
9284
9285 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9286 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9287
9288 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9289 AllowPredicates);
9290 if (EL.hasAnyInfo())
9291 return EL;
9292
9293 auto *ExhaustiveCount =
9294 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9295
9296 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9297 return ExhaustiveCount;
9298
9299 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9300 ExitCond->getOperand(1), L, OriginalPred);
9301}
9302ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9303 const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS,
9304 bool ControlsOnlyExit, bool AllowPredicates) {
9305
9306 // Try to evaluate any dependencies out of the loop.
9307 LHS = getSCEVAtScope(LHS, L);
9308 RHS = getSCEVAtScope(RHS, L);
9309
9310 // At this point, we would like to compute how many iterations of the
9311 // loop the predicate will return true for these inputs.
9312 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9313 // If there is a loop-invariant, force it into the RHS.
9314 std::swap(LHS, RHS);
9316 }
9317
9318 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9320 // Simplify the operands before analyzing them.
9321 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9322
9323 // If we have a comparison of a chrec against a constant, try to use value
9324 // ranges to answer this query.
9325 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9326 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9327 if (AddRec->getLoop() == L) {
9328 // Form the constant range.
9329 ConstantRange CompRange =
9330 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9331
9332 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9333 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9334 }
9335
9336 // If this loop must exit based on this condition (or execute undefined
9337 // behaviour), see if we can improve wrap flags. This is essentially
9338 // a must execute style proof.
9339 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9340 // If we can prove the test sequence produced must repeat the same values
9341 // on self-wrap of the IV, then we can infer that IV doesn't self wrap
9342 // because if it did, we'd have an infinite (undefined) loop.
9343 // TODO: We can peel off any functions which are invertible *in L*. Loop
9344 // invariant terms are effectively constants for our purposes here.
9345 auto *InnerLHS = LHS;
9346 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9347 InnerLHS = ZExt->getOperand();
9348 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9349 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9350 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true,
9351 /*OrNegative=*/true)) {
9352 auto Flags = AR->getNoWrapFlags();
9353 Flags = setFlags(Flags, SCEV::FlagNW);
9354 SmallVector<const SCEV *> Operands{AR->operands()};
9355 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9356 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9357 }
9358
9359 // For a slt/ult condition with a positive step, can we prove nsw/nuw?
9360 // From no-self-wrap, this follows trivially from the fact that every
9361 // (un)signed-wrapped, but not self-wrapped value must be LT than the
9362 // last value before (un)signed wrap. Since we know that last value
9363 // didn't exit, nor will any smaller one.
9364 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) {
9365 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW;
9366 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9367 AR && AR->getLoop() == L && AR->isAffine() &&
9368 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() &&
9369 isKnownPositive(AR->getStepRecurrence(*this))) {
9370 auto Flags = AR->getNoWrapFlags();
9371 Flags = setFlags(Flags, WrapType);
9372 SmallVector<const SCEV*> Operands{AR->operands()};
9373 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9374 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9375 }
9376 }
9377 }
9378
9379 switch (Pred) {
9380 case ICmpInst::ICMP_NE: { // while (X != Y)
9381 // Convert to: while (X-Y != 0)
9382 if (LHS->getType()->isPointerTy()) {
9385 return LHS;
9386 }
9387 if (RHS->getType()->isPointerTy()) {
9390 return RHS;
9391 }
9392 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9393 AllowPredicates);
9394 if (EL.hasAnyInfo())
9395 return EL;
9396 break;
9397 }
9398 case ICmpInst::ICMP_EQ: { // while (X == Y)
9399 // Convert to: while (X-Y == 0)
9400 if (LHS->getType()->isPointerTy()) {
9403 return LHS;
9404 }
9405 if (RHS->getType()->isPointerTy()) {
9408 return RHS;
9409 }
9410 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9411 if (EL.hasAnyInfo()) return EL;
9412 break;
9413 }
9414 case ICmpInst::ICMP_SLE:
9415 case ICmpInst::ICMP_ULE:
9416 // Since the loop is finite, an invariant RHS cannot include the boundary
9417 // value, otherwise it would loop forever.
9418 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9419 !isLoopInvariant(RHS, L)) {
9420 // Otherwise, perform the addition in a wider type, to avoid overflow.
9421 // If the LHS is an addrec with the appropriate nowrap flag, the
9422 // extension will be sunk into it and the exit count can be analyzed.
9423 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9424 if (!OldType)
9425 break;
9426 // Prefer doubling the bitwidth over adding a single bit to make it more
9427 // likely that we use a legal type.
9428 auto *NewType =
9429 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9430 if (ICmpInst::isSigned(Pred)) {
9431 LHS = getSignExtendExpr(LHS, NewType);
9432 RHS = getSignExtendExpr(RHS, NewType);
9433 } else {
9434 LHS = getZeroExtendExpr(LHS, NewType);
9435 RHS = getZeroExtendExpr(RHS, NewType);
9436 }
9437 }
9439 [[fallthrough]];
9440 case ICmpInst::ICMP_SLT:
9441 case ICmpInst::ICMP_ULT: { // while (X < Y)
9442 bool IsSigned = ICmpInst::isSigned(Pred);
9443 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9444 AllowPredicates);
9445 if (EL.hasAnyInfo())
9446 return EL;
9447 break;
9448 }
9449 case ICmpInst::ICMP_SGE:
9450 case ICmpInst::ICMP_UGE:
9451 // Since the loop is finite, an invariant RHS cannot include the boundary
9452 // value, otherwise it would loop forever.
9453 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9454 !isLoopInvariant(RHS, L))
9455 break;
9457 [[fallthrough]];
9458 case ICmpInst::ICMP_SGT:
9459 case ICmpInst::ICMP_UGT: { // while (X > Y)
9460 bool IsSigned = ICmpInst::isSigned(Pred);
9461 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9462 AllowPredicates);
9463 if (EL.hasAnyInfo())
9464 return EL;
9465 break;
9466 }
9467 default:
9468 break;
9469 }
9470
9471 return getCouldNotCompute();
9472}
9473
9474ScalarEvolution::ExitLimit
9475ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9476 SwitchInst *Switch,
9477 BasicBlock *ExitingBlock,
9478 bool ControlsOnlyExit) {
9479 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9480
9481 // Give up if the exit is the default dest of a switch.
9482 if (Switch->getDefaultDest() == ExitingBlock)
9483 return getCouldNotCompute();
9484
9485 assert(L->contains(Switch->getDefaultDest()) &&
9486 "Default case must not exit the loop!");
9487 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9488 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9489
9490 // while (X != Y) --> while (X-Y != 0)
9491 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9492 if (EL.hasAnyInfo())
9493 return EL;
9494
9495 return getCouldNotCompute();
9496}
9497
9498static ConstantInt *
9500 ScalarEvolution &SE) {
9501 const SCEV *InVal = SE.getConstant(C);
9502 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9504 "Evaluation of SCEV at constant didn't fold correctly?");
9505 return cast<SCEVConstant>(Val)->getValue();
9506}
9507
9508ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9509 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9510 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9511 if (!RHS)
9512 return getCouldNotCompute();
9513
9514 const BasicBlock *Latch = L->getLoopLatch();
9515 if (!Latch)
9516 return getCouldNotCompute();
9517
9518 const BasicBlock *Predecessor = L->getLoopPredecessor();
9519 if (!Predecessor)
9520 return getCouldNotCompute();
9521
9522 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9523 // Return LHS in OutLHS and shift_opt in OutOpCode.
9524 auto MatchPositiveShift =
9525 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9526
9527 using namespace PatternMatch;
9528
9529 ConstantInt *ShiftAmt;
9530 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9531 OutOpCode = Instruction::LShr;
9532 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9533 OutOpCode = Instruction::AShr;
9534 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9535 OutOpCode = Instruction::Shl;
9536 else
9537 return false;
9538
9539 return ShiftAmt->getValue().isStrictlyPositive();
9540 };
9541
9542 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9543 //
9544 // loop:
9545 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9546 // %iv.shifted = lshr i32 %iv, <positive constant>
9547 //
9548 // Return true on a successful match. Return the corresponding PHI node (%iv
9549 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9550 auto MatchShiftRecurrence =
9551 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9552 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9553
9554 {
9556 Value *V;
9557
9558 // If we encounter a shift instruction, "peel off" the shift operation,
9559 // and remember that we did so. Later when we inspect %iv's backedge
9560 // value, we will make sure that the backedge value uses the same
9561 // operation.
9562 //
9563 // Note: the peeled shift operation does not have to be the same
9564 // instruction as the one feeding into the PHI's backedge value. We only
9565 // really care about it being the same *kind* of shift instruction --
9566 // that's all that is required for our later inferences to hold.
9567 if (MatchPositiveShift(LHS, V, OpC)) {
9568 PostShiftOpCode = OpC;
9569 LHS = V;
9570 }
9571 }
9572
9573 PNOut = dyn_cast<PHINode>(LHS);
9574 if (!PNOut || PNOut->getParent() != L->getHeader())
9575 return false;
9576
9577 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9578 Value *OpLHS;
9579
9580 return
9581 // The backedge value for the PHI node must be a shift by a positive
9582 // amount
9583 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9584
9585 // of the PHI node itself
9586 OpLHS == PNOut &&
9587
9588 // and the kind of shift should be match the kind of shift we peeled
9589 // off, if any.
9590 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9591 };
9592
9593 PHINode *PN;
9595 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9596 return getCouldNotCompute();
9597
9598 const DataLayout &DL = getDataLayout();
9599
9600 // The key rationale for this optimization is that for some kinds of shift
9601 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9602 // within a finite number of iterations. If the condition guarding the
9603 // backedge (in the sense that the backedge is taken if the condition is true)
9604 // is false for the value the shift recurrence stabilizes to, then we know
9605 // that the backedge is taken only a finite number of times.
9606
9607 ConstantInt *StableValue = nullptr;
9608 switch (OpCode) {
9609 default:
9610 llvm_unreachable("Impossible case!");
9611
9612 case Instruction::AShr: {
9613 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9614 // bitwidth(K) iterations.
9615 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9616 KnownBits Known = computeKnownBits(FirstValue, DL, &AC,
9617 Predecessor->getTerminator(), &DT);
9618 auto *Ty = cast<IntegerType>(RHS->getType());
9619 if (Known.isNonNegative())
9620 StableValue = ConstantInt::get(Ty, 0);
9621 else if (Known.isNegative())
9622 StableValue = ConstantInt::get(Ty, -1, true);
9623 else
9624 return getCouldNotCompute();
9625
9626 break;
9627 }
9628 case Instruction::LShr:
9629 case Instruction::Shl:
9630 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9631 // stabilize to 0 in at most bitwidth(K) iterations.
9632 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9633 break;
9634 }
9635
9636 auto *Result =
9637 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9638 assert(Result->getType()->isIntegerTy(1) &&
9639 "Otherwise cannot be an operand to a branch instruction");
9640
9641 if (Result->isZeroValue()) {
9642 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9643 const SCEV *UpperBound =
9645 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9646 }
9647
9648 return getCouldNotCompute();
9649}
9650
9651/// Return true if we can constant fold an instruction of the specified type,
9652/// assuming that all operands were constants.
9653static bool CanConstantFold(const Instruction *I) {
9657 return true;
9658
9659 if (const CallInst *CI = dyn_cast<CallInst>(I))
9660 if (const Function *F = CI->getCalledFunction())
9661 return canConstantFoldCallTo(CI, F);
9662 return false;
9663}
9664
9665/// Determine whether this instruction can constant evolve within this loop
9666/// assuming its operands can all constant evolve.
9667static bool canConstantEvolve(Instruction *I, const Loop *L) {
9668 // An instruction outside of the loop can't be derived from a loop PHI.
9669 if (!L->contains(I)) return false;
9670
9671 if (isa<PHINode>(I)) {
9672 // We don't currently keep track of the control flow needed to evaluate
9673 // PHIs, so we cannot handle PHIs inside of loops.
9674 return L->getHeader() == I->getParent();
9675 }
9676
9677 // If we won't be able to constant fold this expression even if the operands
9678 // are constants, bail early.
9679 return CanConstantFold(I);
9680}
9681
9682/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9683/// recursing through each instruction operand until reaching a loop header phi.
9684static PHINode *
9687 unsigned Depth) {
9689 return nullptr;
9690
9691 // Otherwise, we can evaluate this instruction if all of its operands are
9692 // constant or derived from a PHI node themselves.
9693 PHINode *PHI = nullptr;
9694 for (Value *Op : UseInst->operands()) {
9695 if (isa<Constant>(Op)) continue;
9696
9698 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9699
9700 PHINode *P = dyn_cast<PHINode>(OpInst);
9701 if (!P)
9702 // If this operand is already visited, reuse the prior result.
9703 // We may have P != PHI if this is the deepest point at which the
9704 // inconsistent paths meet.
9705 P = PHIMap.lookup(OpInst);
9706 if (!P) {
9707 // Recurse and memoize the results, whether a phi is found or not.
9708 // This recursive call invalidates pointers into PHIMap.
9709 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9710 PHIMap[OpInst] = P;
9711 }
9712 if (!P)
9713 return nullptr; // Not evolving from PHI
9714 if (PHI && PHI != P)
9715 return nullptr; // Evolving from multiple different PHIs.
9716 PHI = P;
9717 }
9718 // This is a expression evolving from a constant PHI!
9719 return PHI;
9720}
9721
9722/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9723/// in the loop that V is derived from. We allow arbitrary operations along the
9724/// way, but the operands of an operation must either be constants or a value
9725/// derived from a constant PHI. If this expression does not fit with these
9726/// constraints, return null.
9729 if (!I || !canConstantEvolve(I, L)) return nullptr;
9730
9731 if (PHINode *PN = dyn_cast<PHINode>(I))
9732 return PN;
9733
9734 // Record non-constant instructions contained by the loop.
9736 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9737}
9738
9739/// EvaluateExpression - Given an expression that passes the
9740/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9741/// in the loop has the value PHIVal. If we can't fold this expression for some
9742/// reason, return null.
9745 const DataLayout &DL,
9746 const TargetLibraryInfo *TLI) {
9747 // Convenient constant check, but redundant for recursive calls.
9748 if (Constant *C = dyn_cast<Constant>(V)) return C;
9750 if (!I) return nullptr;
9751
9752 if (Constant *C = Vals.lookup(I)) return C;
9753
9754 // An instruction inside the loop depends on a value outside the loop that we
9755 // weren't given a mapping for, or a value such as a call inside the loop.
9756 if (!canConstantEvolve(I, L)) return nullptr;
9757
9758 // An unmapped PHI can be due to a branch or another loop inside this loop,
9759 // or due to this not being the initial iteration through a loop where we
9760 // couldn't compute the evolution of this particular PHI last time.
9761 if (isa<PHINode>(I)) return nullptr;
9762
9763 std::vector<Constant*> Operands(I->getNumOperands());
9764
9765 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9766 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9767 if (!Operand) {
9768 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9769 if (!Operands[i]) return nullptr;
9770 continue;
9771 }
9772 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9773 Vals[Operand] = C;
9774 if (!C) return nullptr;
9775 Operands[i] = C;
9776 }
9777
9778 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9779 /*AllowNonDeterministic=*/false);
9780}
9781
9782
9783// If every incoming value to PN except the one for BB is a specific Constant,
9784// return that, else return nullptr.
9786 Constant *IncomingVal = nullptr;
9787
9788 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9789 if (PN->getIncomingBlock(i) == BB)
9790 continue;
9791
9792 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9793 if (!CurrentVal)
9794 return nullptr;
9795
9796 if (IncomingVal != CurrentVal) {
9797 if (IncomingVal)
9798 return nullptr;
9799 IncomingVal = CurrentVal;
9800 }
9801 }
9802
9803 return IncomingVal;
9804}
9805
9806/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9807/// in the header of its containing loop, we know the loop executes a
9808/// constant number of times, and the PHI node is just a recurrence
9809/// involving constants, fold it.
9810Constant *
9811ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9812 const APInt &BEs,
9813 const Loop *L) {
9814 auto [I, Inserted] = ConstantEvolutionLoopExitValue.try_emplace(PN);
9815 if (!Inserted)
9816 return I->second;
9817
9819 return nullptr; // Not going to evaluate it.
9820
9821 Constant *&RetVal = I->second;
9822
9823 DenseMap<Instruction *, Constant *> CurrentIterVals;
9824 BasicBlock *Header = L->getHeader();
9825 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9826
9827 BasicBlock *Latch = L->getLoopLatch();
9828 if (!Latch)
9829 return nullptr;
9830
9831 for (PHINode &PHI : Header->phis()) {
9832 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9833 CurrentIterVals[&PHI] = StartCST;
9834 }
9835 if (!CurrentIterVals.count(PN))
9836 return RetVal = nullptr;
9837
9838 Value *BEValue = PN->getIncomingValueForBlock(Latch);
9839
9840 // Execute the loop symbolically to determine the exit value.
9841 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
9842 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
9843
9844 unsigned NumIterations = BEs.getZExtValue(); // must be in range
9845 unsigned IterationNum = 0;
9846 const DataLayout &DL = getDataLayout();
9847 for (; ; ++IterationNum) {
9848 if (IterationNum == NumIterations)
9849 return RetVal = CurrentIterVals[PN]; // Got exit value!
9850
9851 // Compute the value of the PHIs for the next iteration.
9852 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
9853 DenseMap<Instruction *, Constant *> NextIterVals;
9854 Constant *NextPHI =
9855 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9856 if (!NextPHI)
9857 return nullptr; // Couldn't evaluate!
9858 NextIterVals[PN] = NextPHI;
9859
9860 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
9861
9862 // Also evaluate the other PHI nodes. However, we don't get to stop if we
9863 // cease to be able to evaluate one of them or if they stop evolving,
9864 // because that doesn't necessarily prevent us from computing PN.
9866 for (const auto &I : CurrentIterVals) {
9867 PHINode *PHI = dyn_cast<PHINode>(I.first);
9868 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
9869 PHIsToCompute.emplace_back(PHI, I.second);
9870 }
9871 // We use two distinct loops because EvaluateExpression may invalidate any
9872 // iterators into CurrentIterVals.
9873 for (const auto &I : PHIsToCompute) {
9874 PHINode *PHI = I.first;
9875 Constant *&NextPHI = NextIterVals[PHI];
9876 if (!NextPHI) { // Not already computed.
9877 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9878 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9879 }
9880 if (NextPHI != I.second)
9881 StoppedEvolving = false;
9882 }
9883
9884 // If all entries in CurrentIterVals == NextIterVals then we can stop
9885 // iterating, the loop can't continue to change.
9886 if (StoppedEvolving)
9887 return RetVal = CurrentIterVals[PN];
9888
9889 CurrentIterVals.swap(NextIterVals);
9890 }
9891}
9892
9893const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
9894 Value *Cond,
9895 bool ExitWhen) {
9896 PHINode *PN = getConstantEvolvingPHI(Cond, L);
9897 if (!PN) return getCouldNotCompute();
9898
9899 // If the loop is canonicalized, the PHI will have exactly two entries.
9900 // That's the only form we support here.
9901 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
9902
9903 DenseMap<Instruction *, Constant *> CurrentIterVals;
9904 BasicBlock *Header = L->getHeader();
9905 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9906
9907 BasicBlock *Latch = L->getLoopLatch();
9908 assert(Latch && "Should follow from NumIncomingValues == 2!");
9909
9910 for (PHINode &PHI : Header->phis()) {
9911 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9912 CurrentIterVals[&PHI] = StartCST;
9913 }
9914 if (!CurrentIterVals.count(PN))
9915 return getCouldNotCompute();
9916
9917 // Okay, we find a PHI node that defines the trip count of this loop. Execute
9918 // the loop symbolically to determine when the condition gets a value of
9919 // "ExitWhen".
9920 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
9921 const DataLayout &DL = getDataLayout();
9922 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
9923 auto *CondVal = dyn_cast_or_null<ConstantInt>(
9924 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
9925
9926 // Couldn't symbolically evaluate.
9927 if (!CondVal) return getCouldNotCompute();
9928
9929 if (CondVal->getValue() == uint64_t(ExitWhen)) {
9930 ++NumBruteForceTripCountsComputed;
9931 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
9932 }
9933
9934 // Update all the PHI nodes for the next iteration.
9935 DenseMap<Instruction *, Constant *> NextIterVals;
9936
9937 // Create a list of which PHIs we need to compute. We want to do this before
9938 // calling EvaluateExpression on them because that may invalidate iterators
9939 // into CurrentIterVals.
9940 SmallVector<PHINode *, 8> PHIsToCompute;
9941 for (const auto &I : CurrentIterVals) {
9942 PHINode *PHI = dyn_cast<PHINode>(I.first);
9943 if (!PHI || PHI->getParent() != Header) continue;
9944 PHIsToCompute.push_back(PHI);
9945 }
9946 for (PHINode *PHI : PHIsToCompute) {
9947 Constant *&NextPHI = NextIterVals[PHI];
9948 if (NextPHI) continue; // Already computed!
9949
9950 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9951 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9952 }
9953 CurrentIterVals.swap(NextIterVals);
9954 }
9955
9956 // Too many iterations were needed to evaluate.
9957 return getCouldNotCompute();
9958}
9959
9960const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
9962 ValuesAtScopes[V];
9963 // Check to see if we've folded this expression at this loop before.
9964 for (auto &LS : Values)
9965 if (LS.first == L)
9966 return LS.second ? LS.second : V;
9967
9968 Values.emplace_back(L, nullptr);
9969
9970 // Otherwise compute it.
9971 const SCEV *C = computeSCEVAtScope(V, L);
9972 for (auto &LS : reverse(ValuesAtScopes[V]))
9973 if (LS.first == L) {
9974 LS.second = C;
9975 if (!isa<SCEVConstant>(C))
9976 ValuesAtScopesUsers[C].push_back({L, V});
9977 break;
9978 }
9979 return C;
9980}
9981
9982/// This builds up a Constant using the ConstantExpr interface. That way, we
9983/// will return Constants for objects which aren't represented by a
9984/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
9985/// Returns NULL if the SCEV isn't representable as a Constant.
9987 switch (V->getSCEVType()) {
9988 case scCouldNotCompute:
9989 case scAddRecExpr:
9990 case scVScale:
9991 return nullptr;
9992 case scConstant:
9993 return cast<SCEVConstant>(V)->getValue();
9994 case scUnknown:
9995 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
9996 case scPtrToAddr: {
9998 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
9999 return ConstantExpr::getPtrToAddr(CastOp, P2I->getType());
10000
10001 return nullptr;
10002 }
10003 case scPtrToInt: {
10005 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
10006 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
10007
10008 return nullptr;
10009 }
10010 case scTruncate: {
10012 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
10013 return ConstantExpr::getTrunc(CastOp, ST->getType());
10014 return nullptr;
10015 }
10016 case scAddExpr: {
10017 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
10018 Constant *C = nullptr;
10019 for (const SCEV *Op : SA->operands()) {
10021 if (!OpC)
10022 return nullptr;
10023 if (!C) {
10024 C = OpC;
10025 continue;
10026 }
10027 assert(!C->getType()->isPointerTy() &&
10028 "Can only have one pointer, and it must be last");
10029 if (OpC->getType()->isPointerTy()) {
10030 // The offsets have been converted to bytes. We can add bytes using
10031 // an i8 GEP.
10033 OpC, C);
10034 } else {
10035 C = ConstantExpr::getAdd(C, OpC);
10036 }
10037 }
10038 return C;
10039 }
10040 case scMulExpr:
10041 case scSignExtend:
10042 case scZeroExtend:
10043 case scUDivExpr:
10044 case scSMaxExpr:
10045 case scUMaxExpr:
10046 case scSMinExpr:
10047 case scUMinExpr:
10049 return nullptr;
10050 }
10051 llvm_unreachable("Unknown SCEV kind!");
10052}
10053
10054const SCEV *
10055ScalarEvolution::getWithOperands(const SCEV *S,
10056 SmallVectorImpl<const SCEV *> &NewOps) {
10057 switch (S->getSCEVType()) {
10058 case scTruncate:
10059 case scZeroExtend:
10060 case scSignExtend:
10061 case scPtrToAddr:
10062 case scPtrToInt:
10063 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
10064 case scAddRecExpr: {
10065 auto *AddRec = cast<SCEVAddRecExpr>(S);
10066 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
10067 }
10068 case scAddExpr:
10069 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
10070 case scMulExpr:
10071 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
10072 case scUDivExpr:
10073 return getUDivExpr(NewOps[0], NewOps[1]);
10074 case scUMaxExpr:
10075 case scSMaxExpr:
10076 case scUMinExpr:
10077 case scSMinExpr:
10078 return getMinMaxExpr(S->getSCEVType(), NewOps);
10080 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
10081 case scConstant:
10082 case scVScale:
10083 case scUnknown:
10084 return S;
10085 case scCouldNotCompute:
10086 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10087 }
10088 llvm_unreachable("Unknown SCEV kind!");
10089}
10090
10091const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
10092 switch (V->getSCEVType()) {
10093 case scConstant:
10094 case scVScale:
10095 return V;
10096 case scAddRecExpr: {
10097 // If this is a loop recurrence for a loop that does not contain L, then we
10098 // are dealing with the final value computed by the loop.
10099 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
10100 // First, attempt to evaluate each operand.
10101 // Avoid performing the look-up in the common case where the specified
10102 // expression has no loop-variant portions.
10103 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
10104 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
10105 if (OpAtScope == AddRec->getOperand(i))
10106 continue;
10107
10108 // Okay, at least one of these operands is loop variant but might be
10109 // foldable. Build a new instance of the folded commutative expression.
10111 NewOps.reserve(AddRec->getNumOperands());
10112 append_range(NewOps, AddRec->operands().take_front(i));
10113 NewOps.push_back(OpAtScope);
10114 for (++i; i != e; ++i)
10115 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
10116
10117 const SCEV *FoldedRec = getAddRecExpr(
10118 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
10119 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
10120 // The addrec may be folded to a nonrecurrence, for example, if the
10121 // induction variable is multiplied by zero after constant folding. Go
10122 // ahead and return the folded value.
10123 if (!AddRec)
10124 return FoldedRec;
10125 break;
10126 }
10127
10128 // If the scope is outside the addrec's loop, evaluate it by using the
10129 // loop exit value of the addrec.
10130 if (!AddRec->getLoop()->contains(L)) {
10131 // To evaluate this recurrence, we need to know how many times the AddRec
10132 // loop iterates. Compute this now.
10133 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
10134 if (BackedgeTakenCount == getCouldNotCompute())
10135 return AddRec;
10136
10137 // Then, evaluate the AddRec.
10138 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
10139 }
10140
10141 return AddRec;
10142 }
10143 case scTruncate:
10144 case scZeroExtend:
10145 case scSignExtend:
10146 case scPtrToAddr:
10147 case scPtrToInt:
10148 case scAddExpr:
10149 case scMulExpr:
10150 case scUDivExpr:
10151 case scUMaxExpr:
10152 case scSMaxExpr:
10153 case scUMinExpr:
10154 case scSMinExpr:
10155 case scSequentialUMinExpr: {
10156 ArrayRef<const SCEV *> Ops = V->operands();
10157 // Avoid performing the look-up in the common case where the specified
10158 // expression has no loop-variant portions.
10159 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
10160 const SCEV *OpAtScope = getSCEVAtScope(Ops[i], L);
10161 if (OpAtScope != Ops[i]) {
10162 // Okay, at least one of these operands is loop variant but might be
10163 // foldable. Build a new instance of the folded commutative expression.
10165 NewOps.reserve(Ops.size());
10166 append_range(NewOps, Ops.take_front(i));
10167 NewOps.push_back(OpAtScope);
10168
10169 for (++i; i != e; ++i) {
10170 OpAtScope = getSCEVAtScope(Ops[i], L);
10171 NewOps.push_back(OpAtScope);
10172 }
10173
10174 return getWithOperands(V, NewOps);
10175 }
10176 }
10177 // If we got here, all operands are loop invariant.
10178 return V;
10179 }
10180 case scUnknown: {
10181 // If this instruction is evolved from a constant-evolving PHI, compute the
10182 // exit value from the loop without using SCEVs.
10183 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
10185 if (!I)
10186 return V; // This is some other type of SCEVUnknown, just return it.
10187
10188 if (PHINode *PN = dyn_cast<PHINode>(I)) {
10189 const Loop *CurrLoop = this->LI[I->getParent()];
10190 // Looking for loop exit value.
10191 if (CurrLoop && CurrLoop->getParentLoop() == L &&
10192 PN->getParent() == CurrLoop->getHeader()) {
10193 // Okay, there is no closed form solution for the PHI node. Check
10194 // to see if the loop that contains it has a known backedge-taken
10195 // count. If so, we may be able to force computation of the exit
10196 // value.
10197 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10198 // This trivial case can show up in some degenerate cases where
10199 // the incoming IR has not yet been fully simplified.
10200 if (BackedgeTakenCount->isZero()) {
10201 Value *InitValue = nullptr;
10202 bool MultipleInitValues = false;
10203 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10204 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10205 if (!InitValue)
10206 InitValue = PN->getIncomingValue(i);
10207 else if (InitValue != PN->getIncomingValue(i)) {
10208 MultipleInitValues = true;
10209 break;
10210 }
10211 }
10212 }
10213 if (!MultipleInitValues && InitValue)
10214 return getSCEV(InitValue);
10215 }
10216 // Do we have a loop invariant value flowing around the backedge
10217 // for a loop which must execute the backedge?
10218 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10219 isKnownNonZero(BackedgeTakenCount) &&
10220 PN->getNumIncomingValues() == 2) {
10221
10222 unsigned InLoopPred =
10223 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10224 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10225 if (CurrLoop->isLoopInvariant(BackedgeVal))
10226 return getSCEV(BackedgeVal);
10227 }
10228 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10229 // Okay, we know how many times the containing loop executes. If
10230 // this is a constant evolving PHI node, get the final value at
10231 // the specified iteration number.
10232 Constant *RV =
10233 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10234 if (RV)
10235 return getSCEV(RV);
10236 }
10237 }
10238 }
10239
10240 // Okay, this is an expression that we cannot symbolically evaluate
10241 // into a SCEV. Check to see if it's possible to symbolically evaluate
10242 // the arguments into constants, and if so, try to constant propagate the
10243 // result. This is particularly useful for computing loop exit values.
10244 if (!CanConstantFold(I))
10245 return V; // This is some other type of SCEVUnknown, just return it.
10246
10247 SmallVector<Constant *, 4> Operands;
10248 Operands.reserve(I->getNumOperands());
10249 bool MadeImprovement = false;
10250 for (Value *Op : I->operands()) {
10251 if (Constant *C = dyn_cast<Constant>(Op)) {
10252 Operands.push_back(C);
10253 continue;
10254 }
10255
10256 // If any of the operands is non-constant and if they are
10257 // non-integer and non-pointer, don't even try to analyze them
10258 // with scev techniques.
10259 if (!isSCEVable(Op->getType()))
10260 return V;
10261
10262 const SCEV *OrigV = getSCEV(Op);
10263 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10264 MadeImprovement |= OrigV != OpV;
10265
10267 if (!C)
10268 return V;
10269 assert(C->getType() == Op->getType() && "Type mismatch");
10270 Operands.push_back(C);
10271 }
10272
10273 // Check to see if getSCEVAtScope actually made an improvement.
10274 if (!MadeImprovement)
10275 return V; // This is some other type of SCEVUnknown, just return it.
10276
10277 Constant *C = nullptr;
10278 const DataLayout &DL = getDataLayout();
10279 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10280 /*AllowNonDeterministic=*/false);
10281 if (!C)
10282 return V;
10283 return getSCEV(C);
10284 }
10285 case scCouldNotCompute:
10286 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10287 }
10288 llvm_unreachable("Unknown SCEV type!");
10289}
10290
10292 return getSCEVAtScope(getSCEV(V), L);
10293}
10294
10295const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10297 return stripInjectiveFunctions(ZExt->getOperand());
10299 return stripInjectiveFunctions(SExt->getOperand());
10300 return S;
10301}
10302
10303/// Finds the minimum unsigned root of the following equation:
10304///
10305/// A * X = B (mod N)
10306///
10307/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10308/// A and B isn't important.
10309///
10310/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10311static const SCEV *
10314 ScalarEvolution &SE, const Loop *L) {
10315 uint32_t BW = A.getBitWidth();
10316 assert(BW == SE.getTypeSizeInBits(B->getType()));
10317 assert(A != 0 && "A must be non-zero.");
10318
10319 // 1. D = gcd(A, N)
10320 //
10321 // The gcd of A and N may have only one prime factor: 2. The number of
10322 // trailing zeros in A is its multiplicity
10323 uint32_t Mult2 = A.countr_zero();
10324 // D = 2^Mult2
10325
10326 // 2. Check if B is divisible by D.
10327 //
10328 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10329 // is not less than multiplicity of this prime factor for D.
10330 unsigned MinTZ = SE.getMinTrailingZeros(B);
10331 // Try again with the terminator of the loop predecessor for context-specific
10332 // result, if MinTZ s too small.
10333 if (MinTZ < Mult2 && L->getLoopPredecessor())
10334 MinTZ = SE.getMinTrailingZeros(B, L->getLoopPredecessor()->getTerminator());
10335 if (MinTZ < Mult2) {
10336 // Check if we can prove there's no remainder using URem.
10337 const SCEV *URem =
10338 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
10339 const SCEV *Zero = SE.getZero(B->getType());
10340 if (!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero)) {
10341 // Try to add a predicate ensuring B is a multiple of 1 << Mult2.
10342 if (!Predicates)
10343 return SE.getCouldNotCompute();
10344
10345 // Avoid adding a predicate that is known to be false.
10346 if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
10347 return SE.getCouldNotCompute();
10348 Predicates->push_back(SE.getEqualPredicate(URem, Zero));
10349 }
10350 }
10351
10352 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10353 // modulo (N / D).
10354 //
10355 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10356 // (N / D) in general. The inverse itself always fits into BW bits, though,
10357 // so we immediately truncate it.
10358 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10359 APInt I = AD.multiplicativeInverse().zext(BW);
10360
10361 // 4. Compute the minimum unsigned root of the equation:
10362 // I * (B / D) mod (N / D)
10363 // To simplify the computation, we factor out the divide by D:
10364 // (I * B mod N) / D
10365 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10366 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10367}
10368
10369/// For a given quadratic addrec, generate coefficients of the corresponding
10370/// quadratic equation, multiplied by a common value to ensure that they are
10371/// integers.
10372/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10373/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10374/// were multiplied by, and BitWidth is the bit width of the original addrec
10375/// coefficients.
10376/// This function returns std::nullopt if the addrec coefficients are not
10377/// compile- time constants.
10378static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10380 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10381 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10382 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10383 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10384 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10385 << *AddRec << '\n');
10386
10387 // We currently can only solve this if the coefficients are constants.
10388 if (!LC || !MC || !NC) {
10389 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10390 return std::nullopt;
10391 }
10392
10393 APInt L = LC->getAPInt();
10394 APInt M = MC->getAPInt();
10395 APInt N = NC->getAPInt();
10396 assert(!N.isZero() && "This is not a quadratic addrec");
10397
10398 unsigned BitWidth = LC->getAPInt().getBitWidth();
10399 unsigned NewWidth = BitWidth + 1;
10400 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10401 << BitWidth << '\n');
10402 // The sign-extension (as opposed to a zero-extension) here matches the
10403 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10404 N = N.sext(NewWidth);
10405 M = M.sext(NewWidth);
10406 L = L.sext(NewWidth);
10407
10408 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10409 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10410 // L+M, L+2M+N, L+3M+3N, ...
10411 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10412 //
10413 // The equation Acc = 0 is then
10414 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10415 // In a quadratic form it becomes:
10416 // N n^2 + (2M-N) n + 2L = 0.
10417
10418 APInt A = N;
10419 APInt B = 2 * M - A;
10420 APInt C = 2 * L;
10421 APInt T = APInt(NewWidth, 2);
10422 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10423 << "x + " << C << ", coeff bw: " << NewWidth
10424 << ", multiplied by " << T << '\n');
10425 return std::make_tuple(A, B, C, T, BitWidth);
10426}
10427
10428/// Helper function to compare optional APInts:
10429/// (a) if X and Y both exist, return min(X, Y),
10430/// (b) if neither X nor Y exist, return std::nullopt,
10431/// (c) if exactly one of X and Y exists, return that value.
10432static std::optional<APInt> MinOptional(std::optional<APInt> X,
10433 std::optional<APInt> Y) {
10434 if (X && Y) {
10435 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10436 APInt XW = X->sext(W);
10437 APInt YW = Y->sext(W);
10438 return XW.slt(YW) ? *X : *Y;
10439 }
10440 if (!X && !Y)
10441 return std::nullopt;
10442 return X ? *X : *Y;
10443}
10444
10445/// Helper function to truncate an optional APInt to a given BitWidth.
10446/// When solving addrec-related equations, it is preferable to return a value
10447/// that has the same bit width as the original addrec's coefficients. If the
10448/// solution fits in the original bit width, truncate it (except for i1).
10449/// Returning a value of a different bit width may inhibit some optimizations.
10450///
10451/// In general, a solution to a quadratic equation generated from an addrec
10452/// may require BW+1 bits, where BW is the bit width of the addrec's
10453/// coefficients. The reason is that the coefficients of the quadratic
10454/// equation are BW+1 bits wide (to avoid truncation when converting from
10455/// the addrec to the equation).
10456static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10457 unsigned BitWidth) {
10458 if (!X)
10459 return std::nullopt;
10460 unsigned W = X->getBitWidth();
10462 return X->trunc(BitWidth);
10463 return X;
10464}
10465
10466/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10467/// iterations. The values L, M, N are assumed to be signed, and they
10468/// should all have the same bit widths.
10469/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10470/// where BW is the bit width of the addrec's coefficients.
10471/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10472/// returned as such, otherwise the bit width of the returned value may
10473/// be greater than BW.
10474///
10475/// This function returns std::nullopt if
10476/// (a) the addrec coefficients are not constant, or
10477/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10478/// like x^2 = 5, no integer solutions exist, in other cases an integer
10479/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10480static std::optional<APInt>
10482 APInt A, B, C, M;
10483 unsigned BitWidth;
10484 auto T = GetQuadraticEquation(AddRec);
10485 if (!T)
10486 return std::nullopt;
10487
10488 std::tie(A, B, C, M, BitWidth) = *T;
10489 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10490 std::optional<APInt> X =
10492 if (!X)
10493 return std::nullopt;
10494
10495 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10496 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10497 if (!V->isZero())
10498 return std::nullopt;
10499
10500 return TruncIfPossible(X, BitWidth);
10501}
10502
10503/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10504/// iterations. The values M, N are assumed to be signed, and they
10505/// should all have the same bit widths.
10506/// Find the least n such that c(n) does not belong to the given range,
10507/// while c(n-1) does.
10508///
10509/// This function returns std::nullopt if
10510/// (a) the addrec coefficients are not constant, or
10511/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10512/// bounds of the range.
10513static std::optional<APInt>
10515 const ConstantRange &Range, ScalarEvolution &SE) {
10516 assert(AddRec->getOperand(0)->isZero() &&
10517 "Starting value of addrec should be 0");
10518 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10519 << Range << ", addrec " << *AddRec << '\n');
10520 // This case is handled in getNumIterationsInRange. Here we can assume that
10521 // we start in the range.
10522 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10523 "Addrec's initial value should be in range");
10524
10525 APInt A, B, C, M;
10526 unsigned BitWidth;
10527 auto T = GetQuadraticEquation(AddRec);
10528 if (!T)
10529 return std::nullopt;
10530
10531 // Be careful about the return value: there can be two reasons for not
10532 // returning an actual number. First, if no solutions to the equations
10533 // were found, and second, if the solutions don't leave the given range.
10534 // The first case means that the actual solution is "unknown", the second
10535 // means that it's known, but not valid. If the solution is unknown, we
10536 // cannot make any conclusions.
10537 // Return a pair: the optional solution and a flag indicating if the
10538 // solution was found.
10539 auto SolveForBoundary =
10540 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10541 // Solve for signed overflow and unsigned overflow, pick the lower
10542 // solution.
10543 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10544 << Bound << " (before multiplying by " << M << ")\n");
10545 Bound *= M; // The quadratic equation multiplier.
10546
10547 std::optional<APInt> SO;
10548 if (BitWidth > 1) {
10549 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10550 "signed overflow\n");
10552 }
10553 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10554 "unsigned overflow\n");
10555 std::optional<APInt> UO =
10557
10558 auto LeavesRange = [&] (const APInt &X) {
10559 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10560 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10561 if (Range.contains(V0->getValue()))
10562 return false;
10563 // X should be at least 1, so X-1 is non-negative.
10564 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10565 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10566 if (Range.contains(V1->getValue()))
10567 return true;
10568 return false;
10569 };
10570
10571 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10572 // can be a solution, but the function failed to find it. We cannot treat it
10573 // as "no solution".
10574 if (!SO || !UO)
10575 return {std::nullopt, false};
10576
10577 // Check the smaller value first to see if it leaves the range.
10578 // At this point, both SO and UO must have values.
10579 std::optional<APInt> Min = MinOptional(SO, UO);
10580 if (LeavesRange(*Min))
10581 return { Min, true };
10582 std::optional<APInt> Max = Min == SO ? UO : SO;
10583 if (LeavesRange(*Max))
10584 return { Max, true };
10585
10586 // Solutions were found, but were eliminated, hence the "true".
10587 return {std::nullopt, true};
10588 };
10589
10590 std::tie(A, B, C, M, BitWidth) = *T;
10591 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10592 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10593 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10594 auto SL = SolveForBoundary(Lower);
10595 auto SU = SolveForBoundary(Upper);
10596 // If any of the solutions was unknown, no meaninigful conclusions can
10597 // be made.
10598 if (!SL.second || !SU.second)
10599 return std::nullopt;
10600
10601 // Claim: The correct solution is not some value between Min and Max.
10602 //
10603 // Justification: Assuming that Min and Max are different values, one of
10604 // them is when the first signed overflow happens, the other is when the
10605 // first unsigned overflow happens. Crossing the range boundary is only
10606 // possible via an overflow (treating 0 as a special case of it, modeling
10607 // an overflow as crossing k*2^W for some k).
10608 //
10609 // The interesting case here is when Min was eliminated as an invalid
10610 // solution, but Max was not. The argument is that if there was another
10611 // overflow between Min and Max, it would also have been eliminated if
10612 // it was considered.
10613 //
10614 // For a given boundary, it is possible to have two overflows of the same
10615 // type (signed/unsigned) without having the other type in between: this
10616 // can happen when the vertex of the parabola is between the iterations
10617 // corresponding to the overflows. This is only possible when the two
10618 // overflows cross k*2^W for the same k. In such case, if the second one
10619 // left the range (and was the first one to do so), the first overflow
10620 // would have to enter the range, which would mean that either we had left
10621 // the range before or that we started outside of it. Both of these cases
10622 // are contradictions.
10623 //
10624 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10625 // solution is not some value between the Max for this boundary and the
10626 // Min of the other boundary.
10627 //
10628 // Justification: Assume that we had such Max_A and Min_B corresponding
10629 // to range boundaries A and B and such that Max_A < Min_B. If there was
10630 // a solution between Max_A and Min_B, it would have to be caused by an
10631 // overflow corresponding to either A or B. It cannot correspond to B,
10632 // since Min_B is the first occurrence of such an overflow. If it
10633 // corresponded to A, it would have to be either a signed or an unsigned
10634 // overflow that is larger than both eliminated overflows for A. But
10635 // between the eliminated overflows and this overflow, the values would
10636 // cover the entire value space, thus crossing the other boundary, which
10637 // is a contradiction.
10638
10639 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10640}
10641
10642ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10643 const Loop *L,
10644 bool ControlsOnlyExit,
10645 bool AllowPredicates) {
10646
10647 // This is only used for loops with a "x != y" exit test. The exit condition
10648 // is now expressed as a single expression, V = x-y. So the exit test is
10649 // effectively V != 0. We know and take advantage of the fact that this
10650 // expression only being used in a comparison by zero context.
10651
10653 // If the value is a constant
10654 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10655 // If the value is already zero, the branch will execute zero times.
10656 if (C->getValue()->isZero()) return C;
10657 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10658 }
10659
10660 const SCEVAddRecExpr *AddRec =
10661 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10662
10663 if (!AddRec && AllowPredicates)
10664 // Try to make this an AddRec using runtime tests, in the first X
10665 // iterations of this loop, where X is the SCEV expression found by the
10666 // algorithm below.
10667 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10668
10669 if (!AddRec || AddRec->getLoop() != L)
10670 return getCouldNotCompute();
10671
10672 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10673 // the quadratic equation to solve it.
10674 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10675 // We can only use this value if the chrec ends up with an exact zero
10676 // value at this index. When solving for "X*X != 5", for example, we
10677 // should not accept a root of 2.
10678 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10679 const auto *R = cast<SCEVConstant>(getConstant(*S));
10680 return ExitLimit(R, R, R, false, Predicates);
10681 }
10682 return getCouldNotCompute();
10683 }
10684
10685 // Otherwise we can only handle this if it is affine.
10686 if (!AddRec->isAffine())
10687 return getCouldNotCompute();
10688
10689 // If this is an affine expression, the execution count of this branch is
10690 // the minimum unsigned root of the following equation:
10691 //
10692 // Start + Step*N = 0 (mod 2^BW)
10693 //
10694 // equivalent to:
10695 //
10696 // Step*N = -Start (mod 2^BW)
10697 //
10698 // where BW is the common bit width of Start and Step.
10699
10700 // Get the initial value for the loop.
10701 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10702 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10703
10704 if (!isLoopInvariant(Step, L))
10705 return getCouldNotCompute();
10706
10707 LoopGuards Guards = LoopGuards::collect(L, *this);
10708 // Specialize step for this loop so we get context sensitive facts below.
10709 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10710
10711 // For positive steps (counting up until unsigned overflow):
10712 // N = -Start/Step (as unsigned)
10713 // For negative steps (counting down to zero):
10714 // N = Start/-Step
10715 // First compute the unsigned distance from zero in the direction of Step.
10716 bool CountDown = isKnownNegative(StepWLG);
10717 if (!CountDown && !isKnownNonNegative(StepWLG))
10718 return getCouldNotCompute();
10719
10720 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10721 // Handle unitary steps, which cannot wraparound.
10722 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10723 // N = Distance (as unsigned)
10724
10725 if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
10726 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10727 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10728
10729 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10730 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10731 // case, and see if we can improve the bound.
10732 //
10733 // Explicitly handling this here is necessary because getUnsignedRange
10734 // isn't context-sensitive; it doesn't know that we only care about the
10735 // range inside the loop.
10736 const SCEV *Zero = getZero(Distance->getType());
10737 const SCEV *One = getOne(Distance->getType());
10738 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10739 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10740 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10741 // as "unsigned_max(Distance + 1) - 1".
10742 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10743 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10744 }
10745 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10746 Predicates);
10747 }
10748
10749 // If the condition controls loop exit (the loop exits only if the expression
10750 // is true) and the addition is no-wrap we can use unsigned divide to
10751 // compute the backedge count. In this case, the step may not divide the
10752 // distance, but we don't care because if the condition is "missed" the loop
10753 // will have undefined behavior due to wrapping.
10754 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10755 loopHasNoAbnormalExits(AddRec->getLoop())) {
10756
10757 // If the stride is zero and the start is non-zero, the loop must be
10758 // infinite. In C++, most loops are finite by assumption, in which case the
10759 // step being zero implies UB must execute if the loop is entered.
10760 if (!(loopIsFiniteByAssumption(L) && isKnownNonZero(Start)) &&
10761 !isKnownNonZero(StepWLG))
10762 return getCouldNotCompute();
10763
10764 const SCEV *Exact =
10765 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10766 const SCEV *ConstantMax = getCouldNotCompute();
10767 if (Exact != getCouldNotCompute()) {
10768 APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
10769 ConstantMax =
10771 }
10772 const SCEV *SymbolicMax =
10773 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10774 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10775 }
10776
10777 // Solve the general equation.
10778 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10779 if (!StepC || StepC->getValue()->isZero())
10780 return getCouldNotCompute();
10781 const SCEV *E = SolveLinEquationWithOverflow(
10782 StepC->getAPInt(), getNegativeSCEV(Start),
10783 AllowPredicates ? &Predicates : nullptr, *this, L);
10784
10785 const SCEV *M = E;
10786 if (E != getCouldNotCompute()) {
10787 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10788 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10789 }
10790 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10791 return ExitLimit(E, M, S, false, Predicates);
10792}
10793
10794ScalarEvolution::ExitLimit
10795ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10796 // Loops that look like: while (X == 0) are very strange indeed. We don't
10797 // handle them yet except for the trivial case. This could be expanded in the
10798 // future as needed.
10799
10800 // If the value is a constant, check to see if it is known to be non-zero
10801 // already. If so, the backedge will execute zero times.
10802 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10803 if (!C->getValue()->isZero())
10804 return getZero(C->getType());
10805 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10806 }
10807
10808 // We could implement others, but I really doubt anyone writes loops like
10809 // this, and if they did, they would already be constant folded.
10810 return getCouldNotCompute();
10811}
10812
10813std::pair<const BasicBlock *, const BasicBlock *>
10814ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10815 const {
10816 // If the block has a unique predecessor, then there is no path from the
10817 // predecessor to the block that does not go through the direct edge
10818 // from the predecessor to the block.
10819 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10820 return {Pred, BB};
10821
10822 // A loop's header is defined to be a block that dominates the loop.
10823 // If the header has a unique predecessor outside the loop, it must be
10824 // a block that has exactly one successor that can reach the loop.
10825 if (const Loop *L = LI.getLoopFor(BB))
10826 return {L->getLoopPredecessor(), L->getHeader()};
10827
10828 return {nullptr, BB};
10829}
10830
10831/// SCEV structural equivalence is usually sufficient for testing whether two
10832/// expressions are equal, however for the purposes of looking for a condition
10833/// guarding a loop, it can be useful to be a little more general, since a
10834/// front-end may have replicated the controlling expression.
10835static bool HasSameValue(const SCEV *A, const SCEV *B) {
10836 // Quick check to see if they are the same SCEV.
10837 if (A == B) return true;
10838
10839 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
10840 // Not all instructions that are "identical" compute the same value. For
10841 // instance, two distinct alloca instructions allocating the same type are
10842 // identical and do not read memory; but compute distinct values.
10843 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
10844 };
10845
10846 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
10847 // two different instructions with the same value. Check for this case.
10848 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
10849 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
10850 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
10851 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
10852 if (ComputesEqualValues(AI, BI))
10853 return true;
10854
10855 // Otherwise assume they may have a different value.
10856 return false;
10857}
10858
10859static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
10860 const SCEV *Op0, *Op1;
10861 if (!match(S, m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))))
10862 return false;
10863 if (match(Op0, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
10864 LHS = Op1;
10865 return true;
10866 }
10867 if (match(Op1, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
10868 LHS = Op0;
10869 return true;
10870 }
10871 return false;
10872}
10873
10875 const SCEV *&RHS, unsigned Depth) {
10876 bool Changed = false;
10877 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
10878 // '0 != 0'.
10879 auto TrivialCase = [&](bool TriviallyTrue) {
10881 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
10882 return true;
10883 };
10884 // If we hit the max recursion limit bail out.
10885 if (Depth >= 3)
10886 return false;
10887
10888 const SCEV *NewLHS, *NewRHS;
10889 if (match(LHS, m_scev_c_Mul(m_SCEV(NewLHS), m_SCEVVScale())) &&
10890 match(RHS, m_scev_c_Mul(m_SCEV(NewRHS), m_SCEVVScale()))) {
10891 const SCEVMulExpr *LMul = cast<SCEVMulExpr>(LHS);
10892 const SCEVMulExpr *RMul = cast<SCEVMulExpr>(RHS);
10893
10894 // (X * vscale) pred (Y * vscale) ==> X pred Y
10895 // when both multiples are NSW.
10896 // (X * vscale) uicmp/eq/ne (Y * vscale) ==> X uicmp/eq/ne Y
10897 // when both multiples are NUW.
10898 if ((LMul->hasNoSignedWrap() && RMul->hasNoSignedWrap()) ||
10899 (LMul->hasNoUnsignedWrap() && RMul->hasNoUnsignedWrap() &&
10900 !ICmpInst::isSigned(Pred))) {
10901 LHS = NewLHS;
10902 RHS = NewRHS;
10903 Changed = true;
10904 }
10905 }
10906
10907 // Canonicalize a constant to the right side.
10908 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
10909 // Check for both operands constant.
10910 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
10911 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
10912 return TrivialCase(false);
10913 return TrivialCase(true);
10914 }
10915 // Otherwise swap the operands to put the constant on the right.
10916 std::swap(LHS, RHS);
10918 Changed = true;
10919 }
10920
10921 // If we're comparing an addrec with a value which is loop-invariant in the
10922 // addrec's loop, put the addrec on the left. Also make a dominance check,
10923 // as both operands could be addrecs loop-invariant in each other's loop.
10924 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
10925 const Loop *L = AR->getLoop();
10926 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
10927 std::swap(LHS, RHS);
10929 Changed = true;
10930 }
10931 }
10932
10933 // If there's a constant operand, canonicalize comparisons with boundary
10934 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
10935 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
10936 const APInt &RA = RC->getAPInt();
10937
10938 bool SimplifiedByConstantRange = false;
10939
10940 if (!ICmpInst::isEquality(Pred)) {
10942 if (ExactCR.isFullSet())
10943 return TrivialCase(true);
10944 if (ExactCR.isEmptySet())
10945 return TrivialCase(false);
10946
10947 APInt NewRHS;
10948 CmpInst::Predicate NewPred;
10949 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
10950 ICmpInst::isEquality(NewPred)) {
10951 // We were able to convert an inequality to an equality.
10952 Pred = NewPred;
10953 RHS = getConstant(NewRHS);
10954 Changed = SimplifiedByConstantRange = true;
10955 }
10956 }
10957
10958 if (!SimplifiedByConstantRange) {
10959 switch (Pred) {
10960 default:
10961 break;
10962 case ICmpInst::ICMP_EQ:
10963 case ICmpInst::ICMP_NE:
10964 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
10965 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
10966 Changed = true;
10967 break;
10968
10969 // The "Should have been caught earlier!" messages refer to the fact
10970 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
10971 // should have fired on the corresponding cases, and canonicalized the
10972 // check to trivial case.
10973
10974 case ICmpInst::ICMP_UGE:
10975 assert(!RA.isMinValue() && "Should have been caught earlier!");
10976 Pred = ICmpInst::ICMP_UGT;
10977 RHS = getConstant(RA - 1);
10978 Changed = true;
10979 break;
10980 case ICmpInst::ICMP_ULE:
10981 assert(!RA.isMaxValue() && "Should have been caught earlier!");
10982 Pred = ICmpInst::ICMP_ULT;
10983 RHS = getConstant(RA + 1);
10984 Changed = true;
10985 break;
10986 case ICmpInst::ICMP_SGE:
10987 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
10988 Pred = ICmpInst::ICMP_SGT;
10989 RHS = getConstant(RA - 1);
10990 Changed = true;
10991 break;
10992 case ICmpInst::ICMP_SLE:
10993 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
10994 Pred = ICmpInst::ICMP_SLT;
10995 RHS = getConstant(RA + 1);
10996 Changed = true;
10997 break;
10998 }
10999 }
11000 }
11001
11002 // Check for obvious equality.
11003 if (HasSameValue(LHS, RHS)) {
11004 if (ICmpInst::isTrueWhenEqual(Pred))
11005 return TrivialCase(true);
11007 return TrivialCase(false);
11008 }
11009
11010 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
11011 // adding or subtracting 1 from one of the operands.
11012 switch (Pred) {
11013 case ICmpInst::ICMP_SLE:
11014 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
11015 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
11017 Pred = ICmpInst::ICMP_SLT;
11018 Changed = true;
11019 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
11020 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
11022 Pred = ICmpInst::ICMP_SLT;
11023 Changed = true;
11024 }
11025 break;
11026 case ICmpInst::ICMP_SGE:
11027 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
11028 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
11030 Pred = ICmpInst::ICMP_SGT;
11031 Changed = true;
11032 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
11033 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11035 Pred = ICmpInst::ICMP_SGT;
11036 Changed = true;
11037 }
11038 break;
11039 case ICmpInst::ICMP_ULE:
11040 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
11041 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
11043 Pred = ICmpInst::ICMP_ULT;
11044 Changed = true;
11045 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
11046 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
11047 Pred = ICmpInst::ICMP_ULT;
11048 Changed = true;
11049 }
11050 break;
11051 case ICmpInst::ICMP_UGE:
11052 // If RHS is an op we can fold the -1, try that first.
11053 // Otherwise prefer LHS to preserve the nuw flag.
11054 if ((isa<SCEVConstant>(RHS) ||
11056 isa<SCEVConstant>(cast<SCEVNAryExpr>(RHS)->getOperand(0)))) &&
11057 !getUnsignedRangeMin(RHS).isMinValue()) {
11058 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11059 Pred = ICmpInst::ICMP_UGT;
11060 Changed = true;
11061 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
11062 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11064 Pred = ICmpInst::ICMP_UGT;
11065 Changed = true;
11066 } else if (!getUnsignedRangeMin(RHS).isMinValue()) {
11067 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11068 Pred = ICmpInst::ICMP_UGT;
11069 Changed = true;
11070 }
11071 break;
11072 default:
11073 break;
11074 }
11075
11076 // TODO: More simplifications are possible here.
11077
11078 // Recursively simplify until we either hit a recursion limit or nothing
11079 // changes.
11080 if (Changed)
11081 (void)SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
11082
11083 return Changed;
11084}
11085
11087 return getSignedRangeMax(S).isNegative();
11088}
11089
11093
11095 return !getSignedRangeMin(S).isNegative();
11096}
11097
11101
11103 // Query push down for cases where the unsigned range is
11104 // less than sufficient.
11105 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
11106 return isKnownNonZero(SExt->getOperand(0));
11107 return getUnsignedRangeMin(S) != 0;
11108}
11109
11111 bool OrNegative) {
11112 auto NonRecursive = [this, OrNegative](const SCEV *S) {
11113 if (auto *C = dyn_cast<SCEVConstant>(S))
11114 return C->getAPInt().isPowerOf2() ||
11115 (OrNegative && C->getAPInt().isNegatedPowerOf2());
11116
11117 // The vscale_range indicates vscale is a power-of-two.
11118 return isa<SCEVVScale>(S) && F.hasFnAttribute(Attribute::VScaleRange);
11119 };
11120
11121 if (NonRecursive(S))
11122 return true;
11123
11124 auto *Mul = dyn_cast<SCEVMulExpr>(S);
11125 if (!Mul)
11126 return false;
11127 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
11128}
11129
11131 const SCEV *S, uint64_t M,
11133 if (M == 0)
11134 return false;
11135 if (M == 1)
11136 return true;
11137
11138 // Recursively check AddRec operands. An AddRecExpr S is a multiple of M if S
11139 // starts with a multiple of M and at every iteration step S only adds
11140 // multiples of M.
11141 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
11142 return isKnownMultipleOf(AddRec->getStart(), M, Assumptions) &&
11143 isKnownMultipleOf(AddRec->getStepRecurrence(*this), M, Assumptions);
11144
11145 // For a constant, check that "S % M == 0".
11146 if (auto *Cst = dyn_cast<SCEVConstant>(S)) {
11147 APInt C = Cst->getAPInt();
11148 return C.urem(M) == 0;
11149 }
11150
11151 // TODO: Also check other SCEV expressions, i.e., SCEVAddRecExpr, etc.
11152
11153 // Basic tests have failed.
11154 // Check "S % M == 0" at compile time and record runtime Assumptions.
11155 auto *STy = dyn_cast<IntegerType>(S->getType());
11156 const SCEV *SmodM =
11157 getURemExpr(S, getConstant(ConstantInt::get(STy, M, false)));
11158 const SCEV *Zero = getZero(STy);
11159
11160 // Check whether "S % M == 0" is known at compile time.
11161 if (isKnownPredicate(ICmpInst::ICMP_EQ, SmodM, Zero))
11162 return true;
11163
11164 // Check whether "S % M != 0" is known at compile time.
11165 if (isKnownPredicate(ICmpInst::ICMP_NE, SmodM, Zero))
11166 return false;
11167
11169
11170 // Detect redundant predicates.
11171 for (auto *A : Assumptions)
11172 if (A->implies(P, *this))
11173 return true;
11174
11175 // Only record non-redundant predicates.
11176 Assumptions.push_back(P);
11177 return true;
11178}
11179
11181 return ((isKnownNonNegative(S1) && isKnownNonNegative(S2)) ||
11183}
11184
11185std::pair<const SCEV *, const SCEV *>
11187 // Compute SCEV on entry of loop L.
11188 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
11189 if (Start == getCouldNotCompute())
11190 return { Start, Start };
11191 // Compute post increment SCEV for loop L.
11192 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
11193 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
11194 return { Start, PostInc };
11195}
11196
11198 const SCEV *RHS) {
11199 // First collect all loops.
11201 getUsedLoops(LHS, LoopsUsed);
11202 getUsedLoops(RHS, LoopsUsed);
11203
11204 if (LoopsUsed.empty())
11205 return false;
11206
11207 // Domination relationship must be a linear order on collected loops.
11208#ifndef NDEBUG
11209 for (const auto *L1 : LoopsUsed)
11210 for (const auto *L2 : LoopsUsed)
11211 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
11212 DT.dominates(L2->getHeader(), L1->getHeader())) &&
11213 "Domination relationship is not a linear order");
11214#endif
11215
11216 const Loop *MDL =
11217 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
11218 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
11219 });
11220
11221 // Get init and post increment value for LHS.
11222 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
11223 // if LHS contains unknown non-invariant SCEV then bail out.
11224 if (SplitLHS.first == getCouldNotCompute())
11225 return false;
11226 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
11227 // Get init and post increment value for RHS.
11228 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
11229 // if RHS contains unknown non-invariant SCEV then bail out.
11230 if (SplitRHS.first == getCouldNotCompute())
11231 return false;
11232 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
11233 // It is possible that init SCEV contains an invariant load but it does
11234 // not dominate MDL and is not available at MDL loop entry, so we should
11235 // check it here.
11236 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
11237 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
11238 return false;
11239
11240 // It seems backedge guard check is faster than entry one so in some cases
11241 // it can speed up whole estimation by short circuit
11242 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
11243 SplitRHS.second) &&
11244 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
11245}
11246
11248 const SCEV *RHS) {
11249 // Canonicalize the inputs first.
11250 (void)SimplifyICmpOperands(Pred, LHS, RHS);
11251
11252 if (isKnownViaInduction(Pred, LHS, RHS))
11253 return true;
11254
11255 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
11256 return true;
11257
11258 // Otherwise see what can be done with some simple reasoning.
11259 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
11260}
11261
11263 const SCEV *LHS,
11264 const SCEV *RHS) {
11265 if (isKnownPredicate(Pred, LHS, RHS))
11266 return true;
11268 return false;
11269 return std::nullopt;
11270}
11271
11273 const SCEV *RHS,
11274 const Instruction *CtxI) {
11275 // TODO: Analyze guards and assumes from Context's block.
11276 return isKnownPredicate(Pred, LHS, RHS) ||
11277 isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS);
11278}
11279
11280std::optional<bool>
11282 const SCEV *RHS, const Instruction *CtxI) {
11283 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
11284 if (KnownWithoutContext)
11285 return KnownWithoutContext;
11286
11287 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
11288 return true;
11290 CtxI->getParent(), ICmpInst::getInverseCmpPredicate(Pred), LHS, RHS))
11291 return false;
11292 return std::nullopt;
11293}
11294
11296 const SCEVAddRecExpr *LHS,
11297 const SCEV *RHS) {
11298 const Loop *L = LHS->getLoop();
11299 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
11300 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
11301}
11302
11303std::optional<ScalarEvolution::MonotonicPredicateType>
11305 ICmpInst::Predicate Pred) {
11306 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
11307
11308#ifndef NDEBUG
11309 // Verify an invariant: inverting the predicate should turn a monotonically
11310 // increasing change to a monotonically decreasing one, and vice versa.
11311 if (Result) {
11312 auto ResultSwapped =
11313 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11314
11315 assert(*ResultSwapped != *Result &&
11316 "monotonicity should flip as we flip the predicate");
11317 }
11318#endif
11319
11320 return Result;
11321}
11322
11323std::optional<ScalarEvolution::MonotonicPredicateType>
11324ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11325 ICmpInst::Predicate Pred) {
11326 // A zero step value for LHS means the induction variable is essentially a
11327 // loop invariant value. We don't really depend on the predicate actually
11328 // flipping from false to true (for increasing predicates, and the other way
11329 // around for decreasing predicates), all we care about is that *if* the
11330 // predicate changes then it only changes from false to true.
11331 //
11332 // A zero step value in itself is not very useful, but there may be places
11333 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11334 // as general as possible.
11335
11336 // Only handle LE/LT/GE/GT predicates.
11337 if (!ICmpInst::isRelational(Pred))
11338 return std::nullopt;
11339
11340 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11341 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11342 "Should be greater or less!");
11343
11344 // Check that AR does not wrap.
11345 if (ICmpInst::isUnsigned(Pred)) {
11346 if (!LHS->hasNoUnsignedWrap())
11347 return std::nullopt;
11349 }
11350 assert(ICmpInst::isSigned(Pred) &&
11351 "Relational predicate is either signed or unsigned!");
11352 if (!LHS->hasNoSignedWrap())
11353 return std::nullopt;
11354
11355 const SCEV *Step = LHS->getStepRecurrence(*this);
11356
11357 if (isKnownNonNegative(Step))
11359
11360 if (isKnownNonPositive(Step))
11362
11363 return std::nullopt;
11364}
11365
11366std::optional<ScalarEvolution::LoopInvariantPredicate>
11368 const SCEV *RHS, const Loop *L,
11369 const Instruction *CtxI) {
11370 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11371 if (!isLoopInvariant(RHS, L)) {
11372 if (!isLoopInvariant(LHS, L))
11373 return std::nullopt;
11374
11375 std::swap(LHS, RHS);
11377 }
11378
11379 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11380 if (!ArLHS || ArLHS->getLoop() != L)
11381 return std::nullopt;
11382
11383 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11384 if (!MonotonicType)
11385 return std::nullopt;
11386 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11387 // true as the loop iterates, and the backedge is control dependent on
11388 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11389 //
11390 // * if the predicate was false in the first iteration then the predicate
11391 // is never evaluated again, since the loop exits without taking the
11392 // backedge.
11393 // * if the predicate was true in the first iteration then it will
11394 // continue to be true for all future iterations since it is
11395 // monotonically increasing.
11396 //
11397 // For both the above possibilities, we can replace the loop varying
11398 // predicate with its value on the first iteration of the loop (which is
11399 // loop invariant).
11400 //
11401 // A similar reasoning applies for a monotonically decreasing predicate, by
11402 // replacing true with false and false with true in the above two bullets.
11404 auto P = Increasing ? Pred : ICmpInst::getInverseCmpPredicate(Pred);
11405
11406 if (isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
11408 RHS);
11409
11410 if (!CtxI)
11411 return std::nullopt;
11412 // Try to prove via context.
11413 // TODO: Support other cases.
11414 switch (Pred) {
11415 default:
11416 break;
11417 case ICmpInst::ICMP_ULE:
11418 case ICmpInst::ICMP_ULT: {
11419 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11420 // Given preconditions
11421 // (1) ArLHS does not cross the border of positive and negative parts of
11422 // range because of:
11423 // - Positive step; (TODO: lift this limitation)
11424 // - nuw - does not cross zero boundary;
11425 // - nsw - does not cross SINT_MAX boundary;
11426 // (2) ArLHS <s RHS
11427 // (3) RHS >=s 0
11428 // we can replace the loop variant ArLHS <u RHS condition with loop
11429 // invariant Start(ArLHS) <u RHS.
11430 //
11431 // Because of (1) there are two options:
11432 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11433 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11434 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11435 // Because of (2) ArLHS <u RHS is trivially true.
11436 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11437 // We can strengthen this to Start(ArLHS) <u RHS.
11438 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11439 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11440 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11441 isKnownNonNegative(RHS) &&
11442 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11444 RHS);
11445 }
11446 }
11447
11448 return std::nullopt;
11449}
11450
11451std::optional<ScalarEvolution::LoopInvariantPredicate>
11453 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11454 const Instruction *CtxI, const SCEV *MaxIter) {
11456 Pred, LHS, RHS, L, CtxI, MaxIter))
11457 return LIP;
11458 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11459 // Number of iterations expressed as UMIN isn't always great for expressing
11460 // the value on the last iteration. If the straightforward approach didn't
11461 // work, try the following trick: if the a predicate is invariant for X, it
11462 // is also invariant for umin(X, ...). So try to find something that works
11463 // among subexpressions of MaxIter expressed as umin.
11464 for (auto *Op : UMin->operands())
11466 Pred, LHS, RHS, L, CtxI, Op))
11467 return LIP;
11468 return std::nullopt;
11469}
11470
11471std::optional<ScalarEvolution::LoopInvariantPredicate>
11473 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11474 const Instruction *CtxI, const SCEV *MaxIter) {
11475 // Try to prove the following set of facts:
11476 // - The predicate is monotonic in the iteration space.
11477 // - If the check does not fail on the 1st iteration:
11478 // - No overflow will happen during first MaxIter iterations;
11479 // - It will not fail on the MaxIter'th iteration.
11480 // If the check does fail on the 1st iteration, we leave the loop and no
11481 // other checks matter.
11482
11483 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11484 if (!isLoopInvariant(RHS, L)) {
11485 if (!isLoopInvariant(LHS, L))
11486 return std::nullopt;
11487
11488 std::swap(LHS, RHS);
11490 }
11491
11492 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11493 if (!AR || AR->getLoop() != L)
11494 return std::nullopt;
11495
11496 // The predicate must be relational (i.e. <, <=, >=, >).
11497 if (!ICmpInst::isRelational(Pred))
11498 return std::nullopt;
11499
11500 // TODO: Support steps other than +/- 1.
11501 const SCEV *Step = AR->getStepRecurrence(*this);
11502 auto *One = getOne(Step->getType());
11503 auto *MinusOne = getNegativeSCEV(One);
11504 if (Step != One && Step != MinusOne)
11505 return std::nullopt;
11506
11507 // Type mismatch here means that MaxIter is potentially larger than max
11508 // unsigned value in start type, which mean we cannot prove no wrap for the
11509 // indvar.
11510 if (AR->getType() != MaxIter->getType())
11511 return std::nullopt;
11512
11513 // Value of IV on suggested last iteration.
11514 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11515 // Does it still meet the requirement?
11516 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11517 return std::nullopt;
11518 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11519 // not exceed max unsigned value of this type), this effectively proves
11520 // that there is no wrap during the iteration. To prove that there is no
11521 // signed/unsigned wrap, we need to check that
11522 // Start <= Last for step = 1 or Start >= Last for step = -1.
11523 ICmpInst::Predicate NoOverflowPred =
11525 if (Step == MinusOne)
11526 NoOverflowPred = ICmpInst::getSwappedCmpPredicate(NoOverflowPred);
11527 const SCEV *Start = AR->getStart();
11528 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11529 return std::nullopt;
11530
11531 // Everything is fine.
11532 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11533}
11534
11535bool ScalarEvolution::isKnownPredicateViaConstantRanges(CmpPredicate Pred,
11536 const SCEV *LHS,
11537 const SCEV *RHS) {
11538 if (HasSameValue(LHS, RHS))
11539 return ICmpInst::isTrueWhenEqual(Pred);
11540
11541 auto CheckRange = [&](bool IsSigned) {
11542 auto RangeLHS = IsSigned ? getSignedRange(LHS) : getUnsignedRange(LHS);
11543 auto RangeRHS = IsSigned ? getSignedRange(RHS) : getUnsignedRange(RHS);
11544 return RangeLHS.icmp(Pred, RangeRHS);
11545 };
11546
11547 // The check at the top of the function catches the case where the values are
11548 // known to be equal.
11549 if (Pred == CmpInst::ICMP_EQ)
11550 return false;
11551
11552 if (Pred == CmpInst::ICMP_NE) {
11553 if (CheckRange(true) || CheckRange(false))
11554 return true;
11555 auto *Diff = getMinusSCEV(LHS, RHS);
11556 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11557 }
11558
11559 return CheckRange(CmpInst::isSigned(Pred));
11560}
11561
11562bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred,
11563 const SCEV *LHS,
11564 const SCEV *RHS) {
11565 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11566 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11567 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11568 // OutC1 and OutC2.
11569 auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
11570 APInt &OutC1, APInt &OutC2,
11571 SCEV::NoWrapFlags ExpectedFlags) {
11572 const SCEV *XNonConstOp, *XConstOp;
11573 const SCEV *YNonConstOp, *YConstOp;
11574 SCEV::NoWrapFlags XFlagsPresent;
11575 SCEV::NoWrapFlags YFlagsPresent;
11576
11577 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11578 XConstOp = getZero(X->getType());
11579 XNonConstOp = X;
11580 XFlagsPresent = ExpectedFlags;
11581 }
11582 if (!isa<SCEVConstant>(XConstOp))
11583 return false;
11584
11585 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11586 YConstOp = getZero(Y->getType());
11587 YNonConstOp = Y;
11588 YFlagsPresent = ExpectedFlags;
11589 }
11590
11591 if (YNonConstOp != XNonConstOp)
11592 return false;
11593
11594 if (!isa<SCEVConstant>(YConstOp))
11595 return false;
11596
11597 // When matching ADDs with NUW flags (and unsigned predicates), only the
11598 // second ADD (with the larger constant) requires NUW.
11599 if ((YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11600 return false;
11601 if (ExpectedFlags != SCEV::FlagNUW &&
11602 (XFlagsPresent & ExpectedFlags) != ExpectedFlags) {
11603 return false;
11604 }
11605
11606 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11607 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11608
11609 return true;
11610 };
11611
11612 APInt C1;
11613 APInt C2;
11614
11615 switch (Pred) {
11616 default:
11617 break;
11618
11619 case ICmpInst::ICMP_SGE:
11620 std::swap(LHS, RHS);
11621 [[fallthrough]];
11622 case ICmpInst::ICMP_SLE:
11623 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11624 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11625 return true;
11626
11627 break;
11628
11629 case ICmpInst::ICMP_SGT:
11630 std::swap(LHS, RHS);
11631 [[fallthrough]];
11632 case ICmpInst::ICMP_SLT:
11633 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11634 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11635 return true;
11636
11637 break;
11638
11639 case ICmpInst::ICMP_UGE:
11640 std::swap(LHS, RHS);
11641 [[fallthrough]];
11642 case ICmpInst::ICMP_ULE:
11643 // (X + C1) u<= (X + C2)<nuw> for C1 u<= C2.
11644 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11645 return true;
11646
11647 break;
11648
11649 case ICmpInst::ICMP_UGT:
11650 std::swap(LHS, RHS);
11651 [[fallthrough]];
11652 case ICmpInst::ICMP_ULT:
11653 // (X + C1) u< (X + C2)<nuw> if C1 u< C2.
11654 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11655 return true;
11656 break;
11657 }
11658
11659 return false;
11660}
11661
11662bool ScalarEvolution::isKnownPredicateViaSplitting(CmpPredicate Pred,
11663 const SCEV *LHS,
11664 const SCEV *RHS) {
11665 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11666 return false;
11667
11668 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11669 // the stack can result in exponential time complexity.
11670 SaveAndRestore Restore(ProvingSplitPredicate, true);
11671
11672 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11673 //
11674 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11675 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11676 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11677 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11678 // use isKnownPredicate later if needed.
11679 return isKnownNonNegative(RHS) &&
11682}
11683
11684bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, CmpPredicate Pred,
11685 const SCEV *LHS, const SCEV *RHS) {
11686 // No need to even try if we know the module has no guards.
11687 if (!HasGuards)
11688 return false;
11689
11690 return any_of(*BB, [&](const Instruction &I) {
11691 using namespace llvm::PatternMatch;
11692
11693 Value *Condition;
11695 m_Value(Condition))) &&
11696 isImpliedCond(Pred, LHS, RHS, Condition, false);
11697 });
11698}
11699
11700/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11701/// protected by a conditional between LHS and RHS. This is used to
11702/// to eliminate casts.
11704 CmpPredicate Pred,
11705 const SCEV *LHS,
11706 const SCEV *RHS) {
11707 // Interpret a null as meaning no loop, where there is obviously no guard
11708 // (interprocedural conditions notwithstanding). Do not bother about
11709 // unreachable loops.
11710 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11711 return true;
11712
11713 if (VerifyIR)
11714 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11715 "This cannot be done on broken IR!");
11716
11717
11718 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11719 return true;
11720
11721 BasicBlock *Latch = L->getLoopLatch();
11722 if (!Latch)
11723 return false;
11724
11725 BranchInst *LoopContinuePredicate =
11727 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
11728 isImpliedCond(Pred, LHS, RHS,
11729 LoopContinuePredicate->getCondition(),
11730 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11731 return true;
11732
11733 // We don't want more than one activation of the following loops on the stack
11734 // -- that can lead to O(n!) time complexity.
11735 if (WalkingBEDominatingConds)
11736 return false;
11737
11738 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11739
11740 // See if we can exploit a trip count to prove the predicate.
11741 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11742 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11743 if (LatchBECount != getCouldNotCompute()) {
11744 // We know that Latch branches back to the loop header exactly
11745 // LatchBECount times. This means the backdege condition at Latch is
11746 // equivalent to "{0,+,1} u< LatchBECount".
11747 Type *Ty = LatchBECount->getType();
11748 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11749 const SCEV *LoopCounter =
11750 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11751 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11752 LatchBECount))
11753 return true;
11754 }
11755
11756 // Check conditions due to any @llvm.assume intrinsics.
11757 for (auto &AssumeVH : AC.assumptions()) {
11758 if (!AssumeVH)
11759 continue;
11760 auto *CI = cast<CallInst>(AssumeVH);
11761 if (!DT.dominates(CI, Latch->getTerminator()))
11762 continue;
11763
11764 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11765 return true;
11766 }
11767
11768 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11769 return true;
11770
11771 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11772 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11773 assert(DTN && "should reach the loop header before reaching the root!");
11774
11775 BasicBlock *BB = DTN->getBlock();
11776 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11777 return true;
11778
11779 BasicBlock *PBB = BB->getSinglePredecessor();
11780 if (!PBB)
11781 continue;
11782
11783 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
11784 if (!ContinuePredicate || !ContinuePredicate->isConditional())
11785 continue;
11786
11787 Value *Condition = ContinuePredicate->getCondition();
11788
11789 // If we have an edge `E` within the loop body that dominates the only
11790 // latch, the condition guarding `E` also guards the backedge. This
11791 // reasoning works only for loops with a single latch.
11792
11793 BasicBlockEdge DominatingEdge(PBB, BB);
11794 if (DominatingEdge.isSingleEdge()) {
11795 // We're constructively (and conservatively) enumerating edges within the
11796 // loop body that dominate the latch. The dominator tree better agree
11797 // with us on this:
11798 assert(DT.dominates(DominatingEdge, Latch) && "should be!");
11799
11800 if (isImpliedCond(Pred, LHS, RHS, Condition,
11801 BB != ContinuePredicate->getSuccessor(0)))
11802 return true;
11803 }
11804 }
11805
11806 return false;
11807}
11808
11810 CmpPredicate Pred,
11811 const SCEV *LHS,
11812 const SCEV *RHS) {
11813 // Do not bother proving facts for unreachable code.
11814 if (!DT.isReachableFromEntry(BB))
11815 return true;
11816 if (VerifyIR)
11817 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11818 "This cannot be done on broken IR!");
11819
11820 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11821 // the facts (a >= b && a != b) separately. A typical situation is when the
11822 // non-strict comparison is known from ranges and non-equality is known from
11823 // dominating predicates. If we are proving strict comparison, we always try
11824 // to prove non-equality and non-strict comparison separately.
11825 CmpPredicate NonStrictPredicate = ICmpInst::getNonStrictCmpPredicate(Pred);
11826 const bool ProvingStrictComparison =
11827 Pred != NonStrictPredicate.dropSameSign();
11828 bool ProvedNonStrictComparison = false;
11829 bool ProvedNonEquality = false;
11830
11831 auto SplitAndProve = [&](std::function<bool(CmpPredicate)> Fn) -> bool {
11832 if (!ProvedNonStrictComparison)
11833 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11834 if (!ProvedNonEquality)
11835 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11836 if (ProvedNonStrictComparison && ProvedNonEquality)
11837 return true;
11838 return false;
11839 };
11840
11841 if (ProvingStrictComparison) {
11842 auto ProofFn = [&](CmpPredicate P) {
11843 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
11844 };
11845 if (SplitAndProve(ProofFn))
11846 return true;
11847 }
11848
11849 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
11850 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
11851 const Instruction *CtxI = &BB->front();
11852 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
11853 return true;
11854 if (ProvingStrictComparison) {
11855 auto ProofFn = [&](CmpPredicate P) {
11856 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
11857 };
11858 if (SplitAndProve(ProofFn))
11859 return true;
11860 }
11861 return false;
11862 };
11863
11864 // Starting at the block's predecessor, climb up the predecessor chain, as long
11865 // as there are predecessors that can be found that have unique successors
11866 // leading to the original block.
11867 const Loop *ContainingLoop = LI.getLoopFor(BB);
11868 const BasicBlock *PredBB;
11869 if (ContainingLoop && ContainingLoop->getHeader() == BB)
11870 PredBB = ContainingLoop->getLoopPredecessor();
11871 else
11872 PredBB = BB->getSinglePredecessor();
11873 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
11874 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
11875 const BranchInst *BlockEntryPredicate =
11876 dyn_cast<BranchInst>(Pair.first->getTerminator());
11877 if (!BlockEntryPredicate || BlockEntryPredicate->isUnconditional())
11878 continue;
11879
11880 if (ProveViaCond(BlockEntryPredicate->getCondition(),
11881 BlockEntryPredicate->getSuccessor(0) != Pair.second))
11882 return true;
11883 }
11884
11885 // Check conditions due to any @llvm.assume intrinsics.
11886 for (auto &AssumeVH : AC.assumptions()) {
11887 if (!AssumeVH)
11888 continue;
11889 auto *CI = cast<CallInst>(AssumeVH);
11890 if (!DT.dominates(CI, BB))
11891 continue;
11892
11893 if (ProveViaCond(CI->getArgOperand(0), false))
11894 return true;
11895 }
11896
11897 // Check conditions due to any @llvm.experimental.guard intrinsics.
11898 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
11899 F.getParent(), Intrinsic::experimental_guard);
11900 if (GuardDecl)
11901 for (const auto *GU : GuardDecl->users())
11902 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
11903 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
11904 if (ProveViaCond(Guard->getArgOperand(0), false))
11905 return true;
11906 return false;
11907}
11908
11910 const SCEV *LHS,
11911 const SCEV *RHS) {
11912 // Interpret a null as meaning no loop, where there is obviously no guard
11913 // (interprocedural conditions notwithstanding).
11914 if (!L)
11915 return false;
11916
11917 // Both LHS and RHS must be available at loop entry.
11919 "LHS is not available at Loop Entry");
11921 "RHS is not available at Loop Entry");
11922
11923 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11924 return true;
11925
11926 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
11927}
11928
11929bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11930 const SCEV *RHS,
11931 const Value *FoundCondValue, bool Inverse,
11932 const Instruction *CtxI) {
11933 // False conditions implies anything. Do not bother analyzing it further.
11934 if (FoundCondValue ==
11935 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
11936 return true;
11937
11938 if (!PendingLoopPredicates.insert(FoundCondValue).second)
11939 return false;
11940
11941 llvm::scope_exit ClearOnExit(
11942 [&]() { PendingLoopPredicates.erase(FoundCondValue); });
11943
11944 // Recursively handle And and Or conditions.
11945 const Value *Op0, *Op1;
11946 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
11947 if (!Inverse)
11948 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11949 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11950 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
11951 if (Inverse)
11952 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11953 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11954 }
11955
11956 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
11957 if (!ICI) return false;
11958
11959 // Now that we found a conditional branch that dominates the loop or controls
11960 // the loop latch. Check to see if it is the comparison we are looking for.
11961 CmpPredicate FoundPred;
11962 if (Inverse)
11963 FoundPred = ICI->getInverseCmpPredicate();
11964 else
11965 FoundPred = ICI->getCmpPredicate();
11966
11967 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
11968 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
11969
11970 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
11971}
11972
11973bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11974 const SCEV *RHS, CmpPredicate FoundPred,
11975 const SCEV *FoundLHS, const SCEV *FoundRHS,
11976 const Instruction *CtxI) {
11977 // Balance the types.
11978 if (getTypeSizeInBits(LHS->getType()) <
11979 getTypeSizeInBits(FoundLHS->getType())) {
11980 // For unsigned and equality predicates, try to prove that both found
11981 // operands fit into narrow unsigned range. If so, try to prove facts in
11982 // narrow types.
11983 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
11984 !FoundRHS->getType()->isPointerTy()) {
11985 auto *NarrowType = LHS->getType();
11986 auto *WideType = FoundLHS->getType();
11987 auto BitWidth = getTypeSizeInBits(NarrowType);
11988 const SCEV *MaxValue = getZeroExtendExpr(
11990 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
11991 MaxValue) &&
11992 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
11993 MaxValue)) {
11994 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
11995 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
11996 // We cannot preserve samesign after truncation.
11997 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred.dropSameSign(),
11998 TruncFoundLHS, TruncFoundRHS, CtxI))
11999 return true;
12000 }
12001 }
12002
12003 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
12004 return false;
12005 if (CmpInst::isSigned(Pred)) {
12006 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
12007 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
12008 } else {
12009 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
12010 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
12011 }
12012 } else if (getTypeSizeInBits(LHS->getType()) >
12013 getTypeSizeInBits(FoundLHS->getType())) {
12014 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
12015 return false;
12016 if (CmpInst::isSigned(FoundPred)) {
12017 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
12018 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
12019 } else {
12020 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
12021 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
12022 }
12023 }
12024 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
12025 FoundRHS, CtxI);
12026}
12027
12028bool ScalarEvolution::isImpliedCondBalancedTypes(
12029 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
12030 const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) {
12032 getTypeSizeInBits(FoundLHS->getType()) &&
12033 "Types should be balanced!");
12034 // Canonicalize the query to match the way instcombine will have
12035 // canonicalized the comparison.
12036 if (SimplifyICmpOperands(Pred, LHS, RHS))
12037 if (LHS == RHS)
12038 return CmpInst::isTrueWhenEqual(Pred);
12039 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
12040 if (FoundLHS == FoundRHS)
12041 return CmpInst::isFalseWhenEqual(FoundPred);
12042
12043 // Check to see if we can make the LHS or RHS match.
12044 if (LHS == FoundRHS || RHS == FoundLHS) {
12045 if (isa<SCEVConstant>(RHS)) {
12046 std::swap(FoundLHS, FoundRHS);
12047 FoundPred = ICmpInst::getSwappedCmpPredicate(FoundPred);
12048 } else {
12049 std::swap(LHS, RHS);
12051 }
12052 }
12053
12054 // Check whether the found predicate is the same as the desired predicate.
12055 if (auto P = CmpPredicate::getMatching(FoundPred, Pred))
12056 return isImpliedCondOperands(*P, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12057
12058 // Check whether swapping the found predicate makes it the same as the
12059 // desired predicate.
12060 if (auto P = CmpPredicate::getMatching(
12061 ICmpInst::getSwappedCmpPredicate(FoundPred), Pred)) {
12062 // We can write the implication
12063 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
12064 // using one of the following ways:
12065 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
12066 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
12067 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
12068 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
12069 // Forms 1. and 2. require swapping the operands of one condition. Don't
12070 // do this if it would break canonical constant/addrec ordering.
12072 return isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P), RHS,
12073 LHS, FoundLHS, FoundRHS, CtxI);
12074 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
12075 return isImpliedCondOperands(*P, LHS, RHS, FoundRHS, FoundLHS, CtxI);
12076
12077 // There's no clear preference between forms 3. and 4., try both. Avoid
12078 // forming getNotSCEV of pointer values as the resulting subtract is
12079 // not legal.
12080 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
12081 isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P),
12082 getNotSCEV(LHS), getNotSCEV(RHS), FoundLHS,
12083 FoundRHS, CtxI))
12084 return true;
12085
12086 if (!FoundLHS->getType()->isPointerTy() &&
12087 !FoundRHS->getType()->isPointerTy() &&
12088 isImpliedCondOperands(*P, LHS, RHS, getNotSCEV(FoundLHS),
12089 getNotSCEV(FoundRHS), CtxI))
12090 return true;
12091
12092 return false;
12093 }
12094
12095 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
12096 CmpInst::Predicate P2) {
12097 assert(P1 != P2 && "Handled earlier!");
12098 return CmpInst::isRelational(P2) &&
12100 };
12101 if (IsSignFlippedPredicate(Pred, FoundPred)) {
12102 // Unsigned comparison is the same as signed comparison when both the
12103 // operands are non-negative or negative.
12104 if (haveSameSign(FoundLHS, FoundRHS))
12105 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12106 // Create local copies that we can freely swap and canonicalize our
12107 // conditions to "le/lt".
12108 CmpPredicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
12109 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
12110 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
12111 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
12112 CanonicalPred = ICmpInst::getSwappedCmpPredicate(CanonicalPred);
12113 CanonicalFoundPred = ICmpInst::getSwappedCmpPredicate(CanonicalFoundPred);
12114 std::swap(CanonicalLHS, CanonicalRHS);
12115 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
12116 }
12117 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
12118 "Must be!");
12119 assert((ICmpInst::isLT(CanonicalFoundPred) ||
12120 ICmpInst::isLE(CanonicalFoundPred)) &&
12121 "Must be!");
12122 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
12123 // Use implication:
12124 // x <u y && y >=s 0 --> x <s y.
12125 // If we can prove the left part, the right part is also proven.
12126 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12127 CanonicalRHS, CanonicalFoundLHS,
12128 CanonicalFoundRHS);
12129 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
12130 // Use implication:
12131 // x <s y && y <s 0 --> x <u y.
12132 // If we can prove the left part, the right part is also proven.
12133 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12134 CanonicalRHS, CanonicalFoundLHS,
12135 CanonicalFoundRHS);
12136 }
12137
12138 // Check if we can make progress by sharpening ranges.
12139 if (FoundPred == ICmpInst::ICMP_NE &&
12140 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
12141
12142 const SCEVConstant *C = nullptr;
12143 const SCEV *V = nullptr;
12144
12145 if (isa<SCEVConstant>(FoundLHS)) {
12146 C = cast<SCEVConstant>(FoundLHS);
12147 V = FoundRHS;
12148 } else {
12149 C = cast<SCEVConstant>(FoundRHS);
12150 V = FoundLHS;
12151 }
12152
12153 // The guarding predicate tells us that C != V. If the known range
12154 // of V is [C, t), we can sharpen the range to [C + 1, t). The
12155 // range we consider has to correspond to same signedness as the
12156 // predicate we're interested in folding.
12157
12158 APInt Min = ICmpInst::isSigned(Pred) ?
12160
12161 if (Min == C->getAPInt()) {
12162 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
12163 // This is true even if (Min + 1) wraps around -- in case of
12164 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
12165
12166 APInt SharperMin = Min + 1;
12167
12168 switch (Pred) {
12169 case ICmpInst::ICMP_SGE:
12170 case ICmpInst::ICMP_UGE:
12171 // We know V `Pred` SharperMin. If this implies LHS `Pred`
12172 // RHS, we're done.
12173 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
12174 CtxI))
12175 return true;
12176 [[fallthrough]];
12177
12178 case ICmpInst::ICMP_SGT:
12179 case ICmpInst::ICMP_UGT:
12180 // We know from the range information that (V `Pred` Min ||
12181 // V == Min). We know from the guarding condition that !(V
12182 // == Min). This gives us
12183 //
12184 // V `Pred` Min || V == Min && !(V == Min)
12185 // => V `Pred` Min
12186 //
12187 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
12188
12189 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
12190 return true;
12191 break;
12192
12193 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
12194 case ICmpInst::ICMP_SLE:
12195 case ICmpInst::ICMP_ULE:
12196 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12197 LHS, V, getConstant(SharperMin), CtxI))
12198 return true;
12199 [[fallthrough]];
12200
12201 case ICmpInst::ICMP_SLT:
12202 case ICmpInst::ICMP_ULT:
12203 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12204 LHS, V, getConstant(Min), CtxI))
12205 return true;
12206 break;
12207
12208 default:
12209 // No change
12210 break;
12211 }
12212 }
12213 }
12214
12215 // Check whether the actual condition is beyond sufficient.
12216 if (FoundPred == ICmpInst::ICMP_EQ)
12217 if (ICmpInst::isTrueWhenEqual(Pred))
12218 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12219 return true;
12220 if (Pred == ICmpInst::ICMP_NE)
12221 if (!ICmpInst::isTrueWhenEqual(FoundPred))
12222 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12223 return true;
12224
12225 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
12226 return true;
12227
12228 // Otherwise assume the worst.
12229 return false;
12230}
12231
12232bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
12233 const SCEV *&L, const SCEV *&R,
12234 SCEV::NoWrapFlags &Flags) {
12235 if (!match(Expr, m_scev_Add(m_SCEV(L), m_SCEV(R))))
12236 return false;
12237
12238 Flags = cast<SCEVAddExpr>(Expr)->getNoWrapFlags();
12239 return true;
12240}
12241
12242std::optional<APInt>
12244 // We avoid subtracting expressions here because this function is usually
12245 // fairly deep in the call stack (i.e. is called many times).
12246
12247 unsigned BW = getTypeSizeInBits(More->getType());
12248 APInt Diff(BW, 0);
12249 APInt DiffMul(BW, 1);
12250 // Try various simplifications to reduce the difference to a constant. Limit
12251 // the number of allowed simplifications to keep compile-time low.
12252 for (unsigned I = 0; I < 8; ++I) {
12253 if (More == Less)
12254 return Diff;
12255
12256 // Reduce addrecs with identical steps to their start value.
12258 const auto *LAR = cast<SCEVAddRecExpr>(Less);
12259 const auto *MAR = cast<SCEVAddRecExpr>(More);
12260
12261 if (LAR->getLoop() != MAR->getLoop())
12262 return std::nullopt;
12263
12264 // We look at affine expressions only; not for correctness but to keep
12265 // getStepRecurrence cheap.
12266 if (!LAR->isAffine() || !MAR->isAffine())
12267 return std::nullopt;
12268
12269 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
12270 return std::nullopt;
12271
12272 Less = LAR->getStart();
12273 More = MAR->getStart();
12274 continue;
12275 }
12276
12277 // Try to match a common constant multiply.
12278 auto MatchConstMul =
12279 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
12280 const APInt *C;
12281 const SCEV *Op;
12282 if (match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op))))
12283 return {{Op, *C}};
12284 return std::nullopt;
12285 };
12286 if (auto MatchedMore = MatchConstMul(More)) {
12287 if (auto MatchedLess = MatchConstMul(Less)) {
12288 if (MatchedMore->second == MatchedLess->second) {
12289 More = MatchedMore->first;
12290 Less = MatchedLess->first;
12291 DiffMul *= MatchedMore->second;
12292 continue;
12293 }
12294 }
12295 }
12296
12297 // Try to cancel out common factors in two add expressions.
12299 auto Add = [&](const SCEV *S, int Mul) {
12300 if (auto *C = dyn_cast<SCEVConstant>(S)) {
12301 if (Mul == 1) {
12302 Diff += C->getAPInt() * DiffMul;
12303 } else {
12304 assert(Mul == -1);
12305 Diff -= C->getAPInt() * DiffMul;
12306 }
12307 } else
12308 Multiplicity[S] += Mul;
12309 };
12310 auto Decompose = [&](const SCEV *S, int Mul) {
12311 if (isa<SCEVAddExpr>(S)) {
12312 for (const SCEV *Op : S->operands())
12313 Add(Op, Mul);
12314 } else
12315 Add(S, Mul);
12316 };
12317 Decompose(More, 1);
12318 Decompose(Less, -1);
12319
12320 // Check whether all the non-constants cancel out, or reduce to new
12321 // More/Less values.
12322 const SCEV *NewMore = nullptr, *NewLess = nullptr;
12323 for (const auto &[S, Mul] : Multiplicity) {
12324 if (Mul == 0)
12325 continue;
12326 if (Mul == 1) {
12327 if (NewMore)
12328 return std::nullopt;
12329 NewMore = S;
12330 } else if (Mul == -1) {
12331 if (NewLess)
12332 return std::nullopt;
12333 NewLess = S;
12334 } else
12335 return std::nullopt;
12336 }
12337
12338 // Values stayed the same, no point in trying further.
12339 if (NewMore == More || NewLess == Less)
12340 return std::nullopt;
12341
12342 More = NewMore;
12343 Less = NewLess;
12344
12345 // Reduced to constant.
12346 if (!More && !Less)
12347 return Diff;
12348
12349 // Left with variable on only one side, bail out.
12350 if (!More || !Less)
12351 return std::nullopt;
12352 }
12353
12354 // Did not reduce to constant.
12355 return std::nullopt;
12356}
12357
12358bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12359 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS,
12360 const SCEV *FoundRHS, const Instruction *CtxI) {
12361 // Try to recognize the following pattern:
12362 //
12363 // FoundRHS = ...
12364 // ...
12365 // loop:
12366 // FoundLHS = {Start,+,W}
12367 // context_bb: // Basic block from the same loop
12368 // known(Pred, FoundLHS, FoundRHS)
12369 //
12370 // If some predicate is known in the context of a loop, it is also known on
12371 // each iteration of this loop, including the first iteration. Therefore, in
12372 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12373 // prove the original pred using this fact.
12374 if (!CtxI)
12375 return false;
12376 const BasicBlock *ContextBB = CtxI->getParent();
12377 // Make sure AR varies in the context block.
12378 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12379 const Loop *L = AR->getLoop();
12380 // Make sure that context belongs to the loop and executes on 1st iteration
12381 // (if it ever executes at all).
12382 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12383 return false;
12384 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12385 return false;
12386 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12387 }
12388
12389 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12390 const Loop *L = AR->getLoop();
12391 // Make sure that context belongs to the loop and executes on 1st iteration
12392 // (if it ever executes at all).
12393 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12394 return false;
12395 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12396 return false;
12397 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12398 }
12399
12400 return false;
12401}
12402
12403bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(CmpPredicate Pred,
12404 const SCEV *LHS,
12405 const SCEV *RHS,
12406 const SCEV *FoundLHS,
12407 const SCEV *FoundRHS) {
12408 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12409 return false;
12410
12411 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12412 if (!AddRecLHS)
12413 return false;
12414
12415 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12416 if (!AddRecFoundLHS)
12417 return false;
12418
12419 // We'd like to let SCEV reason about control dependencies, so we constrain
12420 // both the inequalities to be about add recurrences on the same loop. This
12421 // way we can use isLoopEntryGuardedByCond later.
12422
12423 const Loop *L = AddRecFoundLHS->getLoop();
12424 if (L != AddRecLHS->getLoop())
12425 return false;
12426
12427 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12428 //
12429 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12430 // ... (2)
12431 //
12432 // Informal proof for (2), assuming (1) [*]:
12433 //
12434 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12435 //
12436 // Then
12437 //
12438 // FoundLHS s< FoundRHS s< INT_MIN - C
12439 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12440 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12441 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12442 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12443 // <=> FoundLHS + C s< FoundRHS + C
12444 //
12445 // [*]: (1) can be proved by ruling out overflow.
12446 //
12447 // [**]: This can be proved by analyzing all the four possibilities:
12448 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12449 // (A s>= 0, B s>= 0).
12450 //
12451 // Note:
12452 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12453 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12454 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12455 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12456 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12457 // C)".
12458
12459 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12460 if (!LDiff)
12461 return false;
12462 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12463 if (!RDiff || *LDiff != *RDiff)
12464 return false;
12465
12466 if (LDiff->isMinValue())
12467 return true;
12468
12469 APInt FoundRHSLimit;
12470
12471 if (Pred == CmpInst::ICMP_ULT) {
12472 FoundRHSLimit = -(*RDiff);
12473 } else {
12474 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12475 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12476 }
12477
12478 // Try to prove (1) or (2), as needed.
12479 return isAvailableAtLoopEntry(FoundRHS, L) &&
12480 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12481 getConstant(FoundRHSLimit));
12482}
12483
12484bool ScalarEvolution::isImpliedViaMerge(CmpPredicate Pred, const SCEV *LHS,
12485 const SCEV *RHS, const SCEV *FoundLHS,
12486 const SCEV *FoundRHS, unsigned Depth) {
12487 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12488
12489 llvm::scope_exit ClearOnExit([&]() {
12490 if (LPhi) {
12491 bool Erased = PendingMerges.erase(LPhi);
12492 assert(Erased && "Failed to erase LPhi!");
12493 (void)Erased;
12494 }
12495 if (RPhi) {
12496 bool Erased = PendingMerges.erase(RPhi);
12497 assert(Erased && "Failed to erase RPhi!");
12498 (void)Erased;
12499 }
12500 });
12501
12502 // Find respective Phis and check that they are not being pending.
12503 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12504 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12505 if (!PendingMerges.insert(Phi).second)
12506 return false;
12507 LPhi = Phi;
12508 }
12509 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12510 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12511 // If we detect a loop of Phi nodes being processed by this method, for
12512 // example:
12513 //
12514 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12515 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12516 //
12517 // we don't want to deal with a case that complex, so return conservative
12518 // answer false.
12519 if (!PendingMerges.insert(Phi).second)
12520 return false;
12521 RPhi = Phi;
12522 }
12523
12524 // If none of LHS, RHS is a Phi, nothing to do here.
12525 if (!LPhi && !RPhi)
12526 return false;
12527
12528 // If there is a SCEVUnknown Phi we are interested in, make it left.
12529 if (!LPhi) {
12530 std::swap(LHS, RHS);
12531 std::swap(FoundLHS, FoundRHS);
12532 std::swap(LPhi, RPhi);
12534 }
12535
12536 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12537 const BasicBlock *LBB = LPhi->getParent();
12538 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12539
12540 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12541 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12542 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12543 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12544 };
12545
12546 if (RPhi && RPhi->getParent() == LBB) {
12547 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12548 // If we compare two Phis from the same block, and for each entry block
12549 // the predicate is true for incoming values from this block, then the
12550 // predicate is also true for the Phis.
12551 for (const BasicBlock *IncBB : predecessors(LBB)) {
12552 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12553 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12554 if (!ProvedEasily(L, R))
12555 return false;
12556 }
12557 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12558 // Case two: RHS is also a Phi from the same basic block, and it is an
12559 // AddRec. It means that there is a loop which has both AddRec and Unknown
12560 // PHIs, for it we can compare incoming values of AddRec from above the loop
12561 // and latch with their respective incoming values of LPhi.
12562 // TODO: Generalize to handle loops with many inputs in a header.
12563 if (LPhi->getNumIncomingValues() != 2) return false;
12564
12565 auto *RLoop = RAR->getLoop();
12566 auto *Predecessor = RLoop->getLoopPredecessor();
12567 assert(Predecessor && "Loop with AddRec with no predecessor?");
12568 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12569 if (!ProvedEasily(L1, RAR->getStart()))
12570 return false;
12571 auto *Latch = RLoop->getLoopLatch();
12572 assert(Latch && "Loop with AddRec with no latch?");
12573 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12574 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12575 return false;
12576 } else {
12577 // In all other cases go over inputs of LHS and compare each of them to RHS,
12578 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12579 // At this point RHS is either a non-Phi, or it is a Phi from some block
12580 // different from LBB.
12581 for (const BasicBlock *IncBB : predecessors(LBB)) {
12582 // Check that RHS is available in this block.
12583 if (!dominates(RHS, IncBB))
12584 return false;
12585 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12586 // Make sure L does not refer to a value from a potentially previous
12587 // iteration of a loop.
12588 if (!properlyDominates(L, LBB))
12589 return false;
12590 // Addrecs are considered to properly dominate their loop, so are missed
12591 // by the previous check. Discard any values that have computable
12592 // evolution in this loop.
12593 if (auto *Loop = LI.getLoopFor(LBB))
12594 if (hasComputableLoopEvolution(L, Loop))
12595 return false;
12596 if (!ProvedEasily(L, RHS))
12597 return false;
12598 }
12599 }
12600 return true;
12601}
12602
12603bool ScalarEvolution::isImpliedCondOperandsViaShift(CmpPredicate Pred,
12604 const SCEV *LHS,
12605 const SCEV *RHS,
12606 const SCEV *FoundLHS,
12607 const SCEV *FoundRHS) {
12608 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12609 // sure that we are dealing with same LHS.
12610 if (RHS == FoundRHS) {
12611 std::swap(LHS, RHS);
12612 std::swap(FoundLHS, FoundRHS);
12614 }
12615 if (LHS != FoundLHS)
12616 return false;
12617
12618 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12619 if (!SUFoundRHS)
12620 return false;
12621
12622 Value *Shiftee, *ShiftValue;
12623
12624 using namespace PatternMatch;
12625 if (match(SUFoundRHS->getValue(),
12626 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12627 auto *ShifteeS = getSCEV(Shiftee);
12628 // Prove one of the following:
12629 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12630 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12631 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12632 // ---> LHS <s RHS
12633 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12634 // ---> LHS <=s RHS
12635 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12636 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12637 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12638 if (isKnownNonNegative(ShifteeS))
12639 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12640 }
12641
12642 return false;
12643}
12644
12645bool ScalarEvolution::isImpliedCondOperands(CmpPredicate Pred, const SCEV *LHS,
12646 const SCEV *RHS,
12647 const SCEV *FoundLHS,
12648 const SCEV *FoundRHS,
12649 const Instruction *CtxI) {
12650 return isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS,
12651 FoundRHS) ||
12652 isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS,
12653 FoundRHS) ||
12654 isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS) ||
12655 isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12656 CtxI) ||
12657 isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS);
12658}
12659
12660/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12661template <typename MinMaxExprType>
12662static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12663 const SCEV *Candidate) {
12664 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12665 if (!MinMaxExpr)
12666 return false;
12667
12668 return is_contained(MinMaxExpr->operands(), Candidate);
12669}
12670
12672 CmpPredicate Pred, const SCEV *LHS,
12673 const SCEV *RHS) {
12674 // If both sides are affine addrecs for the same loop, with equal
12675 // steps, and we know the recurrences don't wrap, then we only
12676 // need to check the predicate on the starting values.
12677
12678 if (!ICmpInst::isRelational(Pred))
12679 return false;
12680
12681 const SCEV *LStart, *RStart, *Step;
12682 const Loop *L;
12683 if (!match(LHS,
12684 m_scev_AffineAddRec(m_SCEV(LStart), m_SCEV(Step), m_Loop(L))) ||
12686 m_SpecificLoop(L))))
12687 return false;
12692 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12693 return false;
12694
12695 return SE.isKnownPredicate(Pred, LStart, RStart);
12696}
12697
12698/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12699/// expression?
12701 const SCEV *LHS, const SCEV *RHS) {
12702 switch (Pred) {
12703 default:
12704 return false;
12705
12706 case ICmpInst::ICMP_SGE:
12707 std::swap(LHS, RHS);
12708 [[fallthrough]];
12709 case ICmpInst::ICMP_SLE:
12710 return
12711 // min(A, ...) <= A
12713 // A <= max(A, ...)
12715
12716 case ICmpInst::ICMP_UGE:
12717 std::swap(LHS, RHS);
12718 [[fallthrough]];
12719 case ICmpInst::ICMP_ULE:
12720 return
12721 // min(A, ...) <= A
12722 // FIXME: what about umin_seq?
12724 // A <= max(A, ...)
12726 }
12727
12728 llvm_unreachable("covered switch fell through?!");
12729}
12730
12731bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS,
12732 const SCEV *RHS,
12733 const SCEV *FoundLHS,
12734 const SCEV *FoundRHS,
12735 unsigned Depth) {
12738 "LHS and RHS have different sizes?");
12739 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12740 getTypeSizeInBits(FoundRHS->getType()) &&
12741 "FoundLHS and FoundRHS have different sizes?");
12742 // We want to avoid hurting the compile time with analysis of too big trees.
12744 return false;
12745
12746 // We only want to work with GT comparison so far.
12747 if (ICmpInst::isLT(Pred)) {
12749 std::swap(LHS, RHS);
12750 std::swap(FoundLHS, FoundRHS);
12751 }
12752
12754
12755 // For unsigned, try to reduce it to corresponding signed comparison.
12756 if (P == ICmpInst::ICMP_UGT)
12757 // We can replace unsigned predicate with its signed counterpart if all
12758 // involved values are non-negative.
12759 // TODO: We could have better support for unsigned.
12760 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12761 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12762 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12763 // use this fact to prove that LHS and RHS are non-negative.
12764 const SCEV *MinusOne = getMinusOne(LHS->getType());
12765 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12766 FoundRHS) &&
12767 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12768 FoundRHS))
12770 }
12771
12772 if (P != ICmpInst::ICMP_SGT)
12773 return false;
12774
12775 auto GetOpFromSExt = [&](const SCEV *S) {
12776 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12777 return Ext->getOperand();
12778 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12779 // the constant in some cases.
12780 return S;
12781 };
12782
12783 // Acquire values from extensions.
12784 auto *OrigLHS = LHS;
12785 auto *OrigFoundLHS = FoundLHS;
12786 LHS = GetOpFromSExt(LHS);
12787 FoundLHS = GetOpFromSExt(FoundLHS);
12788
12789 // Is the SGT predicate can be proved trivially or using the found context.
12790 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12791 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12792 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12793 FoundRHS, Depth + 1);
12794 };
12795
12796 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12797 // We want to avoid creation of any new non-constant SCEV. Since we are
12798 // going to compare the operands to RHS, we should be certain that we don't
12799 // need any size extensions for this. So let's decline all cases when the
12800 // sizes of types of LHS and RHS do not match.
12801 // TODO: Maybe try to get RHS from sext to catch more cases?
12803 return false;
12804
12805 // Should not overflow.
12806 if (!LHSAddExpr->hasNoSignedWrap())
12807 return false;
12808
12809 auto *LL = LHSAddExpr->getOperand(0);
12810 auto *LR = LHSAddExpr->getOperand(1);
12811 auto *MinusOne = getMinusOne(RHS->getType());
12812
12813 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12814 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12815 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12816 };
12817 // Try to prove the following rule:
12818 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12819 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12820 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12821 return true;
12822 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12823 Value *LL, *LR;
12824 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12825
12826 using namespace llvm::PatternMatch;
12827
12828 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12829 // Rules for division.
12830 // We are going to perform some comparisons with Denominator and its
12831 // derivative expressions. In general case, creating a SCEV for it may
12832 // lead to a complex analysis of the entire graph, and in particular it
12833 // can request trip count recalculation for the same loop. This would
12834 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12835 // this, we only want to create SCEVs that are constants in this section.
12836 // So we bail if Denominator is not a constant.
12837 if (!isa<ConstantInt>(LR))
12838 return false;
12839
12840 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12841
12842 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
12843 // then a SCEV for the numerator already exists and matches with FoundLHS.
12844 auto *Numerator = getExistingSCEV(LL);
12845 if (!Numerator || Numerator->getType() != FoundLHS->getType())
12846 return false;
12847
12848 // Make sure that the numerator matches with FoundLHS and the denominator
12849 // is positive.
12850 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
12851 return false;
12852
12853 auto *DTy = Denominator->getType();
12854 auto *FRHSTy = FoundRHS->getType();
12855 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
12856 // One of types is a pointer and another one is not. We cannot extend
12857 // them properly to a wider type, so let us just reject this case.
12858 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
12859 // to avoid this check.
12860 return false;
12861
12862 // Given that:
12863 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
12864 auto *WTy = getWiderType(DTy, FRHSTy);
12865 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
12866 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
12867
12868 // Try to prove the following rule:
12869 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
12870 // For example, given that FoundLHS > 2. It means that FoundLHS is at
12871 // least 3. If we divide it by Denominator < 4, we will have at least 1.
12872 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
12873 if (isKnownNonPositive(RHS) &&
12874 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
12875 return true;
12876
12877 // Try to prove the following rule:
12878 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
12879 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
12880 // If we divide it by Denominator > 2, then:
12881 // 1. If FoundLHS is negative, then the result is 0.
12882 // 2. If FoundLHS is non-negative, then the result is non-negative.
12883 // Anyways, the result is non-negative.
12884 auto *MinusOne = getMinusOne(WTy);
12885 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
12886 if (isKnownNegative(RHS) &&
12887 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
12888 return true;
12889 }
12890 }
12891
12892 // If our expression contained SCEVUnknown Phis, and we split it down and now
12893 // need to prove something for them, try to prove the predicate for every
12894 // possible incoming values of those Phis.
12895 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
12896 return true;
12897
12898 return false;
12899}
12900
12902 const SCEV *RHS) {
12903 // zext x u<= sext x, sext x s<= zext x
12904 const SCEV *Op;
12905 switch (Pred) {
12906 case ICmpInst::ICMP_SGE:
12907 std::swap(LHS, RHS);
12908 [[fallthrough]];
12909 case ICmpInst::ICMP_SLE: {
12910 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12911 return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
12913 }
12914 case ICmpInst::ICMP_UGE:
12915 std::swap(LHS, RHS);
12916 [[fallthrough]];
12917 case ICmpInst::ICMP_ULE: {
12918 // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
12919 return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
12921 }
12922 default:
12923 return false;
12924 };
12925 llvm_unreachable("unhandled case");
12926}
12927
12928bool ScalarEvolution::isKnownViaNonRecursiveReasoning(CmpPredicate Pred,
12929 const SCEV *LHS,
12930 const SCEV *RHS) {
12931 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
12932 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
12933 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
12934 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
12935 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
12936}
12937
12938bool ScalarEvolution::isImpliedCondOperandsHelper(CmpPredicate Pred,
12939 const SCEV *LHS,
12940 const SCEV *RHS,
12941 const SCEV *FoundLHS,
12942 const SCEV *FoundRHS) {
12943 switch (Pred) {
12944 default:
12945 llvm_unreachable("Unexpected CmpPredicate value!");
12946 case ICmpInst::ICMP_EQ:
12947 case ICmpInst::ICMP_NE:
12948 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
12949 return true;
12950 break;
12951 case ICmpInst::ICMP_SLT:
12952 case ICmpInst::ICMP_SLE:
12953 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
12954 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
12955 return true;
12956 break;
12957 case ICmpInst::ICMP_SGT:
12958 case ICmpInst::ICMP_SGE:
12959 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
12960 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
12961 return true;
12962 break;
12963 case ICmpInst::ICMP_ULT:
12964 case ICmpInst::ICMP_ULE:
12965 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
12966 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
12967 return true;
12968 break;
12969 case ICmpInst::ICMP_UGT:
12970 case ICmpInst::ICMP_UGE:
12971 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
12972 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
12973 return true;
12974 break;
12975 }
12976
12977 // Maybe it can be proved via operations?
12978 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
12979 return true;
12980
12981 return false;
12982}
12983
12984bool ScalarEvolution::isImpliedCondOperandsViaRanges(
12985 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
12986 const SCEV *FoundLHS, const SCEV *FoundRHS) {
12987 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
12988 // The restriction on `FoundRHS` be lifted easily -- it exists only to
12989 // reduce the compile time impact of this optimization.
12990 return false;
12991
12992 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
12993 if (!Addend)
12994 return false;
12995
12996 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
12997
12998 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
12999 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
13000 ConstantRange FoundLHSRange =
13001 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
13002
13003 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
13004 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
13005
13006 // We can also compute the range of values for `LHS` that satisfy the
13007 // consequent, "`LHS` `Pred` `RHS`":
13008 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
13009 // The antecedent implies the consequent if every value of `LHS` that
13010 // satisfies the antecedent also satisfies the consequent.
13011 return LHSRange.icmp(Pred, ConstRHS);
13012}
13013
13014bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
13015 bool IsSigned) {
13016 assert(isKnownPositive(Stride) && "Positive stride expected!");
13017
13018 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
13019 const SCEV *One = getOne(Stride->getType());
13020
13021 if (IsSigned) {
13022 APInt MaxRHS = getSignedRangeMax(RHS);
13023 APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
13024 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
13025
13026 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
13027 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
13028 }
13029
13030 APInt MaxRHS = getUnsignedRangeMax(RHS);
13031 APInt MaxValue = APInt::getMaxValue(BitWidth);
13032 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13033
13034 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
13035 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
13036}
13037
13038bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
13039 bool IsSigned) {
13040
13041 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
13042 const SCEV *One = getOne(Stride->getType());
13043
13044 if (IsSigned) {
13045 APInt MinRHS = getSignedRangeMin(RHS);
13046 APInt MinValue = APInt::getSignedMinValue(BitWidth);
13047 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
13048
13049 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
13050 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
13051 }
13052
13053 APInt MinRHS = getUnsignedRangeMin(RHS);
13054 APInt MinValue = APInt::getMinValue(BitWidth);
13055 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13056
13057 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
13058 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
13059}
13060
13062 // umin(N, 1) + floor((N - umin(N, 1)) / D)
13063 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
13064 // expression fixes the case of N=0.
13065 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
13066 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
13067 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
13068}
13069
13070const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
13071 const SCEV *Stride,
13072 const SCEV *End,
13073 unsigned BitWidth,
13074 bool IsSigned) {
13075 // The logic in this function assumes we can represent a positive stride.
13076 // If we can't, the backedge-taken count must be zero.
13077 if (IsSigned && BitWidth == 1)
13078 return getZero(Stride->getType());
13079
13080 // This code below only been closely audited for negative strides in the
13081 // unsigned comparison case, it may be correct for signed comparison, but
13082 // that needs to be established.
13083 if (IsSigned && isKnownNegative(Stride))
13084 return getCouldNotCompute();
13085
13086 // Calculate the maximum backedge count based on the range of values
13087 // permitted by Start, End, and Stride.
13088 APInt MinStart =
13089 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
13090
13091 APInt MinStride =
13092 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
13093
13094 // We assume either the stride is positive, or the backedge-taken count
13095 // is zero. So force StrideForMaxBECount to be at least one.
13096 APInt One(BitWidth, 1);
13097 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
13098 : APIntOps::umax(One, MinStride);
13099
13100 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
13101 : APInt::getMaxValue(BitWidth);
13102 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
13103
13104 // Although End can be a MAX expression we estimate MaxEnd considering only
13105 // the case End = RHS of the loop termination condition. This is safe because
13106 // in the other case (End - Start) is zero, leading to a zero maximum backedge
13107 // taken count.
13108 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
13109 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
13110
13111 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
13112 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
13113 : APIntOps::umax(MaxEnd, MinStart);
13114
13115 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
13116 getConstant(StrideForMaxBECount) /* Step */);
13117}
13118
13120ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
13121 const Loop *L, bool IsSigned,
13122 bool ControlsOnlyExit, bool AllowPredicates) {
13124
13126 bool PredicatedIV = false;
13127 if (!IV) {
13128 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
13129 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
13130 if (AR && AR->getLoop() == L && AR->isAffine()) {
13131 auto canProveNUW = [&]() {
13132 // We can use the comparison to infer no-wrap flags only if it fully
13133 // controls the loop exit.
13134 if (!ControlsOnlyExit)
13135 return false;
13136
13137 if (!isLoopInvariant(RHS, L))
13138 return false;
13139
13140 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
13141 // We need the sequence defined by AR to strictly increase in the
13142 // unsigned integer domain for the logic below to hold.
13143 return false;
13144
13145 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
13146 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
13147 // If RHS <=u Limit, then there must exist a value V in the sequence
13148 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
13149 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
13150 // overflow occurs. This limit also implies that a signed comparison
13151 // (in the wide bitwidth) is equivalent to an unsigned comparison as
13152 // the high bits on both sides must be zero.
13153 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
13154 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
13155 Limit = Limit.zext(OuterBitWidth);
13156 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
13157 };
13158 auto Flags = AR->getNoWrapFlags();
13159 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
13160 Flags = setFlags(Flags, SCEV::FlagNUW);
13161
13162 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
13163 if (AR->hasNoUnsignedWrap()) {
13164 // Emulate what getZeroExtendExpr would have done during construction
13165 // if we'd been able to infer the fact just above at that time.
13166 const SCEV *Step = AR->getStepRecurrence(*this);
13167 Type *Ty = ZExt->getType();
13168 auto *S = getAddRecExpr(
13170 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
13172 }
13173 }
13174 }
13175 }
13176
13177
13178 if (!IV && AllowPredicates) {
13179 // Try to make this an AddRec using runtime tests, in the first X
13180 // iterations of this loop, where X is the SCEV expression found by the
13181 // algorithm below.
13182 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13183 PredicatedIV = true;
13184 }
13185
13186 // Avoid weird loops
13187 if (!IV || IV->getLoop() != L || !IV->isAffine())
13188 return getCouldNotCompute();
13189
13190 // A precondition of this method is that the condition being analyzed
13191 // reaches an exiting branch which dominates the latch. Given that, we can
13192 // assume that an increment which violates the nowrap specification and
13193 // produces poison must cause undefined behavior when the resulting poison
13194 // value is branched upon and thus we can conclude that the backedge is
13195 // taken no more often than would be required to produce that poison value.
13196 // Note that a well defined loop can exit on the iteration which violates
13197 // the nowrap specification if there is another exit (either explicit or
13198 // implicit/exceptional) which causes the loop to execute before the
13199 // exiting instruction we're analyzing would trigger UB.
13200 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13201 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13203
13204 const SCEV *Stride = IV->getStepRecurrence(*this);
13205
13206 bool PositiveStride = isKnownPositive(Stride);
13207
13208 // Avoid negative or zero stride values.
13209 if (!PositiveStride) {
13210 // We can compute the correct backedge taken count for loops with unknown
13211 // strides if we can prove that the loop is not an infinite loop with side
13212 // effects. Here's the loop structure we are trying to handle -
13213 //
13214 // i = start
13215 // do {
13216 // A[i] = i;
13217 // i += s;
13218 // } while (i < end);
13219 //
13220 // The backedge taken count for such loops is evaluated as -
13221 // (max(end, start + stride) - start - 1) /u stride
13222 //
13223 // The additional preconditions that we need to check to prove correctness
13224 // of the above formula is as follows -
13225 //
13226 // a) IV is either nuw or nsw depending upon signedness (indicated by the
13227 // NoWrap flag).
13228 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
13229 // no side effects within the loop)
13230 // c) loop has a single static exit (with no abnormal exits)
13231 //
13232 // Precondition a) implies that if the stride is negative, this is a single
13233 // trip loop. The backedge taken count formula reduces to zero in this case.
13234 //
13235 // Precondition b) and c) combine to imply that if rhs is invariant in L,
13236 // then a zero stride means the backedge can't be taken without executing
13237 // undefined behavior.
13238 //
13239 // The positive stride case is the same as isKnownPositive(Stride) returning
13240 // true (original behavior of the function).
13241 //
13242 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
13244 return getCouldNotCompute();
13245
13246 if (!isKnownNonZero(Stride)) {
13247 // If we have a step of zero, and RHS isn't invariant in L, we don't know
13248 // if it might eventually be greater than start and if so, on which
13249 // iteration. We can't even produce a useful upper bound.
13250 if (!isLoopInvariant(RHS, L))
13251 return getCouldNotCompute();
13252
13253 // We allow a potentially zero stride, but we need to divide by stride
13254 // below. Since the loop can't be infinite and this check must control
13255 // the sole exit, we can infer the exit must be taken on the first
13256 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
13257 // we know the numerator in the divides below must be zero, so we can
13258 // pick an arbitrary non-zero value for the denominator (e.g. stride)
13259 // and produce the right result.
13260 // FIXME: Handle the case where Stride is poison?
13261 auto wouldZeroStrideBeUB = [&]() {
13262 // Proof by contradiction. Suppose the stride were zero. If we can
13263 // prove that the backedge *is* taken on the first iteration, then since
13264 // we know this condition controls the sole exit, we must have an
13265 // infinite loop. We can't have a (well defined) infinite loop per
13266 // check just above.
13267 // Note: The (Start - Stride) term is used to get the start' term from
13268 // (start' + stride,+,stride). Remember that we only care about the
13269 // result of this expression when stride == 0 at runtime.
13270 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
13271 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
13272 };
13273 if (!wouldZeroStrideBeUB()) {
13274 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
13275 }
13276 }
13277 } else if (!NoWrap) {
13278 // Avoid proven overflow cases: this will ensure that the backedge taken
13279 // count will not generate any unsigned overflow.
13280 if (canIVOverflowOnLT(RHS, Stride, IsSigned))
13281 return getCouldNotCompute();
13282 }
13283
13284 // On all paths just preceeding, we established the following invariant:
13285 // IV can be assumed not to overflow up to and including the exiting
13286 // iteration. We proved this in one of two ways:
13287 // 1) We can show overflow doesn't occur before the exiting iteration
13288 // 1a) canIVOverflowOnLT, and b) step of one
13289 // 2) We can show that if overflow occurs, the loop must execute UB
13290 // before any possible exit.
13291 // Note that we have not yet proved RHS invariant (in general).
13292
13293 const SCEV *Start = IV->getStart();
13294
13295 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
13296 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
13297 // Use integer-typed versions for actual computation; we can't subtract
13298 // pointers in general.
13299 const SCEV *OrigStart = Start;
13300 const SCEV *OrigRHS = RHS;
13301 if (Start->getType()->isPointerTy()) {
13303 if (isa<SCEVCouldNotCompute>(Start))
13304 return Start;
13305 }
13306 if (RHS->getType()->isPointerTy()) {
13309 return RHS;
13310 }
13311
13312 const SCEV *End = nullptr, *BECount = nullptr,
13313 *BECountIfBackedgeTaken = nullptr;
13314 if (!isLoopInvariant(RHS, L)) {
13315 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13316 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13317 RHSAddRec->getNoWrapFlags()) {
13318 // The structure of loop we are trying to calculate backedge count of:
13319 //
13320 // left = left_start
13321 // right = right_start
13322 //
13323 // while(left < right){
13324 // ... do something here ...
13325 // left += s1; // stride of left is s1 (s1 > 0)
13326 // right += s2; // stride of right is s2 (s2 < 0)
13327 // }
13328 //
13329
13330 const SCEV *RHSStart = RHSAddRec->getStart();
13331 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13332
13333 // If Stride - RHSStride is positive and does not overflow, we can write
13334 // backedge count as ->
13335 // ceil((End - Start) /u (Stride - RHSStride))
13336 // Where, End = max(RHSStart, Start)
13337
13338 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13339 if (isKnownNegative(RHSStride) &&
13340 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13341 RHSStride)) {
13342
13343 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13344 if (isKnownPositive(Denominator)) {
13345 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13346 : getUMaxExpr(RHSStart, Start);
13347
13348 // We can do this because End >= Start, as End = max(RHSStart, Start)
13349 const SCEV *Delta = getMinusSCEV(End, Start);
13350
13351 BECount = getUDivCeilSCEV(Delta, Denominator);
13352 BECountIfBackedgeTaken =
13353 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13354 }
13355 }
13356 }
13357 if (BECount == nullptr) {
13358 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13359 // given the start, stride and max value for the end bound of the
13360 // loop (RHS), and the fact that IV does not overflow (which is
13361 // checked above).
13362 const SCEV *MaxBECount = computeMaxBECountForLT(
13363 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13364 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13365 MaxBECount, false /*MaxOrZero*/, Predicates);
13366 }
13367 } else {
13368 // We use the expression (max(End,Start)-Start)/Stride to describe the
13369 // backedge count, as if the backedge is taken at least once
13370 // max(End,Start) is End and so the result is as above, and if not
13371 // max(End,Start) is Start so we get a backedge count of zero.
13372 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13373 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13374 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13375 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13376 // Can we prove (max(RHS,Start) > Start - Stride?
13377 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13378 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13379 // In this case, we can use a refined formula for computing backedge
13380 // taken count. The general formula remains:
13381 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13382 // We want to use the alternate formula:
13383 // "((End - 1) - (Start - Stride)) /u Stride"
13384 // Let's do a quick case analysis to show these are equivalent under
13385 // our precondition that max(RHS,Start) > Start - Stride.
13386 // * For RHS <= Start, the backedge-taken count must be zero.
13387 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13388 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13389 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13390 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13391 // reducing this to the stride of 1 case.
13392 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13393 // Stride".
13394 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13395 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13396 // "((RHS - (Start - Stride) - 1) /u Stride".
13397 // Our preconditions trivially imply no overflow in that form.
13398 const SCEV *MinusOne = getMinusOne(Stride->getType());
13399 const SCEV *Numerator =
13400 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13401 BECount = getUDivExpr(Numerator, Stride);
13402 }
13403
13404 if (!BECount) {
13405 auto canProveRHSGreaterThanEqualStart = [&]() {
13406 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13407 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13408 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13409
13410 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13411 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13412 return true;
13413
13414 // (RHS > Start - 1) implies RHS >= Start.
13415 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13416 // "Start - 1" doesn't overflow.
13417 // * For signed comparison, if Start - 1 does overflow, it's equal
13418 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13419 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13420 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13421 //
13422 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13423 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13424 auto *StartMinusOne =
13425 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13426 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13427 };
13428
13429 // If we know that RHS >= Start in the context of loop, then we know
13430 // that max(RHS, Start) = RHS at this point.
13431 if (canProveRHSGreaterThanEqualStart()) {
13432 End = RHS;
13433 } else {
13434 // If RHS < Start, the backedge will be taken zero times. So in
13435 // general, we can write the backedge-taken count as:
13436 //
13437 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13438 //
13439 // We convert it to the following to make it more convenient for SCEV:
13440 //
13441 // ceil(max(RHS, Start) - Start) / Stride
13442 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13443
13444 // See what would happen if we assume the backedge is taken. This is
13445 // used to compute MaxBECount.
13446 BECountIfBackedgeTaken =
13447 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13448 }
13449
13450 // At this point, we know:
13451 //
13452 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13453 // 2. The index variable doesn't overflow.
13454 //
13455 // Therefore, we know N exists such that
13456 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13457 // doesn't overflow.
13458 //
13459 // Using this information, try to prove whether the addition in
13460 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13461 const SCEV *One = getOne(Stride->getType());
13462 bool MayAddOverflow = [&] {
13463 if (isKnownToBeAPowerOfTwo(Stride)) {
13464 // Suppose Stride is a power of two, and Start/End are unsigned
13465 // integers. Let UMAX be the largest representable unsigned
13466 // integer.
13467 //
13468 // By the preconditions of this function, we know
13469 // "(Start + Stride * N) >= End", and this doesn't overflow.
13470 // As a formula:
13471 //
13472 // End <= (Start + Stride * N) <= UMAX
13473 //
13474 // Subtracting Start from all the terms:
13475 //
13476 // End - Start <= Stride * N <= UMAX - Start
13477 //
13478 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13479 //
13480 // End - Start <= Stride * N <= UMAX
13481 //
13482 // Stride * N is a multiple of Stride. Therefore,
13483 //
13484 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13485 //
13486 // Since Stride is a power of two, UMAX + 1 is divisible by
13487 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13488 // write:
13489 //
13490 // End - Start <= Stride * N <= UMAX - Stride - 1
13491 //
13492 // Dropping the middle term:
13493 //
13494 // End - Start <= UMAX - Stride - 1
13495 //
13496 // Adding Stride - 1 to both sides:
13497 //
13498 // (End - Start) + (Stride - 1) <= UMAX
13499 //
13500 // In other words, the addition doesn't have unsigned overflow.
13501 //
13502 // A similar proof works if we treat Start/End as signed values.
13503 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13504 // to use signed max instead of unsigned max. Note that we're
13505 // trying to prove a lack of unsigned overflow in either case.
13506 return false;
13507 }
13508 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13509 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13510 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13511 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13512 // 1 <s End.
13513 //
13514 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13515 // End.
13516 return false;
13517 }
13518 return true;
13519 }();
13520
13521 const SCEV *Delta = getMinusSCEV(End, Start);
13522 if (!MayAddOverflow) {
13523 // floor((D + (S - 1)) / S)
13524 // We prefer this formulation if it's legal because it's fewer
13525 // operations.
13526 BECount =
13527 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13528 } else {
13529 BECount = getUDivCeilSCEV(Delta, Stride);
13530 }
13531 }
13532 }
13533
13534 const SCEV *ConstantMaxBECount;
13535 bool MaxOrZero = false;
13536 if (isa<SCEVConstant>(BECount)) {
13537 ConstantMaxBECount = BECount;
13538 } else if (BECountIfBackedgeTaken &&
13539 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13540 // If we know exactly how many times the backedge will be taken if it's
13541 // taken at least once, then the backedge count will either be that or
13542 // zero.
13543 ConstantMaxBECount = BECountIfBackedgeTaken;
13544 MaxOrZero = true;
13545 } else {
13546 ConstantMaxBECount = computeMaxBECountForLT(
13547 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13548 }
13549
13550 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13551 !isa<SCEVCouldNotCompute>(BECount))
13552 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13553
13554 const SCEV *SymbolicMaxBECount =
13555 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13556 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13557 Predicates);
13558}
13559
13560ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13561 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13562 bool ControlsOnlyExit, bool AllowPredicates) {
13564 // We handle only IV > Invariant
13565 if (!isLoopInvariant(RHS, L))
13566 return getCouldNotCompute();
13567
13568 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13569 if (!IV && AllowPredicates)
13570 // Try to make this an AddRec using runtime tests, in the first X
13571 // iterations of this loop, where X is the SCEV expression found by the
13572 // algorithm below.
13573 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13574
13575 // Avoid weird loops
13576 if (!IV || IV->getLoop() != L || !IV->isAffine())
13577 return getCouldNotCompute();
13578
13579 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13580 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13582
13583 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13584
13585 // Avoid negative or zero stride values
13586 if (!isKnownPositive(Stride))
13587 return getCouldNotCompute();
13588
13589 // Avoid proven overflow cases: this will ensure that the backedge taken count
13590 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13591 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13592 // behaviors like the case of C language.
13593 if (!Stride->isOne() && !NoWrap)
13594 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13595 return getCouldNotCompute();
13596
13597 const SCEV *Start = IV->getStart();
13598 const SCEV *End = RHS;
13599 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13600 // If we know that Start >= RHS in the context of loop, then we know that
13601 // min(RHS, Start) = RHS at this point.
13603 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13604 End = RHS;
13605 else
13606 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13607 }
13608
13609 if (Start->getType()->isPointerTy()) {
13611 if (isa<SCEVCouldNotCompute>(Start))
13612 return Start;
13613 }
13614 if (End->getType()->isPointerTy()) {
13615 End = getLosslessPtrToIntExpr(End);
13616 if (isa<SCEVCouldNotCompute>(End))
13617 return End;
13618 }
13619
13620 // Compute ((Start - End) + (Stride - 1)) / Stride.
13621 // FIXME: This can overflow. Holding off on fixing this for now;
13622 // howManyGreaterThans will hopefully be gone soon.
13623 const SCEV *One = getOne(Stride->getType());
13624 const SCEV *BECount = getUDivExpr(
13625 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13626
13627 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13629
13630 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13631 : getUnsignedRangeMin(Stride);
13632
13633 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13634 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13635 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13636
13637 // Although End can be a MIN expression we estimate MinEnd considering only
13638 // the case End = RHS. This is safe because in the other case (Start - End)
13639 // is zero, leading to a zero maximum backedge taken count.
13640 APInt MinEnd =
13641 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13642 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13643
13644 const SCEV *ConstantMaxBECount =
13645 isa<SCEVConstant>(BECount)
13646 ? BECount
13647 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13648 getConstant(MinStride));
13649
13650 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13651 ConstantMaxBECount = BECount;
13652 const SCEV *SymbolicMaxBECount =
13653 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13654
13655 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13656 Predicates);
13657}
13658
13660 ScalarEvolution &SE) const {
13661 if (Range.isFullSet()) // Infinite loop.
13662 return SE.getCouldNotCompute();
13663
13664 // If the start is a non-zero constant, shift the range to simplify things.
13665 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13666 if (!SC->getValue()->isZero()) {
13668 Operands[0] = SE.getZero(SC->getType());
13669 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13671 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13672 return ShiftedAddRec->getNumIterationsInRange(
13673 Range.subtract(SC->getAPInt()), SE);
13674 // This is strange and shouldn't happen.
13675 return SE.getCouldNotCompute();
13676 }
13677
13678 // The only time we can solve this is when we have all constant indices.
13679 // Otherwise, we cannot determine the overflow conditions.
13680 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13681 return SE.getCouldNotCompute();
13682
13683 // Okay at this point we know that all elements of the chrec are constants and
13684 // that the start element is zero.
13685
13686 // First check to see if the range contains zero. If not, the first
13687 // iteration exits.
13688 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13689 if (!Range.contains(APInt(BitWidth, 0)))
13690 return SE.getZero(getType());
13691
13692 if (isAffine()) {
13693 // If this is an affine expression then we have this situation:
13694 // Solve {0,+,A} in Range === Ax in Range
13695
13696 // We know that zero is in the range. If A is positive then we know that
13697 // the upper value of the range must be the first possible exit value.
13698 // If A is negative then the lower of the range is the last possible loop
13699 // value. Also note that we already checked for a full range.
13700 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13701 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13702
13703 // The exit value should be (End+A)/A.
13704 APInt ExitVal = (End + A).udiv(A);
13705 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13706
13707 // Evaluate at the exit value. If we really did fall out of the valid
13708 // range, then we computed our trip count, otherwise wrap around or other
13709 // things must have happened.
13710 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13711 if (Range.contains(Val->getValue()))
13712 return SE.getCouldNotCompute(); // Something strange happened
13713
13714 // Ensure that the previous value is in the range.
13715 assert(Range.contains(
13717 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13718 "Linear scev computation is off in a bad way!");
13719 return SE.getConstant(ExitValue);
13720 }
13721
13722 if (isQuadratic()) {
13723 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13724 return SE.getConstant(*S);
13725 }
13726
13727 return SE.getCouldNotCompute();
13728}
13729
13730const SCEVAddRecExpr *
13732 assert(getNumOperands() > 1 && "AddRec with zero step?");
13733 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13734 // but in this case we cannot guarantee that the value returned will be an
13735 // AddRec because SCEV does not have a fixed point where it stops
13736 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13737 // may happen if we reach arithmetic depth limit while simplifying. So we
13738 // construct the returned value explicitly.
13740 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13741 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13742 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13743 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13744 // We know that the last operand is not a constant zero (otherwise it would
13745 // have been popped out earlier). This guarantees us that if the result has
13746 // the same last operand, then it will also not be popped out, meaning that
13747 // the returned value will be an AddRec.
13748 const SCEV *Last = getOperand(getNumOperands() - 1);
13749 assert(!Last->isZero() && "Recurrency with zero step?");
13750 Ops.push_back(Last);
13753}
13754
13755// Return true when S contains at least an undef value.
13757 return SCEVExprContains(
13758 S, [](const SCEV *S) { return match(S, m_scev_UndefOrPoison()); });
13759}
13760
13761// Return true when S contains a value that is a nullptr.
13763 return SCEVExprContains(S, [](const SCEV *S) {
13764 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13765 return SU->getValue() == nullptr;
13766 return false;
13767 });
13768}
13769
13770/// Return the size of an element read or written by Inst.
13772 Type *Ty;
13773 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13774 Ty = Store->getValueOperand()->getType();
13775 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13776 Ty = Load->getType();
13777 else
13778 return nullptr;
13779
13781 return getSizeOfExpr(ETy, Ty);
13782}
13783
13784//===----------------------------------------------------------------------===//
13785// SCEVCallbackVH Class Implementation
13786//===----------------------------------------------------------------------===//
13787
13789 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13790 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13791 SE->ConstantEvolutionLoopExitValue.erase(PN);
13792 SE->eraseValueFromMap(getValPtr());
13793 // this now dangles!
13794}
13795
13796void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13797 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13798
13799 // Forget all the expressions associated with users of the old value,
13800 // so that future queries will recompute the expressions using the new
13801 // value.
13802 SE->forgetValue(getValPtr());
13803 // this now dangles!
13804}
13805
13806ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13807 : CallbackVH(V), SE(se) {}
13808
13809//===----------------------------------------------------------------------===//
13810// ScalarEvolution Class Implementation
13811//===----------------------------------------------------------------------===//
13812
13815 LoopInfo &LI)
13816 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
13817 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13818 LoopDispositions(64), BlockDispositions(64) {
13819 // To use guards for proving predicates, we need to scan every instruction in
13820 // relevant basic blocks, and not just terminators. Doing this is a waste of
13821 // time if the IR does not actually contain any calls to
13822 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13823 //
13824 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13825 // to _add_ guards to the module when there weren't any before, and wants
13826 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13827 // efficient in lieu of being smart in that rather obscure case.
13828
13829 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
13830 F.getParent(), Intrinsic::experimental_guard);
13831 HasGuards = GuardDecl && !GuardDecl->use_empty();
13832}
13833
13835 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
13836 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13837 ValueExprMap(std::move(Arg.ValueExprMap)),
13838 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13839 PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
13840 PendingMerges(std::move(Arg.PendingMerges)),
13841 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
13842 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13843 PredicatedBackedgeTakenCounts(
13844 std::move(Arg.PredicatedBackedgeTakenCounts)),
13845 BECountUsers(std::move(Arg.BECountUsers)),
13846 ConstantEvolutionLoopExitValue(
13847 std::move(Arg.ConstantEvolutionLoopExitValue)),
13848 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
13849 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
13850 LoopDispositions(std::move(Arg.LoopDispositions)),
13851 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
13852 BlockDispositions(std::move(Arg.BlockDispositions)),
13853 SCEVUsers(std::move(Arg.SCEVUsers)),
13854 UnsignedRanges(std::move(Arg.UnsignedRanges)),
13855 SignedRanges(std::move(Arg.SignedRanges)),
13856 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
13857 UniquePreds(std::move(Arg.UniquePreds)),
13858 SCEVAllocator(std::move(Arg.SCEVAllocator)),
13859 LoopUsers(std::move(Arg.LoopUsers)),
13860 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
13861 FirstUnknown(Arg.FirstUnknown) {
13862 Arg.FirstUnknown = nullptr;
13863}
13864
13866 // Iterate through all the SCEVUnknown instances and call their
13867 // destructors, so that they release their references to their values.
13868 for (SCEVUnknown *U = FirstUnknown; U;) {
13869 SCEVUnknown *Tmp = U;
13870 U = U->Next;
13871 Tmp->~SCEVUnknown();
13872 }
13873 FirstUnknown = nullptr;
13874
13875 ExprValueMap.clear();
13876 ValueExprMap.clear();
13877 HasRecMap.clear();
13878 BackedgeTakenCounts.clear();
13879 PredicatedBackedgeTakenCounts.clear();
13880
13881 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
13882 assert(PendingPhiRanges.empty() && "getRangeRef garbage");
13883 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
13884 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
13885 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
13886}
13887
13891
13892/// When printing a top-level SCEV for trip counts, it's helpful to include
13893/// a type for constants which are otherwise hard to disambiguate.
13894static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
13895 if (isa<SCEVConstant>(S))
13896 OS << *S->getType() << " ";
13897 OS << *S;
13898}
13899
13901 const Loop *L) {
13902 // Print all inner loops first
13903 for (Loop *I : *L)
13904 PrintLoopInfo(OS, SE, I);
13905
13906 OS << "Loop ";
13907 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13908 OS << ": ";
13909
13910 SmallVector<BasicBlock *, 8> ExitingBlocks;
13911 L->getExitingBlocks(ExitingBlocks);
13912 if (ExitingBlocks.size() != 1)
13913 OS << "<multiple exits> ";
13914
13915 auto *BTC = SE->getBackedgeTakenCount(L);
13916 if (!isa<SCEVCouldNotCompute>(BTC)) {
13917 OS << "backedge-taken count is ";
13918 PrintSCEVWithTypeHint(OS, BTC);
13919 } else
13920 OS << "Unpredictable backedge-taken count.";
13921 OS << "\n";
13922
13923 if (ExitingBlocks.size() > 1)
13924 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13925 OS << " exit count for " << ExitingBlock->getName() << ": ";
13926 const SCEV *EC = SE->getExitCount(L, ExitingBlock);
13927 PrintSCEVWithTypeHint(OS, EC);
13928 if (isa<SCEVCouldNotCompute>(EC)) {
13929 // Retry with predicates.
13931 EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
13932 if (!isa<SCEVCouldNotCompute>(EC)) {
13933 OS << "\n predicated exit count for " << ExitingBlock->getName()
13934 << ": ";
13935 PrintSCEVWithTypeHint(OS, EC);
13936 OS << "\n Predicates:\n";
13937 for (const auto *P : Predicates)
13938 P->print(OS, 4);
13939 }
13940 }
13941 OS << "\n";
13942 }
13943
13944 OS << "Loop ";
13945 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13946 OS << ": ";
13947
13948 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
13949 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
13950 OS << "constant max backedge-taken count is ";
13951 PrintSCEVWithTypeHint(OS, ConstantBTC);
13953 OS << ", actual taken count either this or zero.";
13954 } else {
13955 OS << "Unpredictable constant max backedge-taken count. ";
13956 }
13957
13958 OS << "\n"
13959 "Loop ";
13960 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13961 OS << ": ";
13962
13963 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
13964 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
13965 OS << "symbolic max backedge-taken count is ";
13966 PrintSCEVWithTypeHint(OS, SymbolicBTC);
13968 OS << ", actual taken count either this or zero.";
13969 } else {
13970 OS << "Unpredictable symbolic max backedge-taken count. ";
13971 }
13972 OS << "\n";
13973
13974 if (ExitingBlocks.size() > 1)
13975 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13976 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
13977 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
13979 PrintSCEVWithTypeHint(OS, ExitBTC);
13980 if (isa<SCEVCouldNotCompute>(ExitBTC)) {
13981 // Retry with predicates.
13983 ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
13985 if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
13986 OS << "\n predicated symbolic max exit count for "
13987 << ExitingBlock->getName() << ": ";
13988 PrintSCEVWithTypeHint(OS, ExitBTC);
13989 OS << "\n Predicates:\n";
13990 for (const auto *P : Predicates)
13991 P->print(OS, 4);
13992 }
13993 }
13994 OS << "\n";
13995 }
13996
13998 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
13999 if (PBT != BTC) {
14000 assert(!Preds.empty() && "Different predicated BTC, but no predicates");
14001 OS << "Loop ";
14002 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14003 OS << ": ";
14004 if (!isa<SCEVCouldNotCompute>(PBT)) {
14005 OS << "Predicated backedge-taken count is ";
14006 PrintSCEVWithTypeHint(OS, PBT);
14007 } else
14008 OS << "Unpredictable predicated backedge-taken count.";
14009 OS << "\n";
14010 OS << " Predicates:\n";
14011 for (const auto *P : Preds)
14012 P->print(OS, 4);
14013 }
14014 Preds.clear();
14015
14016 auto *PredConstantMax =
14018 if (PredConstantMax != ConstantBTC) {
14019 assert(!Preds.empty() &&
14020 "different predicated constant max BTC but no predicates");
14021 OS << "Loop ";
14022 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14023 OS << ": ";
14024 if (!isa<SCEVCouldNotCompute>(PredConstantMax)) {
14025 OS << "Predicated constant max backedge-taken count is ";
14026 PrintSCEVWithTypeHint(OS, PredConstantMax);
14027 } else
14028 OS << "Unpredictable predicated constant max backedge-taken count.";
14029 OS << "\n";
14030 OS << " Predicates:\n";
14031 for (const auto *P : Preds)
14032 P->print(OS, 4);
14033 }
14034 Preds.clear();
14035
14036 auto *PredSymbolicMax =
14038 if (SymbolicBTC != PredSymbolicMax) {
14039 assert(!Preds.empty() &&
14040 "Different predicated symbolic max BTC, but no predicates");
14041 OS << "Loop ";
14042 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14043 OS << ": ";
14044 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
14045 OS << "Predicated symbolic max backedge-taken count is ";
14046 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
14047 } else
14048 OS << "Unpredictable predicated symbolic max backedge-taken count.";
14049 OS << "\n";
14050 OS << " Predicates:\n";
14051 for (const auto *P : Preds)
14052 P->print(OS, 4);
14053 }
14054
14056 OS << "Loop ";
14057 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14058 OS << ": ";
14059 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
14060 }
14061}
14062
14063namespace llvm {
14064// Note: these overloaded operators need to be in the llvm namespace for them
14065// to be resolved correctly. If we put them outside the llvm namespace, the
14066//
14067// OS << ": " << SE.getLoopDisposition(SV, InnerL);
14068//
14069// code below "breaks" and start printing raw enum values as opposed to the
14070// string values.
14073 switch (LD) {
14075 OS << "Variant";
14076 break;
14078 OS << "Invariant";
14079 break;
14081 OS << "Computable";
14082 break;
14083 }
14084 return OS;
14085}
14086
14089 switch (BD) {
14091 OS << "DoesNotDominate";
14092 break;
14094 OS << "Dominates";
14095 break;
14097 OS << "ProperlyDominates";
14098 break;
14099 }
14100 return OS;
14101}
14102} // namespace llvm
14103
14105 // ScalarEvolution's implementation of the print method is to print
14106 // out SCEV values of all instructions that are interesting. Doing
14107 // this potentially causes it to create new SCEV objects though,
14108 // which technically conflicts with the const qualifier. This isn't
14109 // observable from outside the class though, so casting away the
14110 // const isn't dangerous.
14111 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14112
14113 if (ClassifyExpressions) {
14114 OS << "Classifying expressions for: ";
14115 F.printAsOperand(OS, /*PrintType=*/false);
14116 OS << "\n";
14117 for (Instruction &I : instructions(F))
14118 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
14119 OS << I << '\n';
14120 OS << " --> ";
14121 const SCEV *SV = SE.getSCEV(&I);
14122 SV->print(OS);
14123 if (!isa<SCEVCouldNotCompute>(SV)) {
14124 OS << " U: ";
14125 SE.getUnsignedRange(SV).print(OS);
14126 OS << " S: ";
14127 SE.getSignedRange(SV).print(OS);
14128 }
14129
14130 const Loop *L = LI.getLoopFor(I.getParent());
14131
14132 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
14133 if (AtUse != SV) {
14134 OS << " --> ";
14135 AtUse->print(OS);
14136 if (!isa<SCEVCouldNotCompute>(AtUse)) {
14137 OS << " U: ";
14138 SE.getUnsignedRange(AtUse).print(OS);
14139 OS << " S: ";
14140 SE.getSignedRange(AtUse).print(OS);
14141 }
14142 }
14143
14144 if (L) {
14145 OS << "\t\t" "Exits: ";
14146 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
14147 if (!SE.isLoopInvariant(ExitValue, L)) {
14148 OS << "<<Unknown>>";
14149 } else {
14150 OS << *ExitValue;
14151 }
14152
14153 ListSeparator LS(", ", "\t\tLoopDispositions: { ");
14154 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
14155 OS << LS;
14156 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14157 OS << ": " << SE.getLoopDisposition(SV, Iter);
14158 }
14159
14160 for (const auto *InnerL : depth_first(L)) {
14161 if (InnerL == L)
14162 continue;
14163 OS << LS;
14164 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14165 OS << ": " << SE.getLoopDisposition(SV, InnerL);
14166 }
14167
14168 OS << " }";
14169 }
14170
14171 OS << "\n";
14172 }
14173 }
14174
14175 OS << "Determining loop execution counts for: ";
14176 F.printAsOperand(OS, /*PrintType=*/false);
14177 OS << "\n";
14178 for (Loop *I : LI)
14179 PrintLoopInfo(OS, &SE, I);
14180}
14181
14184 auto &Values = LoopDispositions[S];
14185 for (auto &V : Values) {
14186 if (V.getPointer() == L)
14187 return V.getInt();
14188 }
14189 Values.emplace_back(L, LoopVariant);
14190 LoopDisposition D = computeLoopDisposition(S, L);
14191 auto &Values2 = LoopDispositions[S];
14192 for (auto &V : llvm::reverse(Values2)) {
14193 if (V.getPointer() == L) {
14194 V.setInt(D);
14195 break;
14196 }
14197 }
14198 return D;
14199}
14200
14202ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
14203 switch (S->getSCEVType()) {
14204 case scConstant:
14205 case scVScale:
14206 return LoopInvariant;
14207 case scAddRecExpr: {
14208 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14209
14210 // If L is the addrec's loop, it's computable.
14211 if (AR->getLoop() == L)
14212 return LoopComputable;
14213
14214 // Add recurrences are never invariant in the function-body (null loop).
14215 if (!L)
14216 return LoopVariant;
14217
14218 // Everything that is not defined at loop entry is variant.
14219 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
14220 return LoopVariant;
14221 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
14222 " dominate the contained loop's header?");
14223
14224 // This recurrence is invariant w.r.t. L if AR's loop contains L.
14225 if (AR->getLoop()->contains(L))
14226 return LoopInvariant;
14227
14228 // This recurrence is variant w.r.t. L if any of its operands
14229 // are variant.
14230 for (const auto *Op : AR->operands())
14231 if (!isLoopInvariant(Op, L))
14232 return LoopVariant;
14233
14234 // Otherwise it's loop-invariant.
14235 return LoopInvariant;
14236 }
14237 case scTruncate:
14238 case scZeroExtend:
14239 case scSignExtend:
14240 case scPtrToAddr:
14241 case scPtrToInt:
14242 case scAddExpr:
14243 case scMulExpr:
14244 case scUDivExpr:
14245 case scUMaxExpr:
14246 case scSMaxExpr:
14247 case scUMinExpr:
14248 case scSMinExpr:
14249 case scSequentialUMinExpr: {
14250 bool HasVarying = false;
14251 for (const auto *Op : S->operands()) {
14253 if (D == LoopVariant)
14254 return LoopVariant;
14255 if (D == LoopComputable)
14256 HasVarying = true;
14257 }
14258 return HasVarying ? LoopComputable : LoopInvariant;
14259 }
14260 case scUnknown:
14261 // All non-instruction values are loop invariant. All instructions are loop
14262 // invariant if they are not contained in the specified loop.
14263 // Instructions are never considered invariant in the function body
14264 // (null loop) because they are defined within the "loop".
14265 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
14266 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
14267 return LoopInvariant;
14268 case scCouldNotCompute:
14269 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14270 }
14271 llvm_unreachable("Unknown SCEV kind!");
14272}
14273
14275 return getLoopDisposition(S, L) == LoopInvariant;
14276}
14277
14279 return getLoopDisposition(S, L) == LoopComputable;
14280}
14281
14284 auto &Values = BlockDispositions[S];
14285 for (auto &V : Values) {
14286 if (V.getPointer() == BB)
14287 return V.getInt();
14288 }
14289 Values.emplace_back(BB, DoesNotDominateBlock);
14290 BlockDisposition D = computeBlockDisposition(S, BB);
14291 auto &Values2 = BlockDispositions[S];
14292 for (auto &V : llvm::reverse(Values2)) {
14293 if (V.getPointer() == BB) {
14294 V.setInt(D);
14295 break;
14296 }
14297 }
14298 return D;
14299}
14300
14302ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
14303 switch (S->getSCEVType()) {
14304 case scConstant:
14305 case scVScale:
14307 case scAddRecExpr: {
14308 // This uses a "dominates" query instead of "properly dominates" query
14309 // to test for proper dominance too, because the instruction which
14310 // produces the addrec's value is a PHI, and a PHI effectively properly
14311 // dominates its entire containing block.
14312 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14313 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
14314 return DoesNotDominateBlock;
14315
14316 // Fall through into SCEVNAryExpr handling.
14317 [[fallthrough]];
14318 }
14319 case scTruncate:
14320 case scZeroExtend:
14321 case scSignExtend:
14322 case scPtrToAddr:
14323 case scPtrToInt:
14324 case scAddExpr:
14325 case scMulExpr:
14326 case scUDivExpr:
14327 case scUMaxExpr:
14328 case scSMaxExpr:
14329 case scUMinExpr:
14330 case scSMinExpr:
14331 case scSequentialUMinExpr: {
14332 bool Proper = true;
14333 for (const SCEV *NAryOp : S->operands()) {
14335 if (D == DoesNotDominateBlock)
14336 return DoesNotDominateBlock;
14337 if (D == DominatesBlock)
14338 Proper = false;
14339 }
14340 return Proper ? ProperlyDominatesBlock : DominatesBlock;
14341 }
14342 case scUnknown:
14343 if (Instruction *I =
14344 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
14345 if (I->getParent() == BB)
14346 return DominatesBlock;
14347 if (DT.properlyDominates(I->getParent(), BB))
14349 return DoesNotDominateBlock;
14350 }
14352 case scCouldNotCompute:
14353 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14354 }
14355 llvm_unreachable("Unknown SCEV kind!");
14356}
14357
14358bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14359 return getBlockDisposition(S, BB) >= DominatesBlock;
14360}
14361
14364}
14365
14366bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14367 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14368}
14369
14370void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14371 bool Predicated) {
14372 auto &BECounts =
14373 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14374 auto It = BECounts.find(L);
14375 if (It != BECounts.end()) {
14376 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14377 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14378 if (!isa<SCEVConstant>(S)) {
14379 auto UserIt = BECountUsers.find(S);
14380 assert(UserIt != BECountUsers.end());
14381 UserIt->second.erase({L, Predicated});
14382 }
14383 }
14384 }
14385 BECounts.erase(It);
14386 }
14387}
14388
14389void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
14390 SmallPtrSet<const SCEV *, 8> ToForget(llvm::from_range, SCEVs);
14391 SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
14392
14393 while (!Worklist.empty()) {
14394 const SCEV *Curr = Worklist.pop_back_val();
14395 auto Users = SCEVUsers.find(Curr);
14396 if (Users != SCEVUsers.end())
14397 for (const auto *User : Users->second)
14398 if (ToForget.insert(User).second)
14399 Worklist.push_back(User);
14400 }
14401
14402 for (const auto *S : ToForget)
14403 forgetMemoizedResultsImpl(S);
14404
14405 for (auto I = PredicatedSCEVRewrites.begin();
14406 I != PredicatedSCEVRewrites.end();) {
14407 std::pair<const SCEV *, const Loop *> Entry = I->first;
14408 if (ToForget.count(Entry.first))
14409 PredicatedSCEVRewrites.erase(I++);
14410 else
14411 ++I;
14412 }
14413}
14414
14415void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14416 LoopDispositions.erase(S);
14417 BlockDispositions.erase(S);
14418 UnsignedRanges.erase(S);
14419 SignedRanges.erase(S);
14420 HasRecMap.erase(S);
14421 ConstantMultipleCache.erase(S);
14422
14423 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14424 UnsignedWrapViaInductionTried.erase(AR);
14425 SignedWrapViaInductionTried.erase(AR);
14426 }
14427
14428 auto ExprIt = ExprValueMap.find(S);
14429 if (ExprIt != ExprValueMap.end()) {
14430 for (Value *V : ExprIt->second) {
14431 auto ValueIt = ValueExprMap.find_as(V);
14432 if (ValueIt != ValueExprMap.end())
14433 ValueExprMap.erase(ValueIt);
14434 }
14435 ExprValueMap.erase(ExprIt);
14436 }
14437
14438 auto ScopeIt = ValuesAtScopes.find(S);
14439 if (ScopeIt != ValuesAtScopes.end()) {
14440 for (const auto &Pair : ScopeIt->second)
14441 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14442 llvm::erase(ValuesAtScopesUsers[Pair.second],
14443 std::make_pair(Pair.first, S));
14444 ValuesAtScopes.erase(ScopeIt);
14445 }
14446
14447 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14448 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14449 for (const auto &Pair : ScopeUserIt->second)
14450 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14451 ValuesAtScopesUsers.erase(ScopeUserIt);
14452 }
14453
14454 auto BEUsersIt = BECountUsers.find(S);
14455 if (BEUsersIt != BECountUsers.end()) {
14456 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14457 auto Copy = BEUsersIt->second;
14458 for (const auto &Pair : Copy)
14459 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14460 BECountUsers.erase(BEUsersIt);
14461 }
14462
14463 auto FoldUser = FoldCacheUser.find(S);
14464 if (FoldUser != FoldCacheUser.end())
14465 for (auto &KV : FoldUser->second)
14466 FoldCache.erase(KV);
14467 FoldCacheUser.erase(S);
14468}
14469
14470void
14471ScalarEvolution::getUsedLoops(const SCEV *S,
14472 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14473 struct FindUsedLoops {
14474 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14475 : LoopsUsed(LoopsUsed) {}
14476 SmallPtrSetImpl<const Loop *> &LoopsUsed;
14477 bool follow(const SCEV *S) {
14478 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14479 LoopsUsed.insert(AR->getLoop());
14480 return true;
14481 }
14482
14483 bool isDone() const { return false; }
14484 };
14485
14486 FindUsedLoops F(LoopsUsed);
14487 SCEVTraversal<FindUsedLoops>(F).visitAll(S);
14488}
14489
14490void ScalarEvolution::getReachableBlocks(
14493 Worklist.push_back(&F.getEntryBlock());
14494 while (!Worklist.empty()) {
14495 BasicBlock *BB = Worklist.pop_back_val();
14496 if (!Reachable.insert(BB).second)
14497 continue;
14498
14499 Value *Cond;
14500 BasicBlock *TrueBB, *FalseBB;
14501 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14502 m_BasicBlock(FalseBB)))) {
14503 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14504 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14505 continue;
14506 }
14507
14508 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14509 const SCEV *L = getSCEV(Cmp->getOperand(0));
14510 const SCEV *R = getSCEV(Cmp->getOperand(1));
14511 if (isKnownPredicateViaConstantRanges(Cmp->getCmpPredicate(), L, R)) {
14512 Worklist.push_back(TrueBB);
14513 continue;
14514 }
14515 if (isKnownPredicateViaConstantRanges(Cmp->getInverseCmpPredicate(), L,
14516 R)) {
14517 Worklist.push_back(FalseBB);
14518 continue;
14519 }
14520 }
14521 }
14522
14523 append_range(Worklist, successors(BB));
14524 }
14525}
14526
14528 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14529 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14530
14531 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14532
14533 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14534 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14535 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14536
14537 const SCEV *visitConstant(const SCEVConstant *Constant) {
14538 return SE.getConstant(Constant->getAPInt());
14539 }
14540
14541 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14542 return SE.getUnknown(Expr->getValue());
14543 }
14544
14545 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14546 return SE.getCouldNotCompute();
14547 }
14548 };
14549
14550 SCEVMapper SCM(SE2);
14551 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14552 SE2.getReachableBlocks(ReachableBlocks, F);
14553
14554 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14555 if (containsUndefs(Old) || containsUndefs(New)) {
14556 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14557 // not propagate undef aggressively). This means we can (and do) fail
14558 // verification in cases where a transform makes a value go from "undef"
14559 // to "undef+1" (say). The transform is fine, since in both cases the
14560 // result is "undef", but SCEV thinks the value increased by 1.
14561 return nullptr;
14562 }
14563
14564 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14565 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14566 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14567 return nullptr;
14568
14569 return Delta;
14570 };
14571
14572 while (!LoopStack.empty()) {
14573 auto *L = LoopStack.pop_back_val();
14574 llvm::append_range(LoopStack, *L);
14575
14576 // Only verify BECounts in reachable loops. For an unreachable loop,
14577 // any BECount is legal.
14578 if (!ReachableBlocks.contains(L->getHeader()))
14579 continue;
14580
14581 // Only verify cached BECounts. Computing new BECounts may change the
14582 // results of subsequent SCEV uses.
14583 auto It = BackedgeTakenCounts.find(L);
14584 if (It == BackedgeTakenCounts.end())
14585 continue;
14586
14587 auto *CurBECount =
14588 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14589 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14590
14591 if (CurBECount == SE2.getCouldNotCompute() ||
14592 NewBECount == SE2.getCouldNotCompute()) {
14593 // NB! This situation is legal, but is very suspicious -- whatever pass
14594 // change the loop to make a trip count go from could not compute to
14595 // computable or vice-versa *should have* invalidated SCEV. However, we
14596 // choose not to assert here (for now) since we don't want false
14597 // positives.
14598 continue;
14599 }
14600
14601 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14602 SE.getTypeSizeInBits(NewBECount->getType()))
14603 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14604 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14605 SE.getTypeSizeInBits(NewBECount->getType()))
14606 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14607
14608 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14609 if (Delta && !Delta->isZero()) {
14610 dbgs() << "Trip Count for " << *L << " Changed!\n";
14611 dbgs() << "Old: " << *CurBECount << "\n";
14612 dbgs() << "New: " << *NewBECount << "\n";
14613 dbgs() << "Delta: " << *Delta << "\n";
14614 std::abort();
14615 }
14616 }
14617
14618 // Collect all valid loops currently in LoopInfo.
14619 SmallPtrSet<Loop *, 32> ValidLoops;
14620 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14621 while (!Worklist.empty()) {
14622 Loop *L = Worklist.pop_back_val();
14623 if (ValidLoops.insert(L).second)
14624 Worklist.append(L->begin(), L->end());
14625 }
14626 for (const auto &KV : ValueExprMap) {
14627#ifndef NDEBUG
14628 // Check for SCEV expressions referencing invalid/deleted loops.
14629 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14630 assert(ValidLoops.contains(AR->getLoop()) &&
14631 "AddRec references invalid loop");
14632 }
14633#endif
14634
14635 // Check that the value is also part of the reverse map.
14636 auto It = ExprValueMap.find(KV.second);
14637 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14638 dbgs() << "Value " << *KV.first
14639 << " is in ValueExprMap but not in ExprValueMap\n";
14640 std::abort();
14641 }
14642
14643 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14644 if (!ReachableBlocks.contains(I->getParent()))
14645 continue;
14646 const SCEV *OldSCEV = SCM.visit(KV.second);
14647 const SCEV *NewSCEV = SE2.getSCEV(I);
14648 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14649 if (Delta && !Delta->isZero()) {
14650 dbgs() << "SCEV for value " << *I << " changed!\n"
14651 << "Old: " << *OldSCEV << "\n"
14652 << "New: " << *NewSCEV << "\n"
14653 << "Delta: " << *Delta << "\n";
14654 std::abort();
14655 }
14656 }
14657 }
14658
14659 for (const auto &KV : ExprValueMap) {
14660 for (Value *V : KV.second) {
14661 const SCEV *S = ValueExprMap.lookup(V);
14662 if (!S) {
14663 dbgs() << "Value " << *V
14664 << " is in ExprValueMap but not in ValueExprMap\n";
14665 std::abort();
14666 }
14667 if (S != KV.first) {
14668 dbgs() << "Value " << *V << " mapped to " << *S << " rather than "
14669 << *KV.first << "\n";
14670 std::abort();
14671 }
14672 }
14673 }
14674
14675 // Verify integrity of SCEV users.
14676 for (const auto &S : UniqueSCEVs) {
14677 for (const auto *Op : S.operands()) {
14678 // We do not store dependencies of constants.
14679 if (isa<SCEVConstant>(Op))
14680 continue;
14681 auto It = SCEVUsers.find(Op);
14682 if (It != SCEVUsers.end() && It->second.count(&S))
14683 continue;
14684 dbgs() << "Use of operand " << *Op << " by user " << S
14685 << " is not being tracked!\n";
14686 std::abort();
14687 }
14688 }
14689
14690 // Verify integrity of ValuesAtScopes users.
14691 for (const auto &ValueAndVec : ValuesAtScopes) {
14692 const SCEV *Value = ValueAndVec.first;
14693 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14694 const Loop *L = LoopAndValueAtScope.first;
14695 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14696 if (!isa<SCEVConstant>(ValueAtScope)) {
14697 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14698 if (It != ValuesAtScopesUsers.end() &&
14699 is_contained(It->second, std::make_pair(L, Value)))
14700 continue;
14701 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14702 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14703 std::abort();
14704 }
14705 }
14706 }
14707
14708 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14709 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14710 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14711 const Loop *L = LoopAndValue.first;
14712 const SCEV *Value = LoopAndValue.second;
14714 auto It = ValuesAtScopes.find(Value);
14715 if (It != ValuesAtScopes.end() &&
14716 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14717 continue;
14718 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14719 << *ValueAtScope << " missing in ValuesAtScopes\n";
14720 std::abort();
14721 }
14722 }
14723
14724 // Verify integrity of BECountUsers.
14725 auto VerifyBECountUsers = [&](bool Predicated) {
14726 auto &BECounts =
14727 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14728 for (const auto &LoopAndBEInfo : BECounts) {
14729 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14730 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14731 if (!isa<SCEVConstant>(S)) {
14732 auto UserIt = BECountUsers.find(S);
14733 if (UserIt != BECountUsers.end() &&
14734 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14735 continue;
14736 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14737 << " missing from BECountUsers\n";
14738 std::abort();
14739 }
14740 }
14741 }
14742 }
14743 };
14744 VerifyBECountUsers(/* Predicated */ false);
14745 VerifyBECountUsers(/* Predicated */ true);
14746
14747 // Verify intergity of loop disposition cache.
14748 for (auto &[S, Values] : LoopDispositions) {
14749 for (auto [Loop, CachedDisposition] : Values) {
14750 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14751 if (CachedDisposition != RecomputedDisposition) {
14752 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14753 << " is incorrect: cached " << CachedDisposition << ", actual "
14754 << RecomputedDisposition << "\n";
14755 std::abort();
14756 }
14757 }
14758 }
14759
14760 // Verify integrity of the block disposition cache.
14761 for (auto &[S, Values] : BlockDispositions) {
14762 for (auto [BB, CachedDisposition] : Values) {
14763 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14764 if (CachedDisposition != RecomputedDisposition) {
14765 dbgs() << "Cached disposition of " << *S << " for block %"
14766 << BB->getName() << " is incorrect: cached " << CachedDisposition
14767 << ", actual " << RecomputedDisposition << "\n";
14768 std::abort();
14769 }
14770 }
14771 }
14772
14773 // Verify FoldCache/FoldCacheUser caches.
14774 for (auto [FoldID, Expr] : FoldCache) {
14775 auto I = FoldCacheUser.find(Expr);
14776 if (I == FoldCacheUser.end()) {
14777 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14778 << "!\n";
14779 std::abort();
14780 }
14781 if (!is_contained(I->second, FoldID)) {
14782 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14783 std::abort();
14784 }
14785 }
14786 for (auto [Expr, IDs] : FoldCacheUser) {
14787 for (auto &FoldID : IDs) {
14788 const SCEV *S = FoldCache.lookup(FoldID);
14789 if (!S) {
14790 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14791 << "!\n";
14792 std::abort();
14793 }
14794 if (S != Expr) {
14795 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: " << *S
14796 << " != " << *Expr << "!\n";
14797 std::abort();
14798 }
14799 }
14800 }
14801
14802 // Verify that ConstantMultipleCache computations are correct. We check that
14803 // cached multiples and recomputed multiples are multiples of each other to
14804 // verify correctness. It is possible that a recomputed multiple is different
14805 // from the cached multiple due to strengthened no wrap flags or changes in
14806 // KnownBits computations.
14807 for (auto [S, Multiple] : ConstantMultipleCache) {
14808 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14809 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14810 Multiple.urem(RecomputedMultiple) != 0 &&
14811 RecomputedMultiple.urem(Multiple) != 0)) {
14812 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14813 << *S << " : Computed " << RecomputedMultiple
14814 << " but cache contains " << Multiple << "!\n";
14815 std::abort();
14816 }
14817 }
14818}
14819
14821 Function &F, const PreservedAnalyses &PA,
14822 FunctionAnalysisManager::Invalidator &Inv) {
14823 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14824 // of its dependencies is invalidated.
14825 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14826 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14827 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14828 Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
14829 Inv.invalidate<LoopAnalysis>(F, PA);
14830}
14831
14832AnalysisKey ScalarEvolutionAnalysis::Key;
14833
14836 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14837 auto &AC = AM.getResult<AssumptionAnalysis>(F);
14838 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
14839 auto &LI = AM.getResult<LoopAnalysis>(F);
14840 return ScalarEvolution(F, TLI, AC, DT, LI);
14841}
14842
14848
14851 // For compatibility with opt's -analyze feature under legacy pass manager
14852 // which was not ported to NPM. This keeps tests using
14853 // update_analyze_test_checks.py working.
14854 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
14855 << F.getName() << "':\n";
14857 return PreservedAnalyses::all();
14858}
14859
14861 "Scalar Evolution Analysis", false, true)
14867 "Scalar Evolution Analysis", false, true)
14868
14870
14872
14874 SE.reset(new ScalarEvolution(
14876 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
14878 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
14879 return false;
14880}
14881
14883
14885 SE->print(OS);
14886}
14887
14889 if (!VerifySCEV)
14890 return;
14891
14892 SE->verify();
14893}
14894
14902
14904 const SCEV *RHS) {
14905 return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS);
14906}
14907
14908const SCEVPredicate *
14910 const SCEV *LHS, const SCEV *RHS) {
14912 assert(LHS->getType() == RHS->getType() &&
14913 "Type mismatch between LHS and RHS");
14914 // Unique this node based on the arguments
14915 ID.AddInteger(SCEVPredicate::P_Compare);
14916 ID.AddInteger(Pred);
14917 ID.AddPointer(LHS);
14918 ID.AddPointer(RHS);
14919 void *IP = nullptr;
14920 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14921 return S;
14922 SCEVComparePredicate *Eq = new (SCEVAllocator)
14923 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
14924 UniquePreds.InsertNode(Eq, IP);
14925 return Eq;
14926}
14927
14929 const SCEVAddRecExpr *AR,
14932 // Unique this node based on the arguments
14933 ID.AddInteger(SCEVPredicate::P_Wrap);
14934 ID.AddPointer(AR);
14935 ID.AddInteger(AddedFlags);
14936 void *IP = nullptr;
14937 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14938 return S;
14939 auto *OF = new (SCEVAllocator)
14940 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
14941 UniquePreds.InsertNode(OF, IP);
14942 return OF;
14943}
14944
14945namespace {
14946
14947class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14948public:
14949
14950 /// Rewrites \p S in the context of a loop L and the SCEV predication
14951 /// infrastructure.
14952 ///
14953 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
14954 /// equivalences present in \p Pred.
14955 ///
14956 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
14957 /// \p NewPreds such that the result will be an AddRecExpr.
14958 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
14960 const SCEVPredicate *Pred) {
14961 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
14962 return Rewriter.visit(S);
14963 }
14964
14965 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14966 if (Pred) {
14967 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
14968 for (const auto *Pred : U->getPredicates())
14969 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
14970 if (IPred->getLHS() == Expr &&
14971 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14972 return IPred->getRHS();
14973 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
14974 if (IPred->getLHS() == Expr &&
14975 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14976 return IPred->getRHS();
14977 }
14978 }
14979 return convertToAddRecWithPreds(Expr);
14980 }
14981
14982 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14983 const SCEV *Operand = visit(Expr->getOperand());
14984 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14985 if (AR && AR->getLoop() == L && AR->isAffine()) {
14986 // This couldn't be folded because the operand didn't have the nuw
14987 // flag. Add the nusw flag as an assumption that we could make.
14988 const SCEV *Step = AR->getStepRecurrence(SE);
14989 Type *Ty = Expr->getType();
14990 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
14991 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
14992 SE.getSignExtendExpr(Step, Ty), L,
14993 AR->getNoWrapFlags());
14994 }
14995 return SE.getZeroExtendExpr(Operand, Expr->getType());
14996 }
14997
14998 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
14999 const SCEV *Operand = visit(Expr->getOperand());
15000 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
15001 if (AR && AR->getLoop() == L && AR->isAffine()) {
15002 // This couldn't be folded because the operand didn't have the nsw
15003 // flag. Add the nssw flag as an assumption that we could make.
15004 const SCEV *Step = AR->getStepRecurrence(SE);
15005 Type *Ty = Expr->getType();
15006 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
15007 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
15008 SE.getSignExtendExpr(Step, Ty), L,
15009 AR->getNoWrapFlags());
15010 }
15011 return SE.getSignExtendExpr(Operand, Expr->getType());
15012 }
15013
15014private:
15015 explicit SCEVPredicateRewriter(
15016 const Loop *L, ScalarEvolution &SE,
15017 SmallVectorImpl<const SCEVPredicate *> *NewPreds,
15018 const SCEVPredicate *Pred)
15019 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
15020
15021 bool addOverflowAssumption(const SCEVPredicate *P) {
15022 if (!NewPreds) {
15023 // Check if we've already made this assumption.
15024 return Pred && Pred->implies(P, SE);
15025 }
15026 NewPreds->push_back(P);
15027 return true;
15028 }
15029
15030 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
15032 auto *A = SE.getWrapPredicate(AR, AddedFlags);
15033 return addOverflowAssumption(A);
15034 }
15035
15036 // If \p Expr represents a PHINode, we try to see if it can be represented
15037 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
15038 // to add this predicate as a runtime overflow check, we return the AddRec.
15039 // If \p Expr does not meet these conditions (is not a PHI node, or we
15040 // couldn't create an AddRec for it, or couldn't add the predicate), we just
15041 // return \p Expr.
15042 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
15043 if (!isa<PHINode>(Expr->getValue()))
15044 return Expr;
15045 std::optional<
15046 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
15047 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
15048 if (!PredicatedRewrite)
15049 return Expr;
15050 for (const auto *P : PredicatedRewrite->second){
15051 // Wrap predicates from outer loops are not supported.
15052 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
15053 if (L != WP->getExpr()->getLoop())
15054 return Expr;
15055 }
15056 if (!addOverflowAssumption(P))
15057 return Expr;
15058 }
15059 return PredicatedRewrite->first;
15060 }
15061
15062 SmallVectorImpl<const SCEVPredicate *> *NewPreds;
15063 const SCEVPredicate *Pred;
15064 const Loop *L;
15065};
15066
15067} // end anonymous namespace
15068
15069const SCEV *
15071 const SCEVPredicate &Preds) {
15072 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
15073}
15074
15076 const SCEV *S, const Loop *L,
15079 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
15080 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
15081
15082 if (!AddRec)
15083 return nullptr;
15084
15085 // Check if any of the transformed predicates is known to be false. In that
15086 // case, it doesn't make sense to convert to a predicated AddRec, as the
15087 // versioned loop will never execute.
15088 for (const SCEVPredicate *Pred : TransformPreds) {
15089 auto *WrapPred = dyn_cast<SCEVWrapPredicate>(Pred);
15090 if (!WrapPred || WrapPred->getFlags() != SCEVWrapPredicate::IncrementNSSW)
15091 continue;
15092
15093 const SCEVAddRecExpr *AddRecToCheck = WrapPred->getExpr();
15094 const SCEV *ExitCount = getBackedgeTakenCount(AddRecToCheck->getLoop());
15095 if (isa<SCEVCouldNotCompute>(ExitCount))
15096 continue;
15097
15098 const SCEV *Step = AddRecToCheck->getStepRecurrence(*this);
15099 if (!Step->isOne())
15100 continue;
15101
15102 ExitCount = getTruncateOrSignExtend(ExitCount, Step->getType());
15103 const SCEV *Add = getAddExpr(AddRecToCheck->getStart(), ExitCount);
15104 if (isKnownPredicate(CmpInst::ICMP_SLT, Add, AddRecToCheck->getStart()))
15105 return nullptr;
15106 }
15107
15108 // Since the transformation was successful, we can now transfer the SCEV
15109 // predicates.
15110 Preds.append(TransformPreds.begin(), TransformPreds.end());
15111
15112 return AddRec;
15113}
15114
15115/// SCEV predicates
15119
15121 const ICmpInst::Predicate Pred,
15122 const SCEV *LHS, const SCEV *RHS)
15123 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
15124 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
15125 assert(LHS != RHS && "LHS and RHS are the same SCEV");
15126}
15127
15129 ScalarEvolution &SE) const {
15130 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
15131
15132 if (!Op)
15133 return false;
15134
15135 if (Pred != ICmpInst::ICMP_EQ)
15136 return false;
15137
15138 return Op->LHS == LHS && Op->RHS == RHS;
15139}
15140
15141bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
15142
15144 if (Pred == ICmpInst::ICMP_EQ)
15145 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
15146 else
15147 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
15148 << *RHS << "\n";
15149
15150}
15151
15153 const SCEVAddRecExpr *AR,
15154 IncrementWrapFlags Flags)
15155 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
15156
15157const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
15158
15160 ScalarEvolution &SE) const {
15161 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
15162 if (!Op || setFlags(Flags, Op->Flags) != Flags)
15163 return false;
15164
15165 if (Op->AR == AR)
15166 return true;
15167
15168 if (Flags != SCEVWrapPredicate::IncrementNSSW &&
15170 return false;
15171
15172 const SCEV *Start = AR->getStart();
15173 const SCEV *OpStart = Op->AR->getStart();
15174 if (Start->getType()->isPointerTy() != OpStart->getType()->isPointerTy())
15175 return false;
15176
15177 // Reject pointers to different address spaces.
15178 if (Start->getType()->isPointerTy() && Start->getType() != OpStart->getType())
15179 return false;
15180
15181 const SCEV *Step = AR->getStepRecurrence(SE);
15182 const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
15183 if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
15184 return false;
15185
15186 // If both steps are positive, this implies N, if N's start and step are
15187 // ULE/SLE (for NSUW/NSSW) than this'.
15188 Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
15189 Step = SE.getNoopOrZeroExtend(Step, WiderTy);
15190 OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
15191
15192 bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
15193 OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
15194 : SE.getNoopOrSignExtend(OpStart, WiderTy);
15195 Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
15196 : SE.getNoopOrSignExtend(Start, WiderTy);
15198 return SE.isKnownPredicate(Pred, OpStep, Step) &&
15199 SE.isKnownPredicate(Pred, OpStart, Start);
15200}
15201
15203 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
15204 IncrementWrapFlags IFlags = Flags;
15205
15206 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
15207 IFlags = clearFlags(IFlags, IncrementNSSW);
15208
15209 return IFlags == IncrementAnyWrap;
15210}
15211
15212void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
15213 OS.indent(Depth) << *getExpr() << " Added Flags: ";
15215 OS << "<nusw>";
15217 OS << "<nssw>";
15218 OS << "\n";
15219}
15220
15223 ScalarEvolution &SE) {
15224 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
15225 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
15226
15227 // We can safely transfer the NSW flag as NSSW.
15228 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
15229 ImpliedFlags = IncrementNSSW;
15230
15231 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
15232 // If the increment is positive, the SCEV NUW flag will also imply the
15233 // WrapPredicate NUSW flag.
15234 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
15235 if (Step->getValue()->getValue().isNonNegative())
15236 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
15237 }
15238
15239 return ImpliedFlags;
15240}
15241
15242/// Union predicates don't get cached so create a dummy set ID for it.
15244 ScalarEvolution &SE)
15245 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15246 for (const auto *P : Preds)
15247 add(P, SE);
15248}
15249
15251 return all_of(Preds,
15252 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15253}
15254
15256 ScalarEvolution &SE) const {
15257 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15258 return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15259 return this->implies(I, SE);
15260 });
15261
15262 return any_of(Preds,
15263 [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
15264}
15265
15267 for (const auto *Pred : Preds)
15268 Pred->print(OS, Depth);
15269}
15270
15271void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
15272 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15273 for (const auto *Pred : Set->Preds)
15274 add(Pred, SE);
15275 return;
15276 }
15277
15278 // Implication checks are quadratic in the number of predicates. Stop doing
15279 // them if there are many predicates, as they should be too expensive to use
15280 // anyway at that point.
15281 bool CheckImplies = Preds.size() < 16;
15282
15283 // Only add predicate if it is not already implied by this union predicate.
15284 if (CheckImplies && implies(N, SE))
15285 return;
15286
15287 // Build a new vector containing the current predicates, except the ones that
15288 // are implied by the new predicate N.
15290 for (auto *P : Preds) {
15291 if (CheckImplies && N->implies(P, SE))
15292 continue;
15293 PrunedPreds.push_back(P);
15294 }
15295 Preds = std::move(PrunedPreds);
15296 Preds.push_back(N);
15297}
15298
15300 Loop &L)
15301 : SE(SE), L(L) {
15303 Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
15304}
15305
15308 for (const auto *Op : Ops)
15309 // We do not expect that forgetting cached data for SCEVConstants will ever
15310 // open any prospects for sharpening or introduce any correctness issues,
15311 // so we don't bother storing their dependencies.
15312 if (!isa<SCEVConstant>(Op))
15313 SCEVUsers[Op].insert(User);
15314}
15315
15317 const SCEV *Expr = SE.getSCEV(V);
15318 return getPredicatedSCEV(Expr);
15319}
15320
15322 RewriteEntry &Entry = RewriteMap[Expr];
15323
15324 // If we already have an entry and the version matches, return it.
15325 if (Entry.second && Generation == Entry.first)
15326 return Entry.second;
15327
15328 // We found an entry but it's stale. Rewrite the stale entry
15329 // according to the current predicate.
15330 if (Entry.second)
15331 Expr = Entry.second;
15332
15333 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
15334 Entry = {Generation, NewSCEV};
15335
15336 return NewSCEV;
15337}
15338
15340 if (!BackedgeCount) {
15342 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
15343 for (const auto *P : Preds)
15344 addPredicate(*P);
15345 }
15346 return BackedgeCount;
15347}
15348
15350 if (!SymbolicMaxBackedgeCount) {
15352 SymbolicMaxBackedgeCount =
15353 SE.getPredicatedSymbolicMaxBackedgeTakenCount(&L, Preds);
15354 for (const auto *P : Preds)
15355 addPredicate(*P);
15356 }
15357 return SymbolicMaxBackedgeCount;
15358}
15359
15361 if (!SmallConstantMaxTripCount) {
15363 SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds);
15364 for (const auto *P : Preds)
15365 addPredicate(*P);
15366 }
15367 return *SmallConstantMaxTripCount;
15368}
15369
15371 if (Preds->implies(&Pred, SE))
15372 return;
15373
15374 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15375 NewPreds.push_back(&Pred);
15376 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
15377 updateGeneration();
15378}
15379
15381 return *Preds;
15382}
15383
15384void PredicatedScalarEvolution::updateGeneration() {
15385 // If the generation number wrapped recompute everything.
15386 if (++Generation == 0) {
15387 for (auto &II : RewriteMap) {
15388 const SCEV *Rewritten = II.second.second;
15389 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
15390 }
15391 }
15392}
15393
15396 const SCEV *Expr = getSCEV(V);
15397 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15398
15399 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
15400
15401 // Clear the statically implied flags.
15402 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
15403 addPredicate(*SE.getWrapPredicate(AR, Flags));
15404
15405 auto II = FlagsMap.insert({V, Flags});
15406 if (!II.second)
15407 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
15408}
15409
15412 const SCEV *Expr = getSCEV(V);
15413 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15414
15416 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
15417
15418 auto II = FlagsMap.find(V);
15419
15420 if (II != FlagsMap.end())
15421 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
15422
15424}
15425
15427 const SCEV *Expr = this->getSCEV(V);
15429 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
15430
15431 if (!New)
15432 return nullptr;
15433
15434 for (const auto *P : NewPreds)
15435 addPredicate(*P);
15436
15437 RewriteMap[SE.getSCEV(V)] = {Generation, New};
15438 return New;
15439}
15440
15443 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15444 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15445 SE)),
15446 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15447 for (auto I : Init.FlagsMap)
15448 FlagsMap.insert(I);
15449}
15450
15452 // For each block.
15453 for (auto *BB : L.getBlocks())
15454 for (auto &I : *BB) {
15455 if (!SE.isSCEVable(I.getType()))
15456 continue;
15457
15458 auto *Expr = SE.getSCEV(&I);
15459 auto II = RewriteMap.find(Expr);
15460
15461 if (II == RewriteMap.end())
15462 continue;
15463
15464 // Don't print things that are not interesting.
15465 if (II->second.second == Expr)
15466 continue;
15467
15468 OS.indent(Depth) << "[PSE]" << I << ":\n";
15469 OS.indent(Depth + 2) << *Expr << "\n";
15470 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15471 }
15472}
15473
15476 BasicBlock *Header = L->getHeader();
15477 BasicBlock *Pred = L->getLoopPredecessor();
15478 LoopGuards Guards(SE);
15479 if (!Pred)
15480 return Guards;
15482 collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15483 return Guards;
15484}
15485
15486void ScalarEvolution::LoopGuards::collectFromPHI(
15490 unsigned Depth) {
15491 if (!SE.isSCEVable(Phi.getType()))
15492 return;
15493
15494 using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15495 auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15496 const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15497 if (!VisitedBlocks.insert(InBlock).second)
15498 return {nullptr, scCouldNotCompute};
15499
15500 // Avoid analyzing unreachable blocks so that we don't get trapped
15501 // traversing cycles with ill-formed dominance or infinite cycles
15502 if (!SE.DT.isReachableFromEntry(InBlock))
15503 return {nullptr, scCouldNotCompute};
15504
15505 auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15506 if (Inserted)
15507 collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15508 Depth + 1);
15509 auto &RewriteMap = G->second.RewriteMap;
15510 if (RewriteMap.empty())
15511 return {nullptr, scCouldNotCompute};
15512 auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15513 if (S == RewriteMap.end())
15514 return {nullptr, scCouldNotCompute};
15515 auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15516 if (!SM)
15517 return {nullptr, scCouldNotCompute};
15518 if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15519 return {C0, SM->getSCEVType()};
15520 return {nullptr, scCouldNotCompute};
15521 };
15522 auto MergeMinMaxConst = [](MinMaxPattern P1,
15523 MinMaxPattern P2) -> MinMaxPattern {
15524 auto [C1, T1] = P1;
15525 auto [C2, T2] = P2;
15526 if (!C1 || !C2 || T1 != T2)
15527 return {nullptr, scCouldNotCompute};
15528 switch (T1) {
15529 case scUMaxExpr:
15530 return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15531 case scSMaxExpr:
15532 return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15533 case scUMinExpr:
15534 return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15535 case scSMinExpr:
15536 return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15537 default:
15538 llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15539 }
15540 };
15541 auto P = GetMinMaxConst(0);
15542 for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15543 if (!P.first)
15544 break;
15545 P = MergeMinMaxConst(P, GetMinMaxConst(In));
15546 }
15547 if (P.first) {
15548 const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15550 const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15551 Guards.RewriteMap.insert({LHS, RHS});
15552 }
15553}
15554
15555// Return a new SCEV that modifies \p Expr to the closest number divides by
15556// \p Divisor and less or equal than Expr. For now, only handle constant
15557// Expr.
15559 const APInt &DivisorVal,
15560 ScalarEvolution &SE) {
15561 const APInt *ExprVal;
15562 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15563 DivisorVal.isNonPositive())
15564 return Expr;
15565 APInt Rem = ExprVal->urem(DivisorVal);
15566 // return the SCEV: Expr - Expr % Divisor
15567 return SE.getConstant(*ExprVal - Rem);
15568}
15569
15570// Return a new SCEV that modifies \p Expr to the closest number divides by
15571// \p Divisor and greater or equal than Expr. For now, only handle constant
15572// Expr.
15573static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
15574 const APInt &DivisorVal,
15575 ScalarEvolution &SE) {
15576 const APInt *ExprVal;
15577 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15578 DivisorVal.isNonPositive())
15579 return Expr;
15580 APInt Rem = ExprVal->urem(DivisorVal);
15581 if (Rem.isZero())
15582 return Expr;
15583 // return the SCEV: Expr + Divisor - Expr % Divisor
15584 return SE.getConstant(*ExprVal + DivisorVal - Rem);
15585}
15586
15588 ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS,
15591 // If we have LHS == 0, check if LHS is computing a property of some unknown
15592 // SCEV %v which we can rewrite %v to express explicitly.
15594 return false;
15595 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15596 // explicitly express that.
15597 const SCEVUnknown *URemLHS = nullptr;
15598 const SCEV *URemRHS = nullptr;
15599 if (!match(LHS, m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE)))
15600 return false;
15601
15602 const SCEV *Multiple =
15603 SE.getMulExpr(SE.getUDivExpr(URemLHS, URemRHS), URemRHS);
15604 DivInfo[URemLHS] = Multiple;
15605 if (auto *C = dyn_cast<SCEVConstant>(URemRHS))
15606 Multiples[URemLHS] = C->getAPInt();
15607 return true;
15608}
15609
15610// Check if the condition is a divisibility guard (A % B == 0).
15611static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS,
15612 ScalarEvolution &SE) {
15613 const SCEV *X, *Y;
15614 return match(LHS, m_scev_URem(m_SCEV(X), m_SCEV(Y), SE)) && RHS->isZero();
15615}
15616
15617// Apply divisibility by \p Divisor on MinMaxExpr with constant values,
15618// recursively. This is done by aligning up/down the constant value to the
15619// Divisor.
15620static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr,
15621 APInt Divisor,
15622 ScalarEvolution &SE) {
15623 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15624 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15625 // the non-constant operand and in \p LHS the constant operand.
15626 auto IsMinMaxSCEVWithNonNegativeConstant =
15627 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15628 const SCEV *&RHS) {
15629 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15630 if (MinMax->getNumOperands() != 2)
15631 return false;
15632 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15633 if (C->getAPInt().isNegative())
15634 return false;
15635 SCTy = MinMax->getSCEVType();
15636 LHS = MinMax->getOperand(0);
15637 RHS = MinMax->getOperand(1);
15638 return true;
15639 }
15640 }
15641 return false;
15642 };
15643
15644 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15645 SCEVTypes SCTy;
15646 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15647 MinMaxRHS))
15648 return MinMaxExpr;
15649 auto IsMin = isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15650 assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!");
15651 auto *DivisibleExpr =
15652 IsMin ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE)
15653 : getNextSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE);
15655 applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr};
15656 return SE.getMinMaxExpr(SCTy, Ops);
15657}
15658
15659void ScalarEvolution::LoopGuards::collectFromBlock(
15660 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
15661 const BasicBlock *Block, const BasicBlock *Pred,
15662 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
15663
15665
15666 SmallVector<const SCEV *> ExprsToRewrite;
15667 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15668 const SCEV *RHS,
15669 DenseMap<const SCEV *, const SCEV *> &RewriteMap,
15670 const LoopGuards &DivGuards) {
15671 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15672 // replacement SCEV which isn't directly implied by the structure of that
15673 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15674 // legal. See the scoping rules for flags in the header to understand why.
15675
15676 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15677 // create this form when combining two checks of the form (X u< C2 + C1) and
15678 // (X >=u C1).
15679 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15680 &ExprsToRewrite]() {
15681 const SCEVConstant *C1;
15682 const SCEVUnknown *LHSUnknown;
15683 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15684 if (!match(LHS,
15685 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
15686 !C2)
15687 return false;
15688
15689 auto ExactRegion =
15690 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15691 .sub(C1->getAPInt());
15692
15693 // Bail out, unless we have a non-wrapping, monotonic range.
15694 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15695 return false;
15696 auto [I, Inserted] = RewriteMap.try_emplace(LHSUnknown);
15697 const SCEV *RewrittenLHS = Inserted ? LHSUnknown : I->second;
15698 I->second = SE.getUMaxExpr(
15699 SE.getConstant(ExactRegion.getUnsignedMin()),
15700 SE.getUMinExpr(RewrittenLHS,
15701 SE.getConstant(ExactRegion.getUnsignedMax())));
15702 ExprsToRewrite.push_back(LHSUnknown);
15703 return true;
15704 };
15705 if (MatchRangeCheckIdiom())
15706 return;
15707
15708 // Do not apply information for constants or if RHS contains an AddRec.
15710 return;
15711
15712 // If RHS is SCEVUnknown, make sure the information is applied to it.
15714 std::swap(LHS, RHS);
15716 }
15717
15718 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15719 // and \p FromRewritten are the same (i.e. there has been no rewrite
15720 // registered for \p From), then puts this value in the list of rewritten
15721 // expressions.
15722 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15723 const SCEV *To) {
15724 if (From == FromRewritten)
15725 ExprsToRewrite.push_back(From);
15726 RewriteMap[From] = To;
15727 };
15728
15729 // Checks whether \p S has already been rewritten. In that case returns the
15730 // existing rewrite because we want to chain further rewrites onto the
15731 // already rewritten value. Otherwise returns \p S.
15732 auto GetMaybeRewritten = [&](const SCEV *S) {
15733 return RewriteMap.lookup_or(S, S);
15734 };
15735
15736 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15737 // Apply divisibility information when computing the constant multiple.
15738 const APInt &DividesBy =
15739 SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS));
15740
15741 // Collect rewrites for LHS and its transitive operands based on the
15742 // condition.
15743 // For min/max expressions, also apply the guard to its operands:
15744 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15745 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15746 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15747 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15748
15749 // We cannot express strict predicates in SCEV, so instead we replace them
15750 // with non-strict ones against plus or minus one of RHS depending on the
15751 // predicate.
15752 const SCEV *One = SE.getOne(RHS->getType());
15753 switch (Predicate) {
15754 case CmpInst::ICMP_ULT:
15755 if (RHS->getType()->isPointerTy())
15756 return;
15757 RHS = SE.getUMaxExpr(RHS, One);
15758 [[fallthrough]];
15759 case CmpInst::ICMP_SLT: {
15760 RHS = SE.getMinusSCEV(RHS, One);
15761 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15762 break;
15763 }
15764 case CmpInst::ICMP_UGT:
15765 case CmpInst::ICMP_SGT:
15766 RHS = SE.getAddExpr(RHS, One);
15767 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15768 break;
15769 case CmpInst::ICMP_ULE:
15770 case CmpInst::ICMP_SLE:
15771 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15772 break;
15773 case CmpInst::ICMP_UGE:
15774 case CmpInst::ICMP_SGE:
15775 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15776 break;
15777 default:
15778 break;
15779 }
15780
15782 SmallPtrSet<const SCEV *, 16> Visited;
15783
15784 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15785 append_range(Worklist, S->operands());
15786 };
15787
15788 while (!Worklist.empty()) {
15789 const SCEV *From = Worklist.pop_back_val();
15790 if (isa<SCEVConstant>(From))
15791 continue;
15792 if (!Visited.insert(From).second)
15793 continue;
15794 const SCEV *FromRewritten = GetMaybeRewritten(From);
15795 const SCEV *To = nullptr;
15796
15797 switch (Predicate) {
15798 case CmpInst::ICMP_ULT:
15799 case CmpInst::ICMP_ULE:
15800 To = SE.getUMinExpr(FromRewritten, RHS);
15801 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15802 EnqueueOperands(UMax);
15803 break;
15804 case CmpInst::ICMP_SLT:
15805 case CmpInst::ICMP_SLE:
15806 To = SE.getSMinExpr(FromRewritten, RHS);
15807 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15808 EnqueueOperands(SMax);
15809 break;
15810 case CmpInst::ICMP_UGT:
15811 case CmpInst::ICMP_UGE:
15812 To = SE.getUMaxExpr(FromRewritten, RHS);
15813 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15814 EnqueueOperands(UMin);
15815 break;
15816 case CmpInst::ICMP_SGT:
15817 case CmpInst::ICMP_SGE:
15818 To = SE.getSMaxExpr(FromRewritten, RHS);
15819 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15820 EnqueueOperands(SMin);
15821 break;
15822 case CmpInst::ICMP_EQ:
15824 To = RHS;
15825 break;
15826 case CmpInst::ICMP_NE:
15827 if (match(RHS, m_scev_Zero())) {
15828 const SCEV *OneAlignedUp =
15829 getNextSCEVDivisibleByDivisor(One, DividesBy, SE);
15830 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
15831 } else {
15832 // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS),
15833 // but creating the subtraction eagerly is expensive. Track the
15834 // inequalities in a separate map, and materialize the rewrite lazily
15835 // when encountering a suitable subtraction while re-writing.
15836 if (LHS->getType()->isPointerTy()) {
15840 break;
15841 }
15842 const SCEVConstant *C;
15843 const SCEV *A, *B;
15846 RHS = A;
15847 LHS = B;
15848 }
15849 if (LHS > RHS)
15850 std::swap(LHS, RHS);
15851 Guards.NotEqual.insert({LHS, RHS});
15852 continue;
15853 }
15854 break;
15855 default:
15856 break;
15857 }
15858
15859 if (To)
15860 AddRewrite(From, FromRewritten, To);
15861 }
15862 };
15863
15865 // First, collect information from assumptions dominating the loop.
15866 for (auto &AssumeVH : SE.AC.assumptions()) {
15867 if (!AssumeVH)
15868 continue;
15869 auto *AssumeI = cast<CallInst>(AssumeVH);
15870 if (!SE.DT.dominates(AssumeI, Block))
15871 continue;
15872 Terms.emplace_back(AssumeI->getOperand(0), true);
15873 }
15874
15875 // Second, collect information from llvm.experimental.guards dominating the loop.
15876 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
15877 SE.F.getParent(), Intrinsic::experimental_guard);
15878 if (GuardDecl)
15879 for (const auto *GU : GuardDecl->users())
15880 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
15881 if (Guard->getFunction() == Block->getParent() &&
15882 SE.DT.dominates(Guard, Block))
15883 Terms.emplace_back(Guard->getArgOperand(0), true);
15884
15885 // Third, collect conditions from dominating branches. Starting at the loop
15886 // predecessor, climb up the predecessor chain, as long as there are
15887 // predecessors that can be found that have unique successors leading to the
15888 // original header.
15889 // TODO: share this logic with isLoopEntryGuardedByCond.
15890 unsigned NumCollectedConditions = 0;
15892 std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
15893 for (; Pair.first;
15894 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
15895 VisitedBlocks.insert(Pair.second);
15896 const BranchInst *LoopEntryPredicate =
15897 dyn_cast<BranchInst>(Pair.first->getTerminator());
15898 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
15899 continue;
15900
15901 Terms.emplace_back(LoopEntryPredicate->getCondition(),
15902 LoopEntryPredicate->getSuccessor(0) == Pair.second);
15903 NumCollectedConditions++;
15904
15905 // If we are recursively collecting guards stop after 2
15906 // conditions to limit compile-time impact for now.
15907 if (Depth > 0 && NumCollectedConditions == 2)
15908 break;
15909 }
15910 // Finally, if we stopped climbing the predecessor chain because
15911 // there wasn't a unique one to continue, try to collect conditions
15912 // for PHINodes by recursively following all of their incoming
15913 // blocks and try to merge the found conditions to build a new one
15914 // for the Phi.
15915 if (Pair.second->hasNPredecessorsOrMore(2) &&
15917 SmallDenseMap<const BasicBlock *, LoopGuards> IncomingGuards;
15918 for (auto &Phi : Pair.second->phis())
15919 collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
15920 }
15921
15922 // Now apply the information from the collected conditions to
15923 // Guards.RewriteMap. Conditions are processed in reverse order, so the
15924 // earliest conditions is processed first, except guards with divisibility
15925 // information, which are moved to the back. This ensures the SCEVs with the
15926 // shortest dependency chains are constructed first.
15928 GuardsToProcess;
15929 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
15930 SmallVector<Value *, 8> Worklist;
15931 SmallPtrSet<Value *, 8> Visited;
15932 Worklist.push_back(Term);
15933 while (!Worklist.empty()) {
15934 Value *Cond = Worklist.pop_back_val();
15935 if (!Visited.insert(Cond).second)
15936 continue;
15937
15938 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
15939 auto Predicate =
15940 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
15941 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
15942 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15943 // If LHS is a constant, apply information to the other expression.
15944 // TODO: If LHS is not a constant, check if using CompareSCEVComplexity
15945 // can improve results.
15946 if (isa<SCEVConstant>(LHS)) {
15947 std::swap(LHS, RHS);
15949 }
15950 GuardsToProcess.emplace_back(Predicate, LHS, RHS);
15951 continue;
15952 }
15953
15954 Value *L, *R;
15955 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
15956 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
15957 Worklist.push_back(L);
15958 Worklist.push_back(R);
15959 }
15960 }
15961 }
15962
15963 // Process divisibility guards in reverse order to populate DivGuards early.
15964 DenseMap<const SCEV *, APInt> Multiples;
15965 LoopGuards DivGuards(SE);
15966 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) {
15967 if (!isDivisibilityGuard(LHS, RHS, SE))
15968 continue;
15969 collectDivisibilityInformation(Predicate, LHS, RHS, DivGuards.RewriteMap,
15970 Multiples, SE);
15971 }
15972
15973 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess)
15974 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivGuards);
15975
15976 // Apply divisibility information last. This ensures it is applied to the
15977 // outermost expression after other rewrites for the given value.
15978 for (const auto &[K, Divisor] : Multiples) {
15979 const SCEV *DivisorSCEV = SE.getConstant(Divisor);
15980 Guards.RewriteMap[K] =
15982 Guards.rewrite(K), Divisor, SE),
15983 DivisorSCEV),
15984 DivisorSCEV);
15985 ExprsToRewrite.push_back(K);
15986 }
15987
15988 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
15989 // the replacement expressions are contained in the ranges of the replaced
15990 // expressions.
15991 Guards.PreserveNUW = true;
15992 Guards.PreserveNSW = true;
15993 for (const SCEV *Expr : ExprsToRewrite) {
15994 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15995 Guards.PreserveNUW &=
15996 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
15997 Guards.PreserveNSW &=
15998 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
15999 }
16000
16001 // Now that all rewrite information is collect, rewrite the collected
16002 // expressions with the information in the map. This applies information to
16003 // sub-expressions.
16004 if (ExprsToRewrite.size() > 1) {
16005 for (const SCEV *Expr : ExprsToRewrite) {
16006 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
16007 Guards.RewriteMap.erase(Expr);
16008 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
16009 }
16010 }
16011}
16012
16014 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
16015 /// in the map. It skips AddRecExpr because we cannot guarantee that the
16016 /// replacement is loop invariant in the loop of the AddRec.
16017 class SCEVLoopGuardRewriter
16018 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
16021
16023
16024 public:
16025 SCEVLoopGuardRewriter(ScalarEvolution &SE,
16026 const ScalarEvolution::LoopGuards &Guards)
16027 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap),
16028 NotEqual(Guards.NotEqual) {
16029 if (Guards.PreserveNUW)
16030 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
16031 if (Guards.PreserveNSW)
16032 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
16033 }
16034
16035 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
16036
16037 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
16038 return Map.lookup_or(Expr, Expr);
16039 }
16040
16041 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
16042 if (const SCEV *S = Map.lookup(Expr))
16043 return S;
16044
16045 // If we didn't find the extact ZExt expr in the map, check if there's
16046 // an entry for a smaller ZExt we can use instead.
16047 Type *Ty = Expr->getType();
16048 const SCEV *Op = Expr->getOperand(0);
16049 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
16050 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
16051 Bitwidth > Op->getType()->getScalarSizeInBits()) {
16052 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
16053 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
16054 if (const SCEV *S = Map.lookup(NarrowExt))
16055 return SE.getZeroExtendExpr(S, Ty);
16056 Bitwidth = Bitwidth / 2;
16057 }
16058
16060 Expr);
16061 }
16062
16063 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
16064 if (const SCEV *S = Map.lookup(Expr))
16065 return S;
16067 Expr);
16068 }
16069
16070 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
16071 if (const SCEV *S = Map.lookup(Expr))
16072 return S;
16074 }
16075
16076 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
16077 if (const SCEV *S = Map.lookup(Expr))
16078 return S;
16080 }
16081
16082 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
16083 // Helper to check if S is a subtraction (A - B) where A != B, and if so,
16084 // return UMax(S, 1).
16085 auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * {
16086 const SCEV *LHS, *RHS;
16087 if (MatchBinarySub(S, LHS, RHS)) {
16088 if (LHS > RHS)
16089 std::swap(LHS, RHS);
16090 if (NotEqual.contains({LHS, RHS})) {
16091 const SCEV *OneAlignedUp = getNextSCEVDivisibleByDivisor(
16092 SE.getOne(S->getType()), SE.getConstantMultiple(S), SE);
16093 return SE.getUMaxExpr(OneAlignedUp, S);
16094 }
16095 }
16096 return nullptr;
16097 };
16098
16099 // Check if Expr itself is a subtraction pattern with guard info.
16100 if (const SCEV *Rewritten = RewriteSubtraction(Expr))
16101 return Rewritten;
16102
16103 // Trip count expressions sometimes consist of adding 3 operands, i.e.
16104 // (Const + A + B). There may be guard info for A + B, and if so, apply
16105 // it.
16106 // TODO: Could more generally apply guards to Add sub-expressions.
16107 if (isa<SCEVConstant>(Expr->getOperand(0)) &&
16108 Expr->getNumOperands() == 3) {
16109 const SCEV *Add =
16110 SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2));
16111 if (const SCEV *Rewritten = RewriteSubtraction(Add))
16112 return SE.getAddExpr(
16113 Expr->getOperand(0), Rewritten,
16114 ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask));
16115 if (const SCEV *S = Map.lookup(Add))
16116 return SE.getAddExpr(Expr->getOperand(0), S);
16117 }
16119 bool Changed = false;
16120 for (const auto *Op : Expr->operands()) {
16121 Operands.push_back(
16123 Changed |= Op != Operands.back();
16124 }
16125 // We are only replacing operands with equivalent values, so transfer the
16126 // flags from the original expression.
16127 return !Changed ? Expr
16128 : SE.getAddExpr(Operands,
16130 Expr->getNoWrapFlags(), FlagMask));
16131 }
16132
16133 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
16135 bool Changed = false;
16136 for (const auto *Op : Expr->operands()) {
16137 Operands.push_back(
16139 Changed |= Op != Operands.back();
16140 }
16141 // We are only replacing operands with equivalent values, so transfer the
16142 // flags from the original expression.
16143 return !Changed ? Expr
16144 : SE.getMulExpr(Operands,
16146 Expr->getNoWrapFlags(), FlagMask));
16147 }
16148 };
16149
16150 if (RewriteMap.empty() && NotEqual.empty())
16151 return Expr;
16152
16153 SCEVLoopGuardRewriter Rewriter(SE, *this);
16154 return Rewriter.visit(Expr);
16155}
16156
16157const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
16158 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
16159}
16160
16162 const LoopGuards &Guards) {
16163 return Guards.rewrite(Expr);
16164}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
constexpr LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition Compiler.h:661
This file contains the declarations for the subclasses of Constant, which represent the different fla...
SmallPtrSet< const BasicBlock *, 8 > VisitedBlocks
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition IVUsers.cpp:48
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition Lint.cpp:539
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
#define G(x, y, z)
Definition MD5.cpp:55
#define T
#define T1
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
SI optimize exec mask operations pre RA
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
This file contains some templates that are useful if you are working with the STL at all.
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static const SCEV * getNextSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< unsigned > MaxLoopGuardCollectionDepth("scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static void GroupByComplexity(SmallVectorImpl< const SCEV * > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, SmallVectorImpl< const SCEV * > &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber)
Performs a number of common optimizations on the passed Ops.
static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS, ScalarEvolution &SE)
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, SmallVectorImpl< const SCEVPredicate * > *Predicates, ScalarEvolution &SE, const Loop *L)
Finds the minimum unsigned root of the following equation:
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS)
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool hasHugeExpression(ArrayRef< const SCEV * > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static bool collectDivisibilityInformation(ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS, DenseMap< const SCEV *, const SCEV * > &DivInfo, DenseMap< const SCEV *, APInt > &Multiples, ScalarEvolution &SE)
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, ArrayRef< const SCEV * > Ops, SCEV::NoWrapFlags Flags)
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * getPreviousSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static const SCEV * applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr, APInt Divisor, ScalarEvolution &SE)
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static std::optional< int > CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static bool CollectAddOperandsWithScales(SmallDenseMap< const SCEV *, APInt, 16 > &M, SmallVectorImpl< const SCEV * > &NewOps, APInt &AccumulatedConstant, ArrayRef< const SCEV * > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
static bool InBlock(const Value *V, const BasicBlock *BB)
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:114
static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")
static TableGen::Emitter::OptClass< SkeletonEmitter > X("gen-skeleton-class", "Generate example skeleton class")
static SymbolRef::Type getType(const Symbol *Sym)
Definition TapiFile.cpp:39
LocallyHashedType DenseMapInfo< LocallyHashedType >::Empty
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition VPlanSLP.cpp:247
static std::optional< bool > isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS, const Value *ARHS, const Value *BLHS, const Value *BRHS)
Return true if "icmp Pred BLHS BRHS" is true whenever "icmp PredALHS ARHS" is true.
Virtual Register Rewriter
Value * RHS
Value * LHS
BinaryOperator * Mul
static const uint32_t IV[8]
Definition blake3_impl.h:83
SCEVCastSinkingRewriter(ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
static const SCEV * rewrite(const SCEV *Scev, ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
const SCEV * visitUnknown(const SCEVUnknown *Expr)
const SCEV * visitMulExpr(const SCEVMulExpr *Expr)
const SCEV * visitAddExpr(const SCEVAddExpr *Expr)
const SCEV * visit(const SCEV *S)
Class for arbitrary precision integers.
Definition APInt.h:78
LLVM_ABI APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition APInt.cpp:1982
LLVM_ABI APInt zext(unsigned width) const
Zero extend to a new width.
Definition APInt.cpp:1023
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition APInt.h:424
uint64_t getZExtValue() const
Get zero extended value.
Definition APInt.h:1549
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition APInt.h:1400
LLVM_ABI APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition APInt.cpp:639
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition APInt.h:1521
LLVM_ABI APInt trunc(unsigned width) const
Truncate to new width.
Definition APInt.cpp:936
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition APInt.h:207
APInt abs() const
Get the absolute value.
Definition APInt.h:1804
bool sgt(const APInt &RHS) const
Signed greater than comparison.
Definition APInt.h:1202
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
Definition APInt.h:372
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition APInt.h:1183
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition APInt.h:381
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition APInt.h:467
LLVM_ABI APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition APInt.cpp:1677
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1497
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition APInt.h:1112
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition APInt.h:210
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition APInt.h:217
bool isNegative() const
Determine sign of this APInt.
Definition APInt.h:330
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition APInt.h:1167
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition APInt.h:220
bool isNonPositive() const
Determine if this APInt Value is non-positive (<= 0).
Definition APInt.h:362
unsigned countTrailingZeros() const
Definition APInt.h:1656
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition APInt.h:357
unsigned logBase2() const
Definition APInt.h:1770
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition APInt.h:828
LLVM_ABI APInt multiplicativeInverse() const
Definition APInt.cpp:1285
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition APInt.h:1151
LLVM_ABI APInt sext(unsigned width) const
Sign extend to a new width.
Definition APInt.cpp:996
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition APInt.h:874
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
Definition APInt.h:441
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition APInt.h:307
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition APInt.h:342
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition APInt.h:1131
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition APInt.h:201
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition APInt.h:433
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:240
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition APInt.h:1222
This templated class represents "all analyses that operate over <aparticular IR unit>" (e....
Definition Analysis.h:50
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
iterator end() const
Definition ArrayRef.h:131
size_t size() const
size - Get the array size.
Definition ArrayRef.h:142
iterator begin() const
Definition ArrayRef.h:130
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< WeakVH > assumptions()
Access the list of assumption handles currently tracked for this function.
LLVM_ABI bool isSingleEdge() const
Check if this is the only edge between Start and End.
LLVM Basic Block Representation.
Definition BasicBlock.h:62
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:470
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
LLVM_ABI const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const Instruction & front() const
Definition BasicBlock.h:493
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition BasicBlock.h:233
LLVM_ABI unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
LLVM_ABI Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
BinaryOps getOpcode() const
Definition InstrTypes.h:374
Conditional or Unconditional Branch instruction.
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Value * getCondition() const
This class represents a function call, abstracting a target machine's calling convention.
virtual void deleted()
Callback for Value destruction.
void setValPtr(Value *P)
bool isFalseWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:948
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:676
@ ICMP_SLT
signed less than
Definition InstrTypes.h:705
@ ICMP_SLE
signed less or equal
Definition InstrTypes.h:706
@ ICMP_UGE
unsigned greater or equal
Definition InstrTypes.h:700
@ ICMP_UGT
unsigned greater than
Definition InstrTypes.h:699
@ ICMP_SGT
signed greater than
Definition InstrTypes.h:703
@ ICMP_ULT
unsigned less than
Definition InstrTypes.h:701
@ ICMP_NE
not equal
Definition InstrTypes.h:698
@ ICMP_SGE
signed greater or equal
Definition InstrTypes.h:704
@ ICMP_ULE
unsigned less or equal
Definition InstrTypes.h:702
bool isSigned() const
Definition InstrTypes.h:930
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition InstrTypes.h:827
bool isTrueWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:942
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition InstrTypes.h:789
bool isUnsigned() const
Definition InstrTypes.h:936
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition InstrTypes.h:926
An abstraction over a floating-point predicate, and a pack of an integer predicate with samesign info...
static LLVM_ABI std::optional< CmpPredicate > getMatching(CmpPredicate A, CmpPredicate B)
Compares two CmpPredicates taking samesign into account and returns the canonicalized CmpPredicate if...
LLVM_ABI CmpInst::Predicate getPreferredSignedPredicate() const
Attempts to return a signed CmpInst::Predicate from the CmpPredicate.
CmpInst::Predicate dropSameSign() const
Drops samesign information.
static LLVM_ABI Constant * getNot(Constant *C)
static LLVM_ABI Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static LLVM_ABI Constant * getPtrToAddr(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static Constant * getGetElementPtr(Type *Ty, Constant *C, ArrayRef< Constant * > IdxList, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReducedTy=nullptr)
Getelementptr form.
Definition Constants.h:1284
static LLVM_ABI Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
static LLVM_ABI Constant * getNeg(Constant *C, bool HasNSW=false)
static LLVM_ABI Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
This is the shared class of boolean and integer constants.
Definition Constants.h:87
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:219
static LLVM_ABI ConstantInt * getFalse(LLVMContext &Context)
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition Constants.h:168
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition Constants.h:159
static LLVM_ABI ConstantInt * getBool(LLVMContext &Context, bool V)
This class represents a range of values.
LLVM_ABI ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
LLVM_ABI ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
LLVM_ABI bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
const APInt & getLower() const
Return the lower value for this range.
LLVM_ABI ConstantRange urem(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an unsigned remainder operation of...
LLVM_ABI bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
LLVM_ABI bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other?
LLVM_ABI bool isEmptySet() const
Return true if this set contains no members.
LLVM_ABI ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
LLVM_ABI bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
LLVM_ABI APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
LLVM_ABI bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
LLVM_ABI void print(raw_ostream &OS) const
Print out the bounds to a stream.
LLVM_ABI ConstantRange truncate(uint32_t BitWidth, unsigned NoWrapKind=0) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
LLVM_ABI ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
LLVM_ABI ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static LLVM_ABI ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
LLVM_ABI bool contains(const APInt &Val) const
Return true if the specified value is in the set.
LLVM_ABI APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
LLVM_ABI ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
LLVM_ABI APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
static LLVM_ABI ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
LLVM_ABI unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
LLVM_ABI ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
LLVM_ABI ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static LLVM_ABI ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
LLVM_ABI const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
LLVM_ABI IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
LLVM_ABI unsigned getIndexTypeSizeInBits(Type *Ty) const
The size in bits of the index used in GEP calculation for this type.
LLVM_ABI IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition DataLayout.h:771
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition DenseMap.h:205
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:178
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition DenseMap.h:256
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition DenseMap.h:74
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition DenseMap.h:191
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition DenseMap.h:174
iterator end()
Definition DenseMap.h:81
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition DenseMap.h:169
void swap(DerivedT &RHS)
Definition DenseMap.h:371
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:241
Analysis pass which computes a DominatorTree.
Definition Dominators.h:283
Legacy analysis pass which computes a DominatorTree.
Definition Dominators.h:321
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:164
LLVM_ABI bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
FoldingSetNodeIDRef - This class describes a reference to an interned FoldingSetNodeID,...
Definition FoldingSet.h:172
FoldingSetNodeID - This class is used to gather all the unique data bits of a node.
Definition FoldingSet.h:209
FunctionPass(char &pid)
Definition Pass.h:316
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static LLVM_ABI Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
static bool isPrivateLinkage(LinkageTypes Linkage)
static bool isInternalLinkage(LinkageTypes Linkage)
This instruction compares its operands according to the predicate given to the constructor.
CmpPredicate getCmpPredicate() const
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
CmpPredicate getSwappedCmpPredicate() const
static LLVM_ABI bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
CmpPredicate getInverseCmpPredicate() const
Predicate getNonStrictCmpPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
Predicate getFlippedSignednessPredicate() const
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ.
static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred)
static bool isEquality(Predicate P)
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
LLVM_ABI bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
LLVM_ABI bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
LLVM_ABI bool isIdenticalToWhenDefined(const Instruction *I, bool IntersectAttrs=false) const LLVM_READONLY
This is like isIdenticalTo, except that it ignores the SubclassOptionalData flags,...
Class to represent integer types.
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:318
A helper class to return the specified delimiter string after the first invocation of operator String...
An instruction for reading from memory.
Analysis pass that exposes the LoopInfo for a function.
Definition LoopInfo.h:569
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition LoopInfo.h:596
Represents a single loop in the control flow graph.
Definition LoopInfo.h:40
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition LoopInfo.cpp:61
Metadata node.
Definition Metadata.h:1078
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition Operator.h:43
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition Operator.h:78
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition Operator.h:111
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition Operator.h:105
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
LLVM_ABI void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
LLVM_ABI const SCEVPredicate & getPredicate() const
LLVM_ABI const SCEV * getPredicatedSCEV(const SCEV *Expr)
Returns the rewritten SCEV for Expr in the context of the current SCEV predicate.
LLVM_ABI bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
LLVM_ABI void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
LLVM_ABI void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
LLVM_ABI bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
LLVM_ABI PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
LLVM_ABI const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
LLVM_ABI unsigned getSmallConstantMaxTripCount()
Returns the upper bound of the loop trip count as a normal unsigned value, or 0 if the trip count is ...
LLVM_ABI const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition Analysis.h:275
constexpr bool isValid() const
Definition Register.h:112
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
LLVM_ABI const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
LLVM_ABI const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
LLVM_ABI const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
This is the base class for unary cast operator classes.
const SCEV * getOperand() const
LLVM_ABI SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Implementation of the SCEVPredicate interface.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
LLVM_ABI SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
const SCEV * getOperand(unsigned i) const
ArrayRef< const SCEV * > operands() const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
SCEVPredicate(const SCEVPredicate &)=default
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const =0
Returns true if this predicate implies N.
SCEVPredicateKind Kind
This class represents a cast from a pointer to a pointer-sized integer value, without capturing the p...
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty)
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds, ScalarEvolution &SE)
Union predicates don't get cached so create a dummy set ID for it.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
LLVM_ABI ArrayRef< const SCEV * > operands() const
Return operands of this SCEV expression.
unsigned short getExpressionSize() const
LLVM_ABI bool isOne() const
Return true if the expression is a constant one.
LLVM_ABI bool isZero() const
Return true if the expression is a constant zero.
LLVM_ABI void dump() const
This method is used for debugging.
LLVM_ABI bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
LLVM_ABI bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
LLVM_ABI void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEV(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, unsigned short ExpressionSize)
SCEVTypes getSCEVType() const
LLVM_ABI Type * getType() const
Return the LLVM type of this SCEV expression.
NoWrapFlags
NoWrapFlags are bitfield indices into SubclassData.
Analysis pass that exposes the ScalarEvolution for a function.
LLVM_ABI ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LLVM_ABI LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
LLVM_ABI const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
LLVM_ABI bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
LLVM_ABI bool isKnownOnEveryIteration(CmpPredicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
LLVM_ABI const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
LLVM_ABI const SCEV * getSMaxExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
LLVM_ABI Type * getWiderType(Type *Ty1, Type *Ty2) const
LLVM_ABI const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
LLVM_ABI bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
LLVM_ABI const SCEV * getURemExpr(const SCEV *LHS, const SCEV *RHS)
Represents an unsigned remainder expression based on unsigned division.
LLVM_ABI bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
LLVM_ABI const SCEV * getPredicatedConstantMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getConstantMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
LLVM_ABI bool isLoopEntryGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
LLVM_ABI bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
LLVM_ABI const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
LLVM_ABI const SCEV * getSMinExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
LLVM_ABI const SCEV * getUMaxExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
LLVM_ABI const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
LLVM_ABI bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?
LLVM_ABI ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
LLVM_ABI const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
LLVM_ABI uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
LLVM_ABI const SCEV * getConstant(ConstantInt *V)
LLVM_ABI const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
LLVM_ABI const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
LLVM_ABI ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
LLVM_ABI const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI const SCEV * getLosslessPtrToIntExpr(const SCEV *Op)
LLVM_ABI const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
LLVM_ABI const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
LLVM_ABI std::optional< bool > evaluatePredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
LLVM_ABI unsigned getSmallConstantMaxTripCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > *Predicates=nullptr)
Returns the upper bound of the loop trip count as a normal unsigned value.
LLVM_ABI const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
LLVM_ABI bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
LLVM_ABI void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
LLVM_ABI bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
LLVM_ABI bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
LLVM_ABI bool SimplifyICmpOperands(CmpPredicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
LLVM_ABI const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LLVM_ABI LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
LLVM_ABI bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
LLVM_ABI const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
LLVM_ABI bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
LLVM_ABI const SCEV * getUDivExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
LLVM_ABI Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
LLVM_ABI const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
LLVM_ABI bool haveSameSign(const SCEV *S1, const SCEV *S2)
Return true if we know that S1 and S2 must have the same sign.
LLVM_ABI const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
LLVM_ABI const SCEV * getElementCount(Type *Ty, ElementCount EC, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
LLVM_ABI bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
LLVM_ABI void print(raw_ostream &OS) const
LLVM_ABI const SCEV * getUMinExpr(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
LLVM_ABI const SCEV * getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl< const SCEVPredicate * > *Predicates, ExitCountKind Kind=Exact)
Same as above except this uses the predicated backedge taken info and may require predicates.
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
LLVM_ABI void forgetTopmostLoop(const Loop *L)
LLVM_ABI void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
LLVM_ABI const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
LLVM_ABI const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
LLVM_ABI const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
LLVM_ABI bool isLoopBackedgeGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
LLVM_ABI APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
LLVM_ABI bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
LLVM_ABI BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
LLVM_ABI const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
LLVM_ABI const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
LLVM_ABI bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
LLVM_ABI const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LLVM_ABI APInt getConstantMultiple(const SCEV *S, const Instruction *CtxI=nullptr)
Returns the max constant multiple of S.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
LLVM_ABI bool isKnownMultipleOf(const SCEV *S, uint64_t M, SmallVectorImpl< const SCEVPredicate * > &Assumptions)
Check that S is a multiple of M.
LLVM_ABI const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
LLVM_ABI bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero=false, bool OrNegative=false)
Test if the given expression is known to be a power of 2.
LLVM_ABI std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
LLVM_ABI void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
LLVM_ABI bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
LLVM_ABI std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
LLVM_ABI const SCEV * getCouldNotCompute()
LLVM_ABI bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S, const Instruction *CtxI=nullptr)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
LLVM_ABI const SCEV * getGEPExpr(GEPOperator *GEP, ArrayRef< const SCEV * > IndexExprs)
Returns an expression for a GEP.
LLVM_ABI const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
LLVM_ABI const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
LLVM_ABI void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
LLVM_ABI const SCEV * getVScale(Type *Ty)
LLVM_ABI unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
LLVM_ABI bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
LLVM_ABI const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
LLVM_ABI const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
LLVM_ABI void forgetAllLoops()
LLVM_ABI bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
LLVM_ABI const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
LLVM_ABI const SCEV * getMulExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
LLVM_ABI const SCEV * getPtrToAddrExpr(const SCEV *Op)
LLVM_ABI const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
LLVM_ABI const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
LLVM_ABI const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
LLVM_ABI std::optional< bool > evaluatePredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
LLVM_ABI const SCEV * getUnknown(Value *V)
LLVM_ABI std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
LLVM_ABI const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, int Mask)
Convenient NoWrapFlags manipulation that hides enum casts and is visible in the ScalarEvolution name ...
LLVM_ABI std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
LLVM_ABI bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
LLVM_ABI const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
LLVM_ABI std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
LLVM_ABI bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
LLVM_ABI bool isKnownPredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * getUDivExactExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
LLVM_ABI const SCEV * getAddExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
LLVM_ABI bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
LLVM_ABI bool isKnownPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI bool isKnownViaInduction(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVM_ABI void verify() const
LLVMContext & getContext() const
Implements a dense probed hash-table based set with some number of buckets stored inline.
Definition DenseSet.h:291
size_type size() const
Definition SmallPtrSet.h:99
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
reference emplace_back(ArgTypes &&... Args)
void reserve(size_type N)
iterator erase(const_iterator CI)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
iterator insert(iterator I, T &&Elt)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition DataLayout.h:723
TypeSize getElementOffset(unsigned Idx) const
Definition DataLayout.h:754
TypeSize getSizeInBits() const
Definition DataLayout.h:734
Class to represent struct types.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:296
bool isPointerTy() const
True if this is an instance of PointerType.
Definition Type.h:267
static LLVM_ABI IntegerType * getInt8Ty(LLVMContext &C)
Definition Type.cpp:294
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:197
static LLVM_ABI IntegerType * getInt1Ty(LLVMContext &C)
Definition Type.cpp:293
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition Type.h:255
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:240
static LLVM_ABI IntegerType * getIntNTy(LLVMContext &C, unsigned N)
Definition Type.cpp:300
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
op_range operands()
Definition User.h:267
Use & Op()
Definition User.h:171
Value * getOperand(unsigned i) const
Definition User.h:207
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition Value.h:543
LLVM_ABI void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
LLVM_ABI LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.cpp:1106
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:322
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition TypeSize.h:168
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
Definition ilist_node.h:34
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition APInt.h:2257
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition APInt.h:2262
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition APInt.h:2267
LLVM_ABI std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition APInt.cpp:2823
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition APInt.h:2272
LLVM_ABI APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition APInt.cpp:798
@ Entry
Definition COFF.h:862
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
int getMinValue(MCInstrInfo const &MCII, MCInst const &MCI)
Return the minimum value of an extendable operand.
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI Function * getDeclarationIfExists(const Module *M, ID id)
Look up the Function declaration of the intrinsic id in the Module M and return it if it exists.
Predicate
Predicate - These are "(BI << 5) | BO" for various predicates.
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
ap_match< APInt > m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
ExtractValue_match< Ind, Val_t > m_ExtractValue(const Val_t &V)
Match a single index ExtractValue instruction.
bind_ty< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
brc_match< Cond_t, bind_ty< BasicBlock >, bind_ty< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
class_match< BasicBlock > m_BasicBlock()
Match an arbitrary basic block value and ignore it.
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
class_match< const SCEVVScale > m_SCEVVScale()
bind_cst_ty m_scev_APInt(const APInt *&C)
Match an SCEV constant and bind it to an APInt.
cst_pred_ty< is_all_ones > m_scev_AllOnes()
Match an integer with all bits set.
SCEVUnaryExpr_match< SCEVZeroExtendExpr, Op0_t > m_scev_ZExt(const Op0_t &Op0)
is_undef_or_poison m_scev_UndefOrPoison()
Match an SCEVUnknown wrapping undef or poison.
class_match< const SCEVConstant > m_SCEVConstant()
cst_pred_ty< is_one > m_scev_One()
Match an integer 1.
specificloop_ty m_SpecificLoop(const Loop *L)
SCEVAffineAddRec_match< Op0_t, Op1_t, class_match< const Loop > > m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1)
bind_ty< const SCEVMulExpr > m_scev_Mul(const SCEVMulExpr *&V)
SCEVUnaryExpr_match< SCEVSignExtendExpr, Op0_t > m_scev_SExt(const Op0_t &Op0)
cst_pred_ty< is_zero > m_scev_Zero()
Match an integer 0.
SCEVUnaryExpr_match< SCEVTruncateExpr, Op0_t > m_scev_Trunc(const Op0_t &Op0)
bool match(const SCEV *S, const Pattern &P)
SCEVBinaryExpr_match< SCEVUDivExpr, Op0_t, Op1_t > m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1)
specificscev_ty m_scev_Specific(const SCEV *S)
Match if we have a specific specified SCEV.
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagNUW, true > m_scev_c_NUWMul(const Op0_t &Op0, const Op1_t &Op1)
class_match< const Loop > m_Loop()
bind_ty< const SCEVAddExpr > m_scev_Add(const SCEVAddExpr *&V)
bind_ty< const SCEVUnknown > m_SCEVUnknown(const SCEVUnknown *&V)
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagAnyWrap, true > m_scev_c_Mul(const Op0_t &Op0, const Op1_t &Op1)
SCEVBinaryExpr_match< SCEVSMaxExpr, Op0_t, Op1_t > m_scev_SMax(const Op0_t &Op0, const Op1_t &Op1)
SCEVURem_match< Op0_t, Op1_t > m_scev_URem(Op0_t LHS, Op1_t RHS, ScalarEvolution &SE)
Match the mathematical pattern A - (A / B) * B, where A and B can be arbitrary expressions.
class_match< const SCEV > m_SCEV()
@ Valid
The data is already valid.
initializer< Ty > init(const Ty &Val)
LocationClass< Ty > location(Ty &L)
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
Definition CoroShape.h:31
constexpr double e
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
Definition Types.h:26
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:316
@ Offset
Definition DWP.cpp:532
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
void stable_sort(R &&Range)
Definition STLExtras.h:2106
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1737
SaveAndRestore(T &) -> SaveAndRestore< T >
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr, unsigned DynamicVGPRBlockSize=0)
LLVM_ABI bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
LLVM_ABI bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
LLVM_ABI bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
InterleavedRange< Range > interleaved(const Range &R, StringRef Separator=", ", StringRef Prefix="", StringRef Suffix="")
Output range R as a sequence of interleaved elements.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
LLVM_ABI bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
auto successors(const MachineBasicBlock *BB)
scope_exit(Callable) -> scope_exit< Callable >
constexpr from_range_t from_range
auto dyn_cast_if_present(const Y &Val)
dyn_cast_if_present<X> - Functionally identical to dyn_cast, except that a null (or none in the case ...
Definition Casting.h:732
bool set_is_subset(const S1Ty &S1, const S2Ty &S2)
set_is_subset(A, B) - Return true iff A in B
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition STLExtras.h:2198
constexpr bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition MathExtras.h:243
LLVM_ABI Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
unsigned short computeExpressionSize(ArrayRef< const SCEV * > Args)
void * PointerTy
LLVM_ABI bool VerifySCEV
auto uninitialized_copy(R &&Src, IterTy Dst)
Definition STLExtras.h:2101
bool isa_and_nonnull(const Y &Val)
Definition Casting.h:676
LLVM_ABI ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition bit.h:202
LLVM_ABI Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
LLVM_ABI bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
DomTreeNodeBase< BasicBlock > DomTreeNode
Definition Dominators.h:94
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition STLExtras.h:2190
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1744
iterator_range< pointee_iterator< WrappedIteratorT > > make_pointee_range(RangeT &&Range)
Definition iterator.h:341
auto reverse(ContainerTy &&C)
Definition STLExtras.h:406
LLVM_ABI bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
LLVM_ABI bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
LLVM_ABI bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
LLVM_ABI void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
LLVM_ABI bool programUndefinedIfPoison(const Instruction *Inst)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
bool isPointerTy(const Type *T)
Definition SPIRVUtils.h:361
FunctionAddr VTableAddr Count
Definition InstrProf.h:139
LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
LLVM_ABI bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition STLExtras.h:2002
DWARFExpression::Operation Op
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition STLExtras.h:2078
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
ArrayRef(const T &OneElt) -> ArrayRef< T >
LLVM_ABI unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Return the number of times the sign bit of the register is replicated into the other bits.
constexpr unsigned BitWidth
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1915
LLVM_ABI bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition STLExtras.h:2009
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
constexpr bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition MathExtras.h:248
auto predecessors(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1945
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition Sequence.h:305
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
LLVM_ABI Constant * ConstantFoldInstOperands(const Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:870
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:872
#define N
#define NC
Definition regutils.h:42
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition Analysis.h:29
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition KnownBits.h:304
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition KnownBits.h:108
static LLVM_ABI KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
unsigned getBitWidth() const
Get the bit width of this value.
Definition KnownBits.h:44
static LLVM_ABI KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition KnownBits.h:199
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition KnownBits.h:148
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition KnownBits.h:132
bool isNegative() const
Returns true if this value is known to be negative.
Definition KnownBits.h:105
static LLVM_ABI KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
An object of this class is returned by queries that could not be answered.
static LLVM_ABI bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
LLVM_ABI ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
SmallVector< const SCEVPredicate *, 4 > Predicates
A vector of predicate guards for this ExitLimit.