LLVM 23.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
65#include "llvm/ADT/FoldingSet.h"
66#include "llvm/ADT/STLExtras.h"
67#include "llvm/ADT/ScopeExit.h"
68#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/Statistic.h"
73#include "llvm/ADT/StringRef.h"
83#include "llvm/Config/llvm-config.h"
84#include "llvm/IR/Argument.h"
85#include "llvm/IR/BasicBlock.h"
86#include "llvm/IR/CFG.h"
87#include "llvm/IR/Constant.h"
89#include "llvm/IR/Constants.h"
90#include "llvm/IR/DataLayout.h"
92#include "llvm/IR/Dominators.h"
93#include "llvm/IR/Function.h"
94#include "llvm/IR/GlobalAlias.h"
95#include "llvm/IR/GlobalValue.h"
97#include "llvm/IR/InstrTypes.h"
98#include "llvm/IR/Instruction.h"
101#include "llvm/IR/Intrinsics.h"
102#include "llvm/IR/LLVMContext.h"
103#include "llvm/IR/Operator.h"
104#include "llvm/IR/PatternMatch.h"
105#include "llvm/IR/Type.h"
106#include "llvm/IR/Use.h"
107#include "llvm/IR/User.h"
108#include "llvm/IR/Value.h"
109#include "llvm/IR/Verifier.h"
111#include "llvm/Pass.h"
112#include "llvm/Support/Casting.h"
115#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136using namespace SCEVPatternMatch;
137
138#define DEBUG_TYPE "scalar-evolution"
139
140STATISTIC(NumExitCountsComputed,
141 "Number of loop exits with predictable exit counts");
142STATISTIC(NumExitCountsNotComputed,
143 "Number of loop exits without predictable exit counts");
144STATISTIC(NumBruteForceTripCountsComputed,
145 "Number of loops with trip counts computed by force");
146
147#ifdef EXPENSIVE_CHECKS
148bool llvm::VerifySCEV = true;
149#else
150bool llvm::VerifySCEV = false;
151#endif
152
154 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
155 cl::desc("Maximum number of iterations SCEV will "
156 "symbolically execute a constant "
157 "derived loop"),
158 cl::init(100));
159
161 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
162 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
164 "verify-scev-strict", cl::Hidden,
165 cl::desc("Enable stricter verification with -verify-scev is passed"));
166
168 "scev-verify-ir", cl::Hidden,
169 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
170 cl::init(false));
171
173 "scev-mulops-inline-threshold", cl::Hidden,
174 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
175 cl::init(32));
176
178 "scev-addops-inline-threshold", cl::Hidden,
179 cl::desc("Threshold for inlining addition operands into a SCEV"),
180 cl::init(500));
181
183 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
184 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
185 cl::init(32));
186
188 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
189 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
190 cl::init(2));
191
193 "scalar-evolution-max-value-compare-depth", cl::Hidden,
194 cl::desc("Maximum depth of recursive value complexity comparisons"),
195 cl::init(2));
196
198 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
199 cl::desc("Maximum depth of recursive arithmetics"),
200 cl::init(32));
201
203 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
204 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
205
207 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
208 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
209 cl::init(8));
210
212 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
213 cl::desc("Max coefficients in AddRec during evolving"),
214 cl::init(8));
215
217 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
218 cl::desc("Size of the expression which is considered huge"),
219 cl::init(4096));
220
222 "scev-range-iter-threshold", cl::Hidden,
223 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
224 cl::init(32));
225
227 "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
228 cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1));
229
230static cl::opt<bool>
231ClassifyExpressions("scalar-evolution-classify-expressions",
232 cl::Hidden, cl::init(true),
233 cl::desc("When printing analysis, include information on every instruction"));
234
236 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
237 cl::init(false),
238 cl::desc("Use more powerful methods of sharpening expression ranges. May "
239 "be costly in terms of compile time"));
240
242 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
243 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
244 "Phi strongly connected components"),
245 cl::init(8));
246
247static cl::opt<bool>
248 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
249 cl::desc("Handle <= and >= in finite loops"),
250 cl::init(true));
251
253 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
254 cl::desc("Infer nuw/nsw flags using context where suitable"),
255 cl::init(true));
256
257//===----------------------------------------------------------------------===//
258// SCEV class definitions
259//===----------------------------------------------------------------------===//
260
261//===----------------------------------------------------------------------===//
262// Implementation of the SCEV class.
263//
264
265#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
267 print(dbgs());
268 dbgs() << '\n';
269}
270#endif
271
272void SCEV::print(raw_ostream &OS) const {
273 switch (getSCEVType()) {
274 case scConstant:
275 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
276 return;
277 case scVScale:
278 OS << "vscale";
279 return;
280 case scPtrToAddr:
281 case scPtrToInt: {
282 const SCEVCastExpr *PtrCast = cast<SCEVCastExpr>(this);
283 const SCEV *Op = PtrCast->getOperand();
284 StringRef OpS = getSCEVType() == scPtrToAddr ? "addr" : "int";
285 OS << "(ptrto" << OpS << " " << *Op->getType() << " " << *Op << " to "
286 << *PtrCast->getType() << ")";
287 return;
288 }
289 case scTruncate: {
290 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
291 const SCEV *Op = Trunc->getOperand();
292 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
293 << *Trunc->getType() << ")";
294 return;
295 }
296 case scZeroExtend: {
298 const SCEV *Op = ZExt->getOperand();
299 OS << "(zext " << *Op->getType() << " " << *Op << " to "
300 << *ZExt->getType() << ")";
301 return;
302 }
303 case scSignExtend: {
305 const SCEV *Op = SExt->getOperand();
306 OS << "(sext " << *Op->getType() << " " << *Op << " to "
307 << *SExt->getType() << ")";
308 return;
309 }
310 case scAddRecExpr: {
311 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
312 OS << "{" << *AR->getOperand(0);
313 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
314 OS << ",+," << *AR->getOperand(i);
315 OS << "}<";
316 if (AR->hasNoUnsignedWrap())
317 OS << "nuw><";
318 if (AR->hasNoSignedWrap())
319 OS << "nsw><";
320 if (AR->hasNoSelfWrap() &&
322 OS << "nw><";
323 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
324 OS << ">";
325 return;
326 }
327 case scAddExpr:
328 case scMulExpr:
329 case scUMaxExpr:
330 case scSMaxExpr:
331 case scUMinExpr:
332 case scSMinExpr:
334 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
335 const char *OpStr = nullptr;
336 switch (NAry->getSCEVType()) {
337 case scAddExpr: OpStr = " + "; break;
338 case scMulExpr: OpStr = " * "; break;
339 case scUMaxExpr: OpStr = " umax "; break;
340 case scSMaxExpr: OpStr = " smax "; break;
341 case scUMinExpr:
342 OpStr = " umin ";
343 break;
344 case scSMinExpr:
345 OpStr = " smin ";
346 break;
348 OpStr = " umin_seq ";
349 break;
350 default:
351 llvm_unreachable("There are no other nary expression types.");
352 }
353 OS << "("
355 << ")";
356 switch (NAry->getSCEVType()) {
357 case scAddExpr:
358 case scMulExpr:
359 if (NAry->hasNoUnsignedWrap())
360 OS << "<nuw>";
361 if (NAry->hasNoSignedWrap())
362 OS << "<nsw>";
363 break;
364 default:
365 // Nothing to print for other nary expressions.
366 break;
367 }
368 return;
369 }
370 case scUDivExpr: {
371 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
372 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
373 return;
374 }
375 case scUnknown:
376 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
377 return;
379 OS << "***COULDNOTCOMPUTE***";
380 return;
381 }
382 llvm_unreachable("Unknown SCEV kind!");
383}
384
386 switch (getSCEVType()) {
387 case scConstant:
388 return cast<SCEVConstant>(this)->getType();
389 case scVScale:
390 return cast<SCEVVScale>(this)->getType();
391 case scPtrToAddr:
392 case scPtrToInt:
393 case scTruncate:
394 case scZeroExtend:
395 case scSignExtend:
396 return cast<SCEVCastExpr>(this)->getType();
397 case scAddRecExpr:
398 return cast<SCEVAddRecExpr>(this)->getType();
399 case scMulExpr:
400 return cast<SCEVMulExpr>(this)->getType();
401 case scUMaxExpr:
402 case scSMaxExpr:
403 case scUMinExpr:
404 case scSMinExpr:
405 return cast<SCEVMinMaxExpr>(this)->getType();
407 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
408 case scAddExpr:
409 return cast<SCEVAddExpr>(this)->getType();
410 case scUDivExpr:
411 return cast<SCEVUDivExpr>(this)->getType();
412 case scUnknown:
413 return cast<SCEVUnknown>(this)->getType();
415 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
416 }
417 llvm_unreachable("Unknown SCEV kind!");
418}
419
421 switch (getSCEVType()) {
422 case scConstant:
423 case scVScale:
424 case scUnknown:
425 return {};
426 case scPtrToAddr:
427 case scPtrToInt:
428 case scTruncate:
429 case scZeroExtend:
430 case scSignExtend:
431 return cast<SCEVCastExpr>(this)->operands();
432 case scAddRecExpr:
433 case scAddExpr:
434 case scMulExpr:
435 case scUMaxExpr:
436 case scSMaxExpr:
437 case scUMinExpr:
438 case scSMinExpr:
440 return cast<SCEVNAryExpr>(this)->operands();
441 case scUDivExpr:
442 return cast<SCEVUDivExpr>(this)->operands();
444 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
445 }
446 llvm_unreachable("Unknown SCEV kind!");
447}
448
449bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
450
451bool SCEV::isOne() const { return match(this, m_scev_One()); }
452
453bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
454
457 if (!Mul) return false;
458
459 // If there is a constant factor, it will be first.
460 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
461 if (!SC) return false;
462
463 // Return true if the value is negative, this matches things like (-42 * V).
464 return SC->getAPInt().isNegative();
465}
466
469
471 return S->getSCEVType() == scCouldNotCompute;
472}
473
476 ID.AddInteger(scConstant);
477 ID.AddPointer(V);
478 void *IP = nullptr;
479 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
480 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
481 UniqueSCEVs.InsertNode(S, IP);
482 return S;
483}
484
486 return getConstant(ConstantInt::get(getContext(), Val));
487}
488
489const SCEV *
492 // TODO: Avoid implicit trunc?
493 // See https://github.com/llvm/llvm-project/issues/112510.
494 return getConstant(
495 ConstantInt::get(ITy, V, isSigned, /*ImplicitTrunc=*/true));
496}
497
500 ID.AddInteger(scVScale);
501 ID.AddPointer(Ty);
502 void *IP = nullptr;
503 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
504 return S;
505 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
506 UniqueSCEVs.InsertNode(S, IP);
507 return S;
508}
509
511 SCEV::NoWrapFlags Flags) {
512 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
513 if (EC.isScalable())
514 Res = getMulExpr(Res, getVScale(Ty), Flags);
515 return Res;
516}
517
519 const SCEV *op, Type *ty)
520 : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
521
522SCEVPtrToAddrExpr::SCEVPtrToAddrExpr(const FoldingSetNodeIDRef ID,
523 const SCEV *Op, Type *ITy)
524 : SCEVCastExpr(ID, scPtrToAddr, Op, ITy) {
525 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
526 "Must be a non-bit-width-changing pointer-to-integer cast!");
527}
528
529SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
530 Type *ITy)
531 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
532 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
533 "Must be a non-bit-width-changing pointer-to-integer cast!");
534}
535
537 SCEVTypes SCEVTy, const SCEV *op,
538 Type *ty)
539 : SCEVCastExpr(ID, SCEVTy, op, ty) {}
540
541SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
542 Type *ty)
544 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
545 "Cannot truncate non-integer value!");
546}
547
548SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
549 const SCEV *op, Type *ty)
551 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
552 "Cannot zero extend non-integer value!");
553}
554
555SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
556 const SCEV *op, Type *ty)
558 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
559 "Cannot sign extend non-integer value!");
560}
561
563 // Clear this SCEVUnknown from various maps.
564 SE->forgetMemoizedResults(this);
565
566 // Remove this SCEVUnknown from the uniquing map.
567 SE->UniqueSCEVs.RemoveNode(this);
568
569 // Release the value.
570 setValPtr(nullptr);
571}
572
573void SCEVUnknown::allUsesReplacedWith(Value *New) {
574 // Clear this SCEVUnknown from various maps.
575 SE->forgetMemoizedResults(this);
576
577 // Remove this SCEVUnknown from the uniquing map.
578 SE->UniqueSCEVs.RemoveNode(this);
579
580 // Replace the value pointer in case someone is still using this SCEVUnknown.
581 setValPtr(New);
582}
583
584//===----------------------------------------------------------------------===//
585// SCEV Utilities
586//===----------------------------------------------------------------------===//
587
588/// Compare the two values \p LV and \p RV in terms of their "complexity" where
589/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
590/// operands in SCEV expressions.
591static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
592 Value *RV, unsigned Depth) {
594 return 0;
595
596 // Order pointer values after integer values. This helps SCEVExpander form
597 // GEPs.
598 bool LIsPointer = LV->getType()->isPointerTy(),
599 RIsPointer = RV->getType()->isPointerTy();
600 if (LIsPointer != RIsPointer)
601 return (int)LIsPointer - (int)RIsPointer;
602
603 // Compare getValueID values.
604 unsigned LID = LV->getValueID(), RID = RV->getValueID();
605 if (LID != RID)
606 return (int)LID - (int)RID;
607
608 // Sort arguments by their position.
609 if (const auto *LA = dyn_cast<Argument>(LV)) {
610 const auto *RA = cast<Argument>(RV);
611 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
612 return (int)LArgNo - (int)RArgNo;
613 }
614
615 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
616 const auto *RGV = cast<GlobalValue>(RV);
617
618 if (auto L = LGV->getLinkage() - RGV->getLinkage())
619 return L;
620
621 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
622 auto LT = GV->getLinkage();
623 return !(GlobalValue::isPrivateLinkage(LT) ||
625 };
626
627 // Use the names to distinguish the two values, but only if the
628 // names are semantically important.
629 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
630 return LGV->getName().compare(RGV->getName());
631 }
632
633 // For instructions, compare their loop depth, and their operand count. This
634 // is pretty loose.
635 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
636 const auto *RInst = cast<Instruction>(RV);
637
638 // Compare loop depths.
639 const BasicBlock *LParent = LInst->getParent(),
640 *RParent = RInst->getParent();
641 if (LParent != RParent) {
642 unsigned LDepth = LI->getLoopDepth(LParent),
643 RDepth = LI->getLoopDepth(RParent);
644 if (LDepth != RDepth)
645 return (int)LDepth - (int)RDepth;
646 }
647
648 // Compare the number of operands.
649 unsigned LNumOps = LInst->getNumOperands(),
650 RNumOps = RInst->getNumOperands();
651 if (LNumOps != RNumOps)
652 return (int)LNumOps - (int)RNumOps;
653
654 for (unsigned Idx : seq(LNumOps)) {
655 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
656 RInst->getOperand(Idx), Depth + 1);
657 if (Result != 0)
658 return Result;
659 }
660 }
661
662 return 0;
663}
664
665// Return negative, zero, or positive, if LHS is less than, equal to, or greater
666// than RHS, respectively. A three-way result allows recursive comparisons to be
667// more efficient.
668// If the max analysis depth was reached, return std::nullopt, assuming we do
669// not know if they are equivalent for sure.
670static std::optional<int>
671CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
672 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
673 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
674 if (LHS == RHS)
675 return 0;
676
677 // Primarily, sort the SCEVs by their getSCEVType().
678 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
679 if (LType != RType)
680 return (int)LType - (int)RType;
681
683 return std::nullopt;
684
685 // Aside from the getSCEVType() ordering, the particular ordering
686 // isn't very important except that it's beneficial to be consistent,
687 // so that (a + b) and (b + a) don't end up as different expressions.
688 switch (LType) {
689 case scUnknown: {
690 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
691 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
692
693 int X =
694 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
695 return X;
696 }
697
698 case scConstant: {
701
702 // Compare constant values.
703 const APInt &LA = LC->getAPInt();
704 const APInt &RA = RC->getAPInt();
705 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
706 if (LBitWidth != RBitWidth)
707 return (int)LBitWidth - (int)RBitWidth;
708 return LA.ult(RA) ? -1 : 1;
709 }
710
711 case scVScale: {
712 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
713 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
714 return LTy->getBitWidth() - RTy->getBitWidth();
715 }
716
717 case scAddRecExpr: {
720
721 // There is always a dominance between two recs that are used by one SCEV,
722 // so we can safely sort recs by loop header dominance. We require such
723 // order in getAddExpr.
724 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
725 if (LLoop != RLoop) {
726 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
727 assert(LHead != RHead && "Two loops share the same header?");
728 if (DT.dominates(LHead, RHead))
729 return 1;
730 assert(DT.dominates(RHead, LHead) &&
731 "No dominance between recurrences used by one SCEV?");
732 return -1;
733 }
734
735 [[fallthrough]];
736 }
737
738 case scTruncate:
739 case scZeroExtend:
740 case scSignExtend:
741 case scPtrToAddr:
742 case scPtrToInt:
743 case scAddExpr:
744 case scMulExpr:
745 case scUDivExpr:
746 case scSMaxExpr:
747 case scUMaxExpr:
748 case scSMinExpr:
749 case scUMinExpr:
751 ArrayRef<const SCEV *> LOps = LHS->operands();
752 ArrayRef<const SCEV *> ROps = RHS->operands();
753
754 // Lexicographically compare n-ary-like expressions.
755 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
756 if (LNumOps != RNumOps)
757 return (int)LNumOps - (int)RNumOps;
758
759 for (unsigned i = 0; i != LNumOps; ++i) {
760 auto X = CompareSCEVComplexity(LI, LOps[i], ROps[i], DT, Depth + 1);
761 if (X != 0)
762 return X;
763 }
764 return 0;
765 }
766
768 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
769 }
770 llvm_unreachable("Unknown SCEV kind!");
771}
772
773/// Given a list of SCEV objects, order them by their complexity, and group
774/// objects of the same complexity together by value. When this routine is
775/// finished, we know that any duplicates in the vector are consecutive and that
776/// complexity is monotonically increasing.
777///
778/// Note that we go take special precautions to ensure that we get deterministic
779/// results from this routine. In other words, we don't want the results of
780/// this to depend on where the addresses of various SCEV objects happened to
781/// land in memory.
783 LoopInfo *LI, DominatorTree &DT) {
784 if (Ops.size() < 2) return; // Noop
785
786 // Whether LHS has provably less complexity than RHS.
787 auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
788 auto Complexity = CompareSCEVComplexity(LI, LHS, RHS, DT);
789 return Complexity && *Complexity < 0;
790 };
791 if (Ops.size() == 2) {
792 // This is the common case, which also happens to be trivially simple.
793 // Special case it.
794 const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
795 if (IsLessComplex(RHS, LHS))
796 std::swap(LHS, RHS);
797 return;
798 }
799
800 // Do the rough sort by complexity.
801 llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
802 return IsLessComplex(LHS, RHS);
803 });
804
805 // Now that we are sorted by complexity, group elements of the same
806 // complexity. Note that this is, at worst, N^2, but the vector is likely to
807 // be extremely short in practice. Note that we take this approach because we
808 // do not want to depend on the addresses of the objects we are grouping.
809 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
810 const SCEV *S = Ops[i];
811 unsigned Complexity = S->getSCEVType();
812
813 // If there are any objects of the same complexity and same value as this
814 // one, group them.
815 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
816 if (Ops[j] == S) { // Found a duplicate.
817 // Move it to immediately after i'th element.
818 std::swap(Ops[i+1], Ops[j]);
819 ++i; // no need to rescan it.
820 if (i == e-2) return; // Done!
821 }
822 }
823 }
824}
825
826/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
827/// least HugeExprThreshold nodes).
829 return any_of(Ops, [](const SCEV *S) {
831 });
832}
833
834/// Performs a number of common optimizations on the passed \p Ops. If the
835/// whole expression reduces down to a single operand, it will be returned.
836///
837/// The following optimizations are performed:
838/// * Fold constants using the \p Fold function.
839/// * Remove identity constants satisfying \p IsIdentity.
840/// * If a constant satisfies \p IsAbsorber, return it.
841/// * Sort operands by complexity.
842template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
843static const SCEV *
846 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
847 const SCEVConstant *Folded = nullptr;
848 for (unsigned Idx = 0; Idx < Ops.size();) {
849 const SCEV *Op = Ops[Idx];
850 if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
851 if (!Folded)
852 Folded = C;
853 else
854 Folded = cast<SCEVConstant>(
855 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
856 Ops.erase(Ops.begin() + Idx);
857 continue;
858 }
859 ++Idx;
860 }
861
862 if (Ops.empty()) {
863 assert(Folded && "Must have folded value");
864 return Folded;
865 }
866
867 if (Folded && IsAbsorber(Folded->getAPInt()))
868 return Folded;
869
870 GroupByComplexity(Ops, &LI, DT);
871 if (Folded && !IsIdentity(Folded->getAPInt()))
872 Ops.insert(Ops.begin(), Folded);
873
874 return Ops.size() == 1 ? Ops[0] : nullptr;
875}
876
877//===----------------------------------------------------------------------===//
878// Simple SCEV method implementations
879//===----------------------------------------------------------------------===//
880
881/// Compute BC(It, K). The result has width W. Assume, K > 0.
882static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
883 ScalarEvolution &SE,
884 Type *ResultTy) {
885 // Handle the simplest case efficiently.
886 if (K == 1)
887 return SE.getTruncateOrZeroExtend(It, ResultTy);
888
889 // We are using the following formula for BC(It, K):
890 //
891 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
892 //
893 // Suppose, W is the bitwidth of the return value. We must be prepared for
894 // overflow. Hence, we must assure that the result of our computation is
895 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
896 // safe in modular arithmetic.
897 //
898 // However, this code doesn't use exactly that formula; the formula it uses
899 // is something like the following, where T is the number of factors of 2 in
900 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
901 // exponentiation:
902 //
903 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
904 //
905 // This formula is trivially equivalent to the previous formula. However,
906 // this formula can be implemented much more efficiently. The trick is that
907 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
908 // arithmetic. To do exact division in modular arithmetic, all we have
909 // to do is multiply by the inverse. Therefore, this step can be done at
910 // width W.
911 //
912 // The next issue is how to safely do the division by 2^T. The way this
913 // is done is by doing the multiplication step at a width of at least W + T
914 // bits. This way, the bottom W+T bits of the product are accurate. Then,
915 // when we perform the division by 2^T (which is equivalent to a right shift
916 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
917 // truncated out after the division by 2^T.
918 //
919 // In comparison to just directly using the first formula, this technique
920 // is much more efficient; using the first formula requires W * K bits,
921 // but this formula less than W + K bits. Also, the first formula requires
922 // a division step, whereas this formula only requires multiplies and shifts.
923 //
924 // It doesn't matter whether the subtraction step is done in the calculation
925 // width or the input iteration count's width; if the subtraction overflows,
926 // the result must be zero anyway. We prefer here to do it in the width of
927 // the induction variable because it helps a lot for certain cases; CodeGen
928 // isn't smart enough to ignore the overflow, which leads to much less
929 // efficient code if the width of the subtraction is wider than the native
930 // register width.
931 //
932 // (It's possible to not widen at all by pulling out factors of 2 before
933 // the multiplication; for example, K=2 can be calculated as
934 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
935 // extra arithmetic, so it's not an obvious win, and it gets
936 // much more complicated for K > 3.)
937
938 // Protection from insane SCEVs; this bound is conservative,
939 // but it probably doesn't matter.
940 if (K > 1000)
941 return SE.getCouldNotCompute();
942
943 unsigned W = SE.getTypeSizeInBits(ResultTy);
944
945 // Calculate K! / 2^T and T; we divide out the factors of two before
946 // multiplying for calculating K! / 2^T to avoid overflow.
947 // Other overflow doesn't matter because we only care about the bottom
948 // W bits of the result.
949 APInt OddFactorial(W, 1);
950 unsigned T = 1;
951 for (unsigned i = 3; i <= K; ++i) {
952 unsigned TwoFactors = countr_zero(i);
953 T += TwoFactors;
954 OddFactorial *= (i >> TwoFactors);
955 }
956
957 // We need at least W + T bits for the multiplication step
958 unsigned CalculationBits = W + T;
959
960 // Calculate 2^T, at width T+W.
961 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
962
963 // Calculate the multiplicative inverse of K! / 2^T;
964 // this multiplication factor will perform the exact division by
965 // K! / 2^T.
966 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
967
968 // Calculate the product, at width T+W
969 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
970 CalculationBits);
971 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
972 for (unsigned i = 1; i != K; ++i) {
973 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
974 Dividend = SE.getMulExpr(Dividend,
975 SE.getTruncateOrZeroExtend(S, CalculationTy));
976 }
977
978 // Divide by 2^T
979 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
980
981 // Truncate the result, and divide by K! / 2^T.
982
983 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
984 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
985}
986
987/// Return the value of this chain of recurrences at the specified iteration
988/// number. We can evaluate this recurrence by multiplying each element in the
989/// chain by the binomial coefficient corresponding to it. In other words, we
990/// can evaluate {A,+,B,+,C,+,D} as:
991///
992/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
993///
994/// where BC(It, k) stands for binomial coefficient.
996 ScalarEvolution &SE) const {
997 return evaluateAtIteration(operands(), It, SE);
998}
999
1000const SCEV *
1002 const SCEV *It, ScalarEvolution &SE) {
1003 assert(Operands.size() > 0);
1004 const SCEV *Result = Operands[0];
1005 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
1006 // The computation is correct in the face of overflow provided that the
1007 // multiplication is performed _after_ the evaluation of the binomial
1008 // coefficient.
1009 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
1010 if (isa<SCEVCouldNotCompute>(Coeff))
1011 return Coeff;
1012
1013 Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
1014 }
1015 return Result;
1016}
1017
1018//===----------------------------------------------------------------------===//
1019// SCEV Expression folder implementations
1020//===----------------------------------------------------------------------===//
1021
1022/// The SCEVCastSinkingRewriter takes a scalar evolution expression,
1023/// which computes a pointer-typed value, and rewrites the whole expression
1024/// tree so that *all* the computations are done on integers, and the only
1025/// pointer-typed operands in the expression are SCEVUnknown.
1026/// The CreatePtrCast callback is invoked to create the actual conversion
1027/// (ptrtoint or ptrtoaddr) at the SCEVUnknown leaves.
1029 : public SCEVRewriteVisitor<SCEVCastSinkingRewriter> {
1031 using ConversionFn = function_ref<const SCEV *(const SCEVUnknown *)>;
1032 Type *TargetTy;
1033 ConversionFn CreatePtrCast;
1034
1035public:
1037 ConversionFn CreatePtrCast)
1038 : Base(SE), TargetTy(TargetTy), CreatePtrCast(std::move(CreatePtrCast)) {}
1039
1040 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
1041 Type *TargetTy, ConversionFn CreatePtrCast) {
1042 SCEVCastSinkingRewriter Rewriter(SE, TargetTy, std::move(CreatePtrCast));
1043 return Rewriter.visit(Scev);
1044 }
1045
1046 const SCEV *visit(const SCEV *S) {
1047 Type *STy = S->getType();
1048 // If the expression is not pointer-typed, just keep it as-is.
1049 if (!STy->isPointerTy())
1050 return S;
1051 // Else, recursively sink the cast down into it.
1052 return Base::visit(S);
1053 }
1054
1055 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1056 // Preserve wrap flags on rewritten SCEVAddExpr, which the default
1057 // implementation drops.
1059 bool Changed = false;
1060 for (const auto *Op : Expr->operands()) {
1061 Operands.push_back(visit(Op));
1062 Changed |= Op != Operands.back();
1063 }
1064 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1065 }
1066
1067 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1069 bool Changed = false;
1070 for (const auto *Op : Expr->operands()) {
1071 Operands.push_back(visit(Op));
1072 Changed |= Op != Operands.back();
1073 }
1074 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1075 }
1076
1077 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1078 assert(Expr->getType()->isPointerTy() &&
1079 "Should only reach pointer-typed SCEVUnknown's.");
1080 // Perform some basic constant folding. If the operand of the cast is a
1081 // null pointer, don't create a cast SCEV expression (that will be left
1082 // as-is), but produce a zero constant.
1084 return SE.getZero(TargetTy);
1085 return CreatePtrCast(Expr);
1086 }
1087};
1088
1090 assert(Op->getType()->isPointerTy() && "Op must be a pointer");
1091
1092 // It isn't legal for optimizations to construct new ptrtoint expressions
1093 // for non-integral pointers.
1094 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1095 return getCouldNotCompute();
1096
1097 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1098
1099 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1100 // is sufficiently wide to represent all possible pointer values.
1101 // We could theoretically teach SCEV to truncate wider pointers, but
1102 // that isn't implemented for now.
1104 getDataLayout().getTypeSizeInBits(IntPtrTy))
1105 return getCouldNotCompute();
1106
1107 // Use the rewriter to sink the cast down to SCEVUnknown leaves.
1109 Op, *this, IntPtrTy, [this, IntPtrTy](const SCEVUnknown *U) {
1111 ID.AddInteger(scPtrToInt);
1112 ID.AddPointer(U);
1113 void *IP = nullptr;
1114 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1115 return S;
1116 SCEV *S = new (SCEVAllocator)
1117 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), U, IntPtrTy);
1118 UniqueSCEVs.InsertNode(S, IP);
1119 registerUser(S, U);
1120 return static_cast<const SCEV *>(S);
1121 });
1122 assert(IntOp->getType()->isIntegerTy() &&
1123 "We must have succeeded in sinking the cast, "
1124 "and ending up with an integer-typed expression!");
1125 return IntOp;
1126}
1127
1129 assert(Op->getType()->isPointerTy() && "Op must be a pointer");
1130
1131 // Treat pointers with unstable representation conservatively, since the
1132 // address bits may change.
1133 if (DL.hasUnstableRepresentation(Op->getType()))
1134 return getCouldNotCompute();
1135
1136 Type *Ty = DL.getAddressType(Op->getType());
1137
1138 // Use the rewriter to sink the cast down to SCEVUnknown leaves.
1139 // The rewriter handles null pointer constant folding.
1141 Op, *this, Ty, [this, Ty](const SCEVUnknown *U) {
1143 ID.AddInteger(scPtrToAddr);
1144 ID.AddPointer(U);
1145 ID.AddPointer(Ty);
1146 void *IP = nullptr;
1147 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1148 return S;
1149 SCEV *S = new (SCEVAllocator)
1150 SCEVPtrToAddrExpr(ID.Intern(SCEVAllocator), U, Ty);
1151 UniqueSCEVs.InsertNode(S, IP);
1152 registerUser(S, U);
1153 return static_cast<const SCEV *>(S);
1154 });
1155 assert(IntOp->getType()->isIntegerTy() &&
1156 "We must have succeeded in sinking the cast, "
1157 "and ending up with an integer-typed expression!");
1158 return IntOp;
1159}
1160
1162 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1163
1164 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1165 if (isa<SCEVCouldNotCompute>(IntOp))
1166 return IntOp;
1167
1168 return getTruncateOrZeroExtend(IntOp, Ty);
1169}
1170
1172 unsigned Depth) {
1173 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1174 "This is not a truncating conversion!");
1175 assert(isSCEVable(Ty) &&
1176 "This is not a conversion to a SCEVable type!");
1177 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1178 Ty = getEffectiveSCEVType(Ty);
1179
1181 ID.AddInteger(scTruncate);
1182 ID.AddPointer(Op);
1183 ID.AddPointer(Ty);
1184 void *IP = nullptr;
1185 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1186
1187 // Fold if the operand is constant.
1188 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1189 return getConstant(
1190 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1191
1192 // trunc(trunc(x)) --> trunc(x)
1194 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1195
1196 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1198 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1199
1200 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1202 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1203
1204 if (Depth > MaxCastDepth) {
1205 SCEV *S =
1206 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1207 UniqueSCEVs.InsertNode(S, IP);
1208 registerUser(S, Op);
1209 return S;
1210 }
1211
1212 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1213 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1214 // if after transforming we have at most one truncate, not counting truncates
1215 // that replace other casts.
1217 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1219 unsigned numTruncs = 0;
1220 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1221 ++i) {
1222 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1223 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1225 numTruncs++;
1226 Operands.push_back(S);
1227 }
1228 if (numTruncs < 2) {
1229 if (isa<SCEVAddExpr>(Op))
1230 return getAddExpr(Operands);
1231 if (isa<SCEVMulExpr>(Op))
1232 return getMulExpr(Operands);
1233 llvm_unreachable("Unexpected SCEV type for Op.");
1234 }
1235 // Although we checked in the beginning that ID is not in the cache, it is
1236 // possible that during recursion and different modification ID was inserted
1237 // into the cache. So if we find it, just return it.
1238 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1239 return S;
1240 }
1241
1242 // If the input value is a chrec scev, truncate the chrec's operands.
1243 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1245 for (const SCEV *Op : AddRec->operands())
1246 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1247 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1248 }
1249
1250 // Return zero if truncating to known zeros.
1251 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1252 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1253 return getZero(Ty);
1254
1255 // The cast wasn't folded; create an explicit cast node. We can reuse
1256 // the existing insert position since if we get here, we won't have
1257 // made any changes which would invalidate it.
1258 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1259 Op, Ty);
1260 UniqueSCEVs.InsertNode(S, IP);
1261 registerUser(S, Op);
1262 return S;
1263}
1264
1265// Get the limit of a recurrence such that incrementing by Step cannot cause
1266// signed overflow as long as the value of the recurrence within the
1267// loop does not exceed this limit before incrementing.
1268static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1269 ICmpInst::Predicate *Pred,
1270 ScalarEvolution *SE) {
1271 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1272 if (SE->isKnownPositive(Step)) {
1273 *Pred = ICmpInst::ICMP_SLT;
1275 SE->getSignedRangeMax(Step));
1276 }
1277 if (SE->isKnownNegative(Step)) {
1278 *Pred = ICmpInst::ICMP_SGT;
1280 SE->getSignedRangeMin(Step));
1281 }
1282 return nullptr;
1283}
1284
1285// Get the limit of a recurrence such that incrementing by Step cannot cause
1286// unsigned overflow as long as the value of the recurrence within the loop does
1287// not exceed this limit before incrementing.
1289 ICmpInst::Predicate *Pred,
1290 ScalarEvolution *SE) {
1291 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1292 *Pred = ICmpInst::ICMP_ULT;
1293
1295 SE->getUnsignedRangeMax(Step));
1296}
1297
1298namespace {
1299
1300struct ExtendOpTraitsBase {
1301 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1302 unsigned);
1303};
1304
1305// Used to make code generic over signed and unsigned overflow.
1306template <typename ExtendOp> struct ExtendOpTraits {
1307 // Members present:
1308 //
1309 // static const SCEV::NoWrapFlags WrapType;
1310 //
1311 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1312 //
1313 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1314 // ICmpInst::Predicate *Pred,
1315 // ScalarEvolution *SE);
1316};
1317
1318template <>
1319struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1320 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1321
1322 static const GetExtendExprTy GetExtendExpr;
1323
1324 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1325 ICmpInst::Predicate *Pred,
1326 ScalarEvolution *SE) {
1327 return getSignedOverflowLimitForStep(Step, Pred, SE);
1328 }
1329};
1330
1331const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1333
1334template <>
1335struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1336 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1337
1338 static const GetExtendExprTy GetExtendExpr;
1339
1340 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1341 ICmpInst::Predicate *Pred,
1342 ScalarEvolution *SE) {
1343 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1344 }
1345};
1346
1347const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1349
1350} // end anonymous namespace
1351
1352// The recurrence AR has been shown to have no signed/unsigned wrap or something
1353// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1354// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1355// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1356// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1357// expression "Step + sext/zext(PreIncAR)" is congruent with
1358// "sext/zext(PostIncAR)"
1359template <typename ExtendOpTy>
1360static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1361 ScalarEvolution *SE, unsigned Depth) {
1362 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1363 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1364
1365 const Loop *L = AR->getLoop();
1366 const SCEV *Start = AR->getStart();
1367 const SCEV *Step = AR->getStepRecurrence(*SE);
1368
1369 // Check for a simple looking step prior to loop entry.
1370 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1371 if (!SA)
1372 return nullptr;
1373
1374 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1375 // subtraction is expensive. For this purpose, perform a quick and dirty
1376 // difference, by checking for Step in the operand list. Note, that
1377 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1379 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1380 if (*It == Step) {
1381 DiffOps.erase(It);
1382 break;
1383 }
1384
1385 if (DiffOps.size() == SA->getNumOperands())
1386 return nullptr;
1387
1388 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1389 // `Step`:
1390
1391 // 1. NSW/NUW flags on the step increment.
1392 auto PreStartFlags =
1394 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1396 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1397
1398 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1399 // "S+X does not sign/unsign-overflow".
1400 //
1401
1402 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1403 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1404 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1405 return PreStart;
1406
1407 // 2. Direct overflow check on the step operation's expression.
1408 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1409 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1410 const SCEV *OperandExtendedStart =
1411 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1412 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1413 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1414 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1415 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1416 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1417 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1418 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1419 }
1420 return PreStart;
1421 }
1422
1423 // 3. Loop precondition.
1425 const SCEV *OverflowLimit =
1426 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1427
1428 if (OverflowLimit &&
1429 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1430 return PreStart;
1431
1432 return nullptr;
1433}
1434
1435// Get the normalized zero or sign extended expression for this AddRec's Start.
1436template <typename ExtendOpTy>
1437static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1438 ScalarEvolution *SE,
1439 unsigned Depth) {
1440 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1441
1442 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1443 if (!PreStart)
1444 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1445
1446 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1447 Depth),
1448 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1449}
1450
1451// Try to prove away overflow by looking at "nearby" add recurrences. A
1452// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1453// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1454//
1455// Formally:
1456//
1457// {S,+,X} == {S-T,+,X} + T
1458// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1459//
1460// If ({S-T,+,X} + T) does not overflow ... (1)
1461//
1462// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1463//
1464// If {S-T,+,X} does not overflow ... (2)
1465//
1466// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1467// == {Ext(S-T)+Ext(T),+,Ext(X)}
1468//
1469// If (S-T)+T does not overflow ... (3)
1470//
1471// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1472// == {Ext(S),+,Ext(X)} == LHS
1473//
1474// Thus, if (1), (2) and (3) are true for some T, then
1475// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1476//
1477// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1478// does not overflow" restricted to the 0th iteration. Therefore we only need
1479// to check for (1) and (2).
1480//
1481// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1482// is `Delta` (defined below).
1483template <typename ExtendOpTy>
1484bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1485 const SCEV *Step,
1486 const Loop *L) {
1487 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1488
1489 // We restrict `Start` to a constant to prevent SCEV from spending too much
1490 // time here. It is correct (but more expensive) to continue with a
1491 // non-constant `Start` and do a general SCEV subtraction to compute
1492 // `PreStart` below.
1493 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1494 if (!StartC)
1495 return false;
1496
1497 APInt StartAI = StartC->getAPInt();
1498
1499 for (unsigned Delta : {-2, -1, 1, 2}) {
1500 const SCEV *PreStart = getConstant(StartAI - Delta);
1501
1502 FoldingSetNodeID ID;
1503 ID.AddInteger(scAddRecExpr);
1504 ID.AddPointer(PreStart);
1505 ID.AddPointer(Step);
1506 ID.AddPointer(L);
1507 void *IP = nullptr;
1508 const auto *PreAR =
1509 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1510
1511 // Give up if we don't already have the add recurrence we need because
1512 // actually constructing an add recurrence is relatively expensive.
1513 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1514 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1516 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1517 DeltaS, &Pred, this);
1518 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1519 return true;
1520 }
1521 }
1522
1523 return false;
1524}
1525
1526// Finds an integer D for an expression (C + x + y + ...) such that the top
1527// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1528// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1529// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1530// the (C + x + y + ...) expression is \p WholeAddExpr.
1532 const SCEVConstant *ConstantTerm,
1533 const SCEVAddExpr *WholeAddExpr) {
1534 const APInt &C = ConstantTerm->getAPInt();
1535 const unsigned BitWidth = C.getBitWidth();
1536 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1537 uint32_t TZ = BitWidth;
1538 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1539 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1540 if (TZ) {
1541 // Set D to be as many least significant bits of C as possible while still
1542 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1543 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1544 }
1545 return APInt(BitWidth, 0);
1546}
1547
1548// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1549// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1550// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1551// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1553 const APInt &ConstantStart,
1554 const SCEV *Step) {
1555 const unsigned BitWidth = ConstantStart.getBitWidth();
1556 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1557 if (TZ)
1558 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1559 : ConstantStart;
1560 return APInt(BitWidth, 0);
1561}
1562
1564 const ScalarEvolution::FoldID &ID, const SCEV *S,
1567 &FoldCacheUser) {
1568 auto I = FoldCache.insert({ID, S});
1569 if (!I.second) {
1570 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1571 // entry.
1572 auto &UserIDs = FoldCacheUser[I.first->second];
1573 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1574 for (unsigned I = 0; I != UserIDs.size(); ++I)
1575 if (UserIDs[I] == ID) {
1576 std::swap(UserIDs[I], UserIDs.back());
1577 break;
1578 }
1579 UserIDs.pop_back();
1580 I.first->second = S;
1581 }
1582 FoldCacheUser[S].push_back(ID);
1583}
1584
1585const SCEV *
1587 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1588 "This is not an extending conversion!");
1589 assert(isSCEVable(Ty) &&
1590 "This is not a conversion to a SCEVable type!");
1591 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1592 Ty = getEffectiveSCEVType(Ty);
1593
1594 FoldID ID(scZeroExtend, Op, Ty);
1595 if (const SCEV *S = FoldCache.lookup(ID))
1596 return S;
1597
1598 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1600 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1601 return S;
1602}
1603
1605 unsigned Depth) {
1606 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1607 "This is not an extending conversion!");
1608 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1609 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1610
1611 // Fold if the operand is constant.
1612 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1613 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1614
1615 // zext(zext(x)) --> zext(x)
1617 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1618
1619 // Before doing any expensive analysis, check to see if we've already
1620 // computed a SCEV for this Op and Ty.
1622 ID.AddInteger(scZeroExtend);
1623 ID.AddPointer(Op);
1624 ID.AddPointer(Ty);
1625 void *IP = nullptr;
1626 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1627 if (Depth > MaxCastDepth) {
1628 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1629 Op, Ty);
1630 UniqueSCEVs.InsertNode(S, IP);
1631 registerUser(S, Op);
1632 return S;
1633 }
1634
1635 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1637 // It's possible the bits taken off by the truncate were all zero bits. If
1638 // so, we should be able to simplify this further.
1639 const SCEV *X = ST->getOperand();
1641 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1642 unsigned NewBits = getTypeSizeInBits(Ty);
1643 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1644 CR.zextOrTrunc(NewBits)))
1645 return getTruncateOrZeroExtend(X, Ty, Depth);
1646 }
1647
1648 // If the input value is a chrec scev, and we can prove that the value
1649 // did not overflow the old, smaller, value, we can zero extend all of the
1650 // operands (often constants). This allows analysis of something like
1651 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1653 if (AR->isAffine()) {
1654 const SCEV *Start = AR->getStart();
1655 const SCEV *Step = AR->getStepRecurrence(*this);
1656 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1657 const Loop *L = AR->getLoop();
1658
1659 // If we have special knowledge that this addrec won't overflow,
1660 // we don't need to do any further analysis.
1661 if (AR->hasNoUnsignedWrap()) {
1662 Start =
1664 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1665 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1666 }
1667
1668 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1669 // Note that this serves two purposes: It filters out loops that are
1670 // simply not analyzable, and it covers the case where this code is
1671 // being called from within backedge-taken count analysis, such that
1672 // attempting to ask for the backedge-taken count would likely result
1673 // in infinite recursion. In the later case, the analysis code will
1674 // cope with a conservative value, and it will take care to purge
1675 // that value once it has finished.
1676 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1677 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1678 // Manually compute the final value for AR, checking for overflow.
1679
1680 // Check whether the backedge-taken count can be losslessly casted to
1681 // the addrec's type. The count is always unsigned.
1682 const SCEV *CastedMaxBECount =
1683 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1684 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1685 CastedMaxBECount, MaxBECount->getType(), Depth);
1686 if (MaxBECount == RecastedMaxBECount) {
1687 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1688 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1689 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1691 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1693 Depth + 1),
1694 WideTy, Depth + 1);
1695 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1696 const SCEV *WideMaxBECount =
1697 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1698 const SCEV *OperandExtendedAdd =
1699 getAddExpr(WideStart,
1700 getMulExpr(WideMaxBECount,
1701 getZeroExtendExpr(Step, WideTy, Depth + 1),
1704 if (ZAdd == OperandExtendedAdd) {
1705 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1706 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1707 // Return the expression with the addrec on the outside.
1708 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1709 Depth + 1);
1710 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1711 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1712 }
1713 // Similar to above, only this time treat the step value as signed.
1714 // This covers loops that count down.
1715 OperandExtendedAdd =
1716 getAddExpr(WideStart,
1717 getMulExpr(WideMaxBECount,
1718 getSignExtendExpr(Step, WideTy, Depth + 1),
1721 if (ZAdd == OperandExtendedAdd) {
1722 // Cache knowledge of AR NW, which is propagated to this AddRec.
1723 // Negative step causes unsigned wrap, but it still can't self-wrap.
1724 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1725 // Return the expression with the addrec on the outside.
1726 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1727 Depth + 1);
1728 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1729 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1730 }
1731 }
1732 }
1733
1734 // Normally, in the cases we can prove no-overflow via a
1735 // backedge guarding condition, we can also compute a backedge
1736 // taken count for the loop. The exceptions are assumptions and
1737 // guards present in the loop -- SCEV is not great at exploiting
1738 // these to compute max backedge taken counts, but can still use
1739 // these to prove lack of overflow. Use this fact to avoid
1740 // doing extra work that may not pay off.
1741 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1742 !AC.assumptions().empty()) {
1743
1744 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1745 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1746 if (AR->hasNoUnsignedWrap()) {
1747 // Same as nuw case above - duplicated here to avoid a compile time
1748 // issue. It's not clear that the order of checks does matter, but
1749 // it's one of two issue possible causes for a change which was
1750 // reverted. Be conservative for the moment.
1751 Start =
1753 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1754 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1755 }
1756
1757 // For a negative step, we can extend the operands iff doing so only
1758 // traverses values in the range zext([0,UINT_MAX]).
1759 if (isKnownNegative(Step)) {
1761 getSignedRangeMin(Step));
1764 // Cache knowledge of AR NW, which is propagated to this
1765 // AddRec. Negative step causes unsigned wrap, but it
1766 // still can't self-wrap.
1767 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1768 // Return the expression with the addrec on the outside.
1769 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1770 Depth + 1);
1771 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1772 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1773 }
1774 }
1775 }
1776
1777 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1778 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1779 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1780 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1781 const APInt &C = SC->getAPInt();
1782 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1783 if (D != 0) {
1784 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1785 const SCEV *SResidual =
1786 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1787 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1788 return getAddExpr(SZExtD, SZExtR,
1790 Depth + 1);
1791 }
1792 }
1793
1794 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1795 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1796 Start =
1798 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1799 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1800 }
1801 }
1802
1803 // zext(A % B) --> zext(A) % zext(B)
1804 {
1805 const SCEV *LHS;
1806 const SCEV *RHS;
1807 if (match(Op, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), *this)))
1808 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1809 getZeroExtendExpr(RHS, Ty, Depth + 1));
1810 }
1811
1812 // zext(A / B) --> zext(A) / zext(B).
1813 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1814 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1815 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1816
1817 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1818 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1819 if (SA->hasNoUnsignedWrap()) {
1820 // If the addition does not unsign overflow then we can, by definition,
1821 // commute the zero extension with the addition operation.
1823 for (const auto *Op : SA->operands())
1824 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1825 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1826 }
1827
1828 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1829 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1830 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1831 //
1832 // Often address arithmetics contain expressions like
1833 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1834 // This transformation is useful while proving that such expressions are
1835 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1836 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1837 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1838 if (D != 0) {
1839 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1840 const SCEV *SResidual =
1842 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1843 return getAddExpr(SZExtD, SZExtR,
1845 Depth + 1);
1846 }
1847 }
1848 }
1849
1850 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1851 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1852 if (SM->hasNoUnsignedWrap()) {
1853 // If the multiply does not unsign overflow then we can, by definition,
1854 // commute the zero extension with the multiply operation.
1856 for (const auto *Op : SM->operands())
1857 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1858 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1859 }
1860
1861 // zext(2^K * (trunc X to iN)) to iM ->
1862 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1863 //
1864 // Proof:
1865 //
1866 // zext(2^K * (trunc X to iN)) to iM
1867 // = zext((trunc X to iN) << K) to iM
1868 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1869 // (because shl removes the top K bits)
1870 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1871 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1872 //
1873 const APInt *C;
1874 const SCEV *TruncRHS;
1875 if (match(SM,
1876 m_scev_Mul(m_scev_APInt(C), m_scev_Trunc(m_SCEV(TruncRHS)))) &&
1877 C->isPowerOf2()) {
1878 int NewTruncBits =
1879 getTypeSizeInBits(SM->getOperand(1)->getType()) - C->logBase2();
1880 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1881 return getMulExpr(
1882 getZeroExtendExpr(SM->getOperand(0), Ty),
1883 getZeroExtendExpr(getTruncateExpr(TruncRHS, NewTruncTy), Ty),
1884 SCEV::FlagNUW, Depth + 1);
1885 }
1886 }
1887
1888 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1889 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1893 for (auto *Operand : MinMax->operands())
1894 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1896 return getUMinExpr(Operands);
1897 return getUMaxExpr(Operands);
1898 }
1899
1900 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1902 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1904 for (auto *Operand : MinMax->operands())
1905 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1906 return getUMinExpr(Operands, /*Sequential*/ true);
1907 }
1908
1909 // The cast wasn't folded; create an explicit cast node.
1910 // Recompute the insert position, as it may have been invalidated.
1911 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1912 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1913 Op, Ty);
1914 UniqueSCEVs.InsertNode(S, IP);
1915 registerUser(S, Op);
1916 return S;
1917}
1918
1919const SCEV *
1921 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1922 "This is not an extending conversion!");
1923 assert(isSCEVable(Ty) &&
1924 "This is not a conversion to a SCEVable type!");
1925 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1926 Ty = getEffectiveSCEVType(Ty);
1927
1928 FoldID ID(scSignExtend, Op, Ty);
1929 if (const SCEV *S = FoldCache.lookup(ID))
1930 return S;
1931
1932 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
1934 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1935 return S;
1936}
1937
1939 unsigned Depth) {
1940 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1941 "This is not an extending conversion!");
1942 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1943 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1944 Ty = getEffectiveSCEVType(Ty);
1945
1946 // Fold if the operand is constant.
1947 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1948 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
1949
1950 // sext(sext(x)) --> sext(x)
1952 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1953
1954 // sext(zext(x)) --> zext(x)
1956 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1957
1958 // Before doing any expensive analysis, check to see if we've already
1959 // computed a SCEV for this Op and Ty.
1961 ID.AddInteger(scSignExtend);
1962 ID.AddPointer(Op);
1963 ID.AddPointer(Ty);
1964 void *IP = nullptr;
1965 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1966 // Limit recursion depth.
1967 if (Depth > MaxCastDepth) {
1968 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1969 Op, Ty);
1970 UniqueSCEVs.InsertNode(S, IP);
1971 registerUser(S, Op);
1972 return S;
1973 }
1974
1975 // sext(trunc(x)) --> sext(x) or x or trunc(x)
1977 // It's possible the bits taken off by the truncate were all sign bits. If
1978 // so, we should be able to simplify this further.
1979 const SCEV *X = ST->getOperand();
1981 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1982 unsigned NewBits = getTypeSizeInBits(Ty);
1983 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1984 CR.sextOrTrunc(NewBits)))
1985 return getTruncateOrSignExtend(X, Ty, Depth);
1986 }
1987
1988 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1989 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1990 if (SA->hasNoSignedWrap()) {
1991 // If the addition does not sign overflow then we can, by definition,
1992 // commute the sign extension with the addition operation.
1994 for (const auto *Op : SA->operands())
1995 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1996 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1997 }
1998
1999 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
2000 // if D + (C - D + x + y + ...) could be proven to not signed wrap
2001 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
2002 //
2003 // For instance, this will bring two seemingly different expressions:
2004 // 1 + sext(5 + 20 * %x + 24 * %y) and
2005 // sext(6 + 20 * %x + 24 * %y)
2006 // to the same form:
2007 // 2 + sext(4 + 20 * %x + 24 * %y)
2008 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
2009 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
2010 if (D != 0) {
2011 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2012 const SCEV *SResidual =
2014 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2015 return getAddExpr(SSExtD, SSExtR,
2017 Depth + 1);
2018 }
2019 }
2020 }
2021 // If the input value is a chrec scev, and we can prove that the value
2022 // did not overflow the old, smaller, value, we can sign extend all of the
2023 // operands (often constants). This allows analysis of something like
2024 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
2026 if (AR->isAffine()) {
2027 const SCEV *Start = AR->getStart();
2028 const SCEV *Step = AR->getStepRecurrence(*this);
2029 unsigned BitWidth = getTypeSizeInBits(AR->getType());
2030 const Loop *L = AR->getLoop();
2031
2032 // If we have special knowledge that this addrec won't overflow,
2033 // we don't need to do any further analysis.
2034 if (AR->hasNoSignedWrap()) {
2035 Start =
2037 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2038 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2039 }
2040
2041 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2042 // Note that this serves two purposes: It filters out loops that are
2043 // simply not analyzable, and it covers the case where this code is
2044 // being called from within backedge-taken count analysis, such that
2045 // attempting to ask for the backedge-taken count would likely result
2046 // in infinite recursion. In the later case, the analysis code will
2047 // cope with a conservative value, and it will take care to purge
2048 // that value once it has finished.
2049 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2050 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2051 // Manually compute the final value for AR, checking for
2052 // overflow.
2053
2054 // Check whether the backedge-taken count can be losslessly casted to
2055 // the addrec's type. The count is always unsigned.
2056 const SCEV *CastedMaxBECount =
2057 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2058 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2059 CastedMaxBECount, MaxBECount->getType(), Depth);
2060 if (MaxBECount == RecastedMaxBECount) {
2061 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2062 // Check whether Start+Step*MaxBECount has no signed overflow.
2063 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2065 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2067 Depth + 1),
2068 WideTy, Depth + 1);
2069 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2070 const SCEV *WideMaxBECount =
2071 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2072 const SCEV *OperandExtendedAdd =
2073 getAddExpr(WideStart,
2074 getMulExpr(WideMaxBECount,
2075 getSignExtendExpr(Step, WideTy, Depth + 1),
2078 if (SAdd == OperandExtendedAdd) {
2079 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2080 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2081 // Return the expression with the addrec on the outside.
2082 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2083 Depth + 1);
2084 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2085 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2086 }
2087 // Similar to above, only this time treat the step value as unsigned.
2088 // This covers loops that count up with an unsigned step.
2089 OperandExtendedAdd =
2090 getAddExpr(WideStart,
2091 getMulExpr(WideMaxBECount,
2092 getZeroExtendExpr(Step, WideTy, Depth + 1),
2095 if (SAdd == OperandExtendedAdd) {
2096 // If AR wraps around then
2097 //
2098 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2099 // => SAdd != OperandExtendedAdd
2100 //
2101 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2102 // (SAdd == OperandExtendedAdd => AR is NW)
2103
2104 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2105
2106 // Return the expression with the addrec on the outside.
2107 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2108 Depth + 1);
2109 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2110 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2111 }
2112 }
2113 }
2114
2115 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2116 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2117 if (AR->hasNoSignedWrap()) {
2118 // Same as nsw case above - duplicated here to avoid a compile time
2119 // issue. It's not clear that the order of checks does matter, but
2120 // it's one of two issue possible causes for a change which was
2121 // reverted. Be conservative for the moment.
2122 Start =
2124 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2125 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2126 }
2127
2128 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2129 // if D + (C - D + Step * n) could be proven to not signed wrap
2130 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2131 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2132 const APInt &C = SC->getAPInt();
2133 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2134 if (D != 0) {
2135 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2136 const SCEV *SResidual =
2137 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2138 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2139 return getAddExpr(SSExtD, SSExtR,
2141 Depth + 1);
2142 }
2143 }
2144
2145 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2146 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2147 Start =
2149 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2150 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2151 }
2152 }
2153
2154 // If the input value is provably positive and we could not simplify
2155 // away the sext build a zext instead.
2157 return getZeroExtendExpr(Op, Ty, Depth + 1);
2158
2159 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2160 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2164 for (auto *Operand : MinMax->operands())
2165 Operands.push_back(getSignExtendExpr(Operand, Ty));
2167 return getSMinExpr(Operands);
2168 return getSMaxExpr(Operands);
2169 }
2170
2171 // The cast wasn't folded; create an explicit cast node.
2172 // Recompute the insert position, as it may have been invalidated.
2173 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2174 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2175 Op, Ty);
2176 UniqueSCEVs.InsertNode(S, IP);
2177 registerUser(S, { Op });
2178 return S;
2179}
2180
2182 Type *Ty) {
2183 switch (Kind) {
2184 case scTruncate:
2185 return getTruncateExpr(Op, Ty);
2186 case scZeroExtend:
2187 return getZeroExtendExpr(Op, Ty);
2188 case scSignExtend:
2189 return getSignExtendExpr(Op, Ty);
2190 case scPtrToInt:
2191 return getPtrToIntExpr(Op, Ty);
2192 default:
2193 llvm_unreachable("Not a SCEV cast expression!");
2194 }
2195}
2196
2197/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2198/// unspecified bits out to the given type.
2200 Type *Ty) {
2201 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2202 "This is not an extending conversion!");
2203 assert(isSCEVable(Ty) &&
2204 "This is not a conversion to a SCEVable type!");
2205 Ty = getEffectiveSCEVType(Ty);
2206
2207 // Sign-extend negative constants.
2208 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2209 if (SC->getAPInt().isNegative())
2210 return getSignExtendExpr(Op, Ty);
2211
2212 // Peel off a truncate cast.
2214 const SCEV *NewOp = T->getOperand();
2215 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2216 return getAnyExtendExpr(NewOp, Ty);
2217 return getTruncateOrNoop(NewOp, Ty);
2218 }
2219
2220 // Next try a zext cast. If the cast is folded, use it.
2221 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2222 if (!isa<SCEVZeroExtendExpr>(ZExt))
2223 return ZExt;
2224
2225 // Next try a sext cast. If the cast is folded, use it.
2226 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2227 if (!isa<SCEVSignExtendExpr>(SExt))
2228 return SExt;
2229
2230 // Force the cast to be folded into the operands of an addrec.
2231 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2233 for (const SCEV *Op : AR->operands())
2234 Ops.push_back(getAnyExtendExpr(Op, Ty));
2235 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2236 }
2237
2238 // If the expression is obviously signed, use the sext cast value.
2239 if (isa<SCEVSMaxExpr>(Op))
2240 return SExt;
2241
2242 // Absent any other information, use the zext cast value.
2243 return ZExt;
2244}
2245
2246/// Process the given Ops list, which is a list of operands to be added under
2247/// the given scale, update the given map. This is a helper function for
2248/// getAddRecExpr. As an example of what it does, given a sequence of operands
2249/// that would form an add expression like this:
2250///
2251/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2252///
2253/// where A and B are constants, update the map with these values:
2254///
2255/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2256///
2257/// and add 13 + A*B*29 to AccumulatedConstant.
2258/// This will allow getAddRecExpr to produce this:
2259///
2260/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2261///
2262/// This form often exposes folding opportunities that are hidden in
2263/// the original operand list.
2264///
2265/// Return true iff it appears that any interesting folding opportunities
2266/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2267/// the common case where no interesting opportunities are present, and
2268/// is also used as a check to avoid infinite recursion.
2269static bool
2272 APInt &AccumulatedConstant,
2273 ArrayRef<const SCEV *> Ops, const APInt &Scale,
2274 ScalarEvolution &SE) {
2275 bool Interesting = false;
2276
2277 // Iterate over the add operands. They are sorted, with constants first.
2278 unsigned i = 0;
2279 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2280 ++i;
2281 // Pull a buried constant out to the outside.
2282 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2283 Interesting = true;
2284 AccumulatedConstant += Scale * C->getAPInt();
2285 }
2286
2287 // Next comes everything else. We're especially interested in multiplies
2288 // here, but they're in the middle, so just visit the rest with one loop.
2289 for (; i != Ops.size(); ++i) {
2291 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2292 APInt NewScale =
2293 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2294 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2295 // A multiplication of a constant with another add; recurse.
2296 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2297 Interesting |=
2298 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2299 Add->operands(), NewScale, SE);
2300 } else {
2301 // A multiplication of a constant with some other value. Update
2302 // the map.
2303 SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2304 const SCEV *Key = SE.getMulExpr(MulOps);
2305 auto Pair = M.insert({Key, NewScale});
2306 if (Pair.second) {
2307 NewOps.push_back(Pair.first->first);
2308 } else {
2309 Pair.first->second += NewScale;
2310 // The map already had an entry for this value, which may indicate
2311 // a folding opportunity.
2312 Interesting = true;
2313 }
2314 }
2315 } else {
2316 // An ordinary operand. Update the map.
2317 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2318 M.insert({Ops[i], Scale});
2319 if (Pair.second) {
2320 NewOps.push_back(Pair.first->first);
2321 } else {
2322 Pair.first->second += Scale;
2323 // The map already had an entry for this value, which may indicate
2324 // a folding opportunity.
2325 Interesting = true;
2326 }
2327 }
2328 }
2329
2330 return Interesting;
2331}
2332
2334 const SCEV *LHS, const SCEV *RHS,
2335 const Instruction *CtxI) {
2336 const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2337 SCEV::NoWrapFlags, unsigned);
2338 switch (BinOp) {
2339 default:
2340 llvm_unreachable("Unsupported binary op");
2341 case Instruction::Add:
2343 break;
2344 case Instruction::Sub:
2346 break;
2347 case Instruction::Mul:
2349 break;
2350 }
2351
2352 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2355
2356 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2357 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2358 auto *WideTy =
2359 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2360
2361 const SCEV *A = (this->*Extension)(
2362 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2363 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2364 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2365 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2366 if (A == B)
2367 return true;
2368 // Can we use context to prove the fact we need?
2369 if (!CtxI)
2370 return false;
2371 // TODO: Support mul.
2372 if (BinOp == Instruction::Mul)
2373 return false;
2374 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2375 // TODO: Lift this limitation.
2376 if (!RHSC)
2377 return false;
2378 APInt C = RHSC->getAPInt();
2379 unsigned NumBits = C.getBitWidth();
2380 bool IsSub = (BinOp == Instruction::Sub);
2381 bool IsNegativeConst = (Signed && C.isNegative());
2382 // Compute the direction and magnitude by which we need to check overflow.
2383 bool OverflowDown = IsSub ^ IsNegativeConst;
2384 APInt Magnitude = C;
2385 if (IsNegativeConst) {
2386 if (C == APInt::getSignedMinValue(NumBits))
2387 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2388 // want to deal with that.
2389 return false;
2390 Magnitude = -C;
2391 }
2392
2394 if (OverflowDown) {
2395 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2396 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2397 : APInt::getMinValue(NumBits);
2398 APInt Limit = Min + Magnitude;
2399 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2400 } else {
2401 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2402 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2403 : APInt::getMaxValue(NumBits);
2404 APInt Limit = Max - Magnitude;
2405 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2406 }
2407}
2408
2409std::optional<SCEV::NoWrapFlags>
2411 const OverflowingBinaryOperator *OBO) {
2412 // It cannot be done any better.
2413 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2414 return std::nullopt;
2415
2417
2418 if (OBO->hasNoUnsignedWrap())
2420 if (OBO->hasNoSignedWrap())
2422
2423 bool Deduced = false;
2424
2425 if (OBO->getOpcode() != Instruction::Add &&
2426 OBO->getOpcode() != Instruction::Sub &&
2427 OBO->getOpcode() != Instruction::Mul)
2428 return std::nullopt;
2429
2430 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2431 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2432
2433 const Instruction *CtxI =
2435 if (!OBO->hasNoUnsignedWrap() &&
2437 /* Signed */ false, LHS, RHS, CtxI)) {
2439 Deduced = true;
2440 }
2441
2442 if (!OBO->hasNoSignedWrap() &&
2444 /* Signed */ true, LHS, RHS, CtxI)) {
2446 Deduced = true;
2447 }
2448
2449 if (Deduced)
2450 return Flags;
2451 return std::nullopt;
2452}
2453
2454// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2455// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2456// can't-overflow flags for the operation if possible.
2460 SCEV::NoWrapFlags Flags) {
2461 using namespace std::placeholders;
2462
2463 using OBO = OverflowingBinaryOperator;
2464
2465 bool CanAnalyze =
2467 (void)CanAnalyze;
2468 assert(CanAnalyze && "don't call from other places!");
2469
2470 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2471 SCEV::NoWrapFlags SignOrUnsignWrap =
2472 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2473
2474 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2475 auto IsKnownNonNegative = [&](const SCEV *S) {
2476 return SE->isKnownNonNegative(S);
2477 };
2478
2479 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2480 Flags =
2481 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2482
2483 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2484
2485 if (SignOrUnsignWrap != SignOrUnsignMask &&
2486 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2487 isa<SCEVConstant>(Ops[0])) {
2488
2489 auto Opcode = [&] {
2490 switch (Type) {
2491 case scAddExpr:
2492 return Instruction::Add;
2493 case scMulExpr:
2494 return Instruction::Mul;
2495 default:
2496 llvm_unreachable("Unexpected SCEV op.");
2497 }
2498 }();
2499
2500 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2501
2502 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2503 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2505 Opcode, C, OBO::NoSignedWrap);
2506 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2508 }
2509
2510 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2511 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2513 Opcode, C, OBO::NoUnsignedWrap);
2514 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2516 }
2517 }
2518
2519 // <0,+,nonnegative><nw> is also nuw
2520 // TODO: Add corresponding nsw case
2522 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2523 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2525
2526 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2528 Ops.size() == 2) {
2529 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2530 if (UDiv->getOperand(1) == Ops[1])
2532 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2533 if (UDiv->getOperand(1) == Ops[0])
2535 }
2536
2537 return Flags;
2538}
2539
2541 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2542}
2543
2544/// Get a canonical add expression, or something simpler if possible.
2546 SCEV::NoWrapFlags OrigFlags,
2547 unsigned Depth) {
2548 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2549 "only nuw or nsw allowed");
2550 assert(!Ops.empty() && "Cannot get empty add!");
2551 if (Ops.size() == 1) return Ops[0];
2552#ifndef NDEBUG
2553 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2554 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2555 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2556 "SCEVAddExpr operand types don't match!");
2557 unsigned NumPtrs = count_if(
2558 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2559 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2560#endif
2561
2562 const SCEV *Folded = constantFoldAndGroupOps(
2563 *this, LI, DT, Ops,
2564 [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2565 [](const APInt &C) { return C.isZero(); }, // identity
2566 [](const APInt &C) { return false; }); // absorber
2567 if (Folded)
2568 return Folded;
2569
2570 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2571
2572 // Delay expensive flag strengthening until necessary.
2573 auto ComputeFlags = [this, OrigFlags](ArrayRef<const SCEV *> Ops) {
2574 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2575 };
2576
2577 // Limit recursion calls depth.
2579 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2580
2581 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2582 // Don't strengthen flags if we have no new information.
2583 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2584 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2585 Add->setNoWrapFlags(ComputeFlags(Ops));
2586 return S;
2587 }
2588
2589 // Okay, check to see if the same value occurs in the operand list more than
2590 // once. If so, merge them together into an multiply expression. Since we
2591 // sorted the list, these values are required to be adjacent.
2592 Type *Ty = Ops[0]->getType();
2593 bool FoundMatch = false;
2594 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2595 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2596 // Scan ahead to count how many equal operands there are.
2597 unsigned Count = 2;
2598 while (i+Count != e && Ops[i+Count] == Ops[i])
2599 ++Count;
2600 // Merge the values into a multiply.
2601 const SCEV *Scale = getConstant(Ty, Count);
2602 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2603 if (Ops.size() == Count)
2604 return Mul;
2605 Ops[i] = Mul;
2606 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2607 --i; e -= Count - 1;
2608 FoundMatch = true;
2609 }
2610 if (FoundMatch)
2611 return getAddExpr(Ops, OrigFlags, Depth + 1);
2612
2613 // Check for truncates. If all the operands are truncated from the same
2614 // type, see if factoring out the truncate would permit the result to be
2615 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2616 // if the contents of the resulting outer trunc fold to something simple.
2617 auto FindTruncSrcType = [&]() -> Type * {
2618 // We're ultimately looking to fold an addrec of truncs and muls of only
2619 // constants and truncs, so if we find any other types of SCEV
2620 // as operands of the addrec then we bail and return nullptr here.
2621 // Otherwise, we return the type of the operand of a trunc that we find.
2622 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2623 return T->getOperand()->getType();
2624 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2625 const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2626 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2627 return T->getOperand()->getType();
2628 }
2629 return nullptr;
2630 };
2631 if (auto *SrcType = FindTruncSrcType()) {
2633 bool Ok = true;
2634 // Check all the operands to see if they can be represented in the
2635 // source type of the truncate.
2636 for (const SCEV *Op : Ops) {
2638 if (T->getOperand()->getType() != SrcType) {
2639 Ok = false;
2640 break;
2641 }
2642 LargeOps.push_back(T->getOperand());
2643 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2644 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2645 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2646 SmallVector<const SCEV *, 8> LargeMulOps;
2647 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2648 if (const SCEVTruncateExpr *T =
2649 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2650 if (T->getOperand()->getType() != SrcType) {
2651 Ok = false;
2652 break;
2653 }
2654 LargeMulOps.push_back(T->getOperand());
2655 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2656 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2657 } else {
2658 Ok = false;
2659 break;
2660 }
2661 }
2662 if (Ok)
2663 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2664 } else {
2665 Ok = false;
2666 break;
2667 }
2668 }
2669 if (Ok) {
2670 // Evaluate the expression in the larger type.
2671 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2672 // If it folds to something simple, use it. Otherwise, don't.
2673 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2674 return getTruncateExpr(Fold, Ty);
2675 }
2676 }
2677
2678 if (Ops.size() == 2) {
2679 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2680 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2681 // C1).
2682 const SCEV *A = Ops[0];
2683 const SCEV *B = Ops[1];
2684 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2685 auto *C = dyn_cast<SCEVConstant>(A);
2686 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2687 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2688 auto C2 = C->getAPInt();
2689 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2690
2691 APInt ConstAdd = C1 + C2;
2692 auto AddFlags = AddExpr->getNoWrapFlags();
2693 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2695 ConstAdd.ule(C1)) {
2696 PreservedFlags =
2698 }
2699
2700 // Adding a constant with the same sign and small magnitude is NSW, if the
2701 // original AddExpr was NSW.
2703 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2704 ConstAdd.abs().ule(C1.abs())) {
2705 PreservedFlags =
2707 }
2708
2709 if (PreservedFlags != SCEV::FlagAnyWrap) {
2710 SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2711 NewOps[0] = getConstant(ConstAdd);
2712 return getAddExpr(NewOps, PreservedFlags);
2713 }
2714 }
2715
2716 // Try to push the constant operand into a ZExt: A + zext (-A + B) -> zext
2717 // (B), if trunc (A) + -A + B does not unsigned-wrap.
2718 const SCEVAddExpr *InnerAdd;
2719 if (match(B, m_scev_ZExt(m_scev_Add(InnerAdd)))) {
2720 const SCEV *NarrowA = getTruncateExpr(A, InnerAdd->getType());
2721 if (NarrowA == getNegativeSCEV(InnerAdd->getOperand(0)) &&
2722 getZeroExtendExpr(NarrowA, B->getType()) == A &&
2723 hasFlags(StrengthenNoWrapFlags(this, scAddExpr, {NarrowA, InnerAdd},
2725 SCEV::FlagNUW)) {
2726 return getZeroExtendExpr(getAddExpr(NarrowA, InnerAdd), B->getType());
2727 }
2728 }
2729 }
2730
2731 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2732 const SCEV *Y;
2733 if (Ops.size() == 2 &&
2734 match(Ops[0],
2736 m_scev_URem(m_scev_Specific(Ops[1]), m_SCEV(Y), *this))))
2737 return getMulExpr(Y, getUDivExpr(Ops[1], Y));
2738
2739 // Skip past any other cast SCEVs.
2740 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2741 ++Idx;
2742
2743 // If there are add operands they would be next.
2744 if (Idx < Ops.size()) {
2745 bool DeletedAdd = false;
2746 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2747 // common NUW flag for expression after inlining. Other flags cannot be
2748 // preserved, because they may depend on the original order of operations.
2749 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2750 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2751 if (Ops.size() > AddOpsInlineThreshold ||
2752 Add->getNumOperands() > AddOpsInlineThreshold)
2753 break;
2754 // If we have an add, expand the add operands onto the end of the operands
2755 // list.
2756 Ops.erase(Ops.begin()+Idx);
2757 append_range(Ops, Add->operands());
2758 DeletedAdd = true;
2759 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2760 }
2761
2762 // If we deleted at least one add, we added operands to the end of the list,
2763 // and they are not necessarily sorted. Recurse to resort and resimplify
2764 // any operands we just acquired.
2765 if (DeletedAdd)
2766 return getAddExpr(Ops, CommonFlags, Depth + 1);
2767 }
2768
2769 // Skip over the add expression until we get to a multiply.
2770 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2771 ++Idx;
2772
2773 // Check to see if there are any folding opportunities present with
2774 // operands multiplied by constant values.
2775 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2779 APInt AccumulatedConstant(BitWidth, 0);
2780 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2781 Ops, APInt(BitWidth, 1), *this)) {
2782 struct APIntCompare {
2783 bool operator()(const APInt &LHS, const APInt &RHS) const {
2784 return LHS.ult(RHS);
2785 }
2786 };
2787
2788 // Some interesting folding opportunity is present, so its worthwhile to
2789 // re-generate the operands list. Group the operands by constant scale,
2790 // to avoid multiplying by the same constant scale multiple times.
2791 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2792 for (const SCEV *NewOp : NewOps)
2793 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2794 // Re-generate the operands list.
2795 Ops.clear();
2796 if (AccumulatedConstant != 0)
2797 Ops.push_back(getConstant(AccumulatedConstant));
2798 for (auto &MulOp : MulOpLists) {
2799 if (MulOp.first == 1) {
2800 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2801 } else if (MulOp.first != 0) {
2802 Ops.push_back(getMulExpr(
2803 getConstant(MulOp.first),
2804 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2805 SCEV::FlagAnyWrap, Depth + 1));
2806 }
2807 }
2808 if (Ops.empty())
2809 return getZero(Ty);
2810 if (Ops.size() == 1)
2811 return Ops[0];
2812 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2813 }
2814 }
2815
2816 // If we are adding something to a multiply expression, make sure the
2817 // something is not already an operand of the multiply. If so, merge it into
2818 // the multiply.
2819 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2820 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2821 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2822 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2823 if (isa<SCEVConstant>(MulOpSCEV))
2824 continue;
2825 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2826 if (MulOpSCEV == Ops[AddOp]) {
2827 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2828 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2829 if (Mul->getNumOperands() != 2) {
2830 // If the multiply has more than two operands, we must get the
2831 // Y*Z term.
2833 Mul->operands().take_front(MulOp));
2834 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2835 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2836 }
2837 SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2838 const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2839 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2841 if (Ops.size() == 2) return OuterMul;
2842 if (AddOp < Idx) {
2843 Ops.erase(Ops.begin()+AddOp);
2844 Ops.erase(Ops.begin()+Idx-1);
2845 } else {
2846 Ops.erase(Ops.begin()+Idx);
2847 Ops.erase(Ops.begin()+AddOp-1);
2848 }
2849 Ops.push_back(OuterMul);
2850 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2851 }
2852
2853 // Check this multiply against other multiplies being added together.
2854 for (unsigned OtherMulIdx = Idx+1;
2855 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2856 ++OtherMulIdx) {
2857 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2858 // If MulOp occurs in OtherMul, we can fold the two multiplies
2859 // together.
2860 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2861 OMulOp != e; ++OMulOp)
2862 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2863 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2864 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2865 if (Mul->getNumOperands() != 2) {
2867 Mul->operands().take_front(MulOp));
2868 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2869 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2870 }
2871 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2872 if (OtherMul->getNumOperands() != 2) {
2874 OtherMul->operands().take_front(OMulOp));
2875 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2876 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2877 }
2878 SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2879 const SCEV *InnerMulSum =
2880 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2881 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2883 if (Ops.size() == 2) return OuterMul;
2884 Ops.erase(Ops.begin()+Idx);
2885 Ops.erase(Ops.begin()+OtherMulIdx-1);
2886 Ops.push_back(OuterMul);
2887 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2888 }
2889 }
2890 }
2891 }
2892
2893 // If there are any add recurrences in the operands list, see if any other
2894 // added values are loop invariant. If so, we can fold them into the
2895 // recurrence.
2896 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2897 ++Idx;
2898
2899 // Scan over all recurrences, trying to fold loop invariants into them.
2900 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2901 // Scan all of the other operands to this add and add them to the vector if
2902 // they are loop invariant w.r.t. the recurrence.
2904 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2905 const Loop *AddRecLoop = AddRec->getLoop();
2906 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2907 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2908 LIOps.push_back(Ops[i]);
2909 Ops.erase(Ops.begin()+i);
2910 --i; --e;
2911 }
2912
2913 // If we found some loop invariants, fold them into the recurrence.
2914 if (!LIOps.empty()) {
2915 // Compute nowrap flags for the addition of the loop-invariant ops and
2916 // the addrec. Temporarily push it as an operand for that purpose. These
2917 // flags are valid in the scope of the addrec only.
2918 LIOps.push_back(AddRec);
2919 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2920 LIOps.pop_back();
2921
2922 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2923 LIOps.push_back(AddRec->getStart());
2924
2925 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2926
2927 // It is not in general safe to propagate flags valid on an add within
2928 // the addrec scope to one outside it. We must prove that the inner
2929 // scope is guaranteed to execute if the outer one does to be able to
2930 // safely propagate. We know the program is undefined if poison is
2931 // produced on the inner scoped addrec. We also know that *for this use*
2932 // the outer scoped add can't overflow (because of the flags we just
2933 // computed for the inner scoped add) without the program being undefined.
2934 // Proving that entry to the outer scope neccesitates entry to the inner
2935 // scope, thus proves the program undefined if the flags would be violated
2936 // in the outer scope.
2937 SCEV::NoWrapFlags AddFlags = Flags;
2938 if (AddFlags != SCEV::FlagAnyWrap) {
2939 auto *DefI = getDefiningScopeBound(LIOps);
2940 auto *ReachI = &*AddRecLoop->getHeader()->begin();
2941 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2942 AddFlags = SCEV::FlagAnyWrap;
2943 }
2944 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2945
2946 // Build the new addrec. Propagate the NUW and NSW flags if both the
2947 // outer add and the inner addrec are guaranteed to have no overflow.
2948 // Always propagate NW.
2949 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2950 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2951
2952 // If all of the other operands were loop invariant, we are done.
2953 if (Ops.size() == 1) return NewRec;
2954
2955 // Otherwise, add the folded AddRec by the non-invariant parts.
2956 for (unsigned i = 0;; ++i)
2957 if (Ops[i] == AddRec) {
2958 Ops[i] = NewRec;
2959 break;
2960 }
2961 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2962 }
2963
2964 // Okay, if there weren't any loop invariants to be folded, check to see if
2965 // there are multiple AddRec's with the same loop induction variable being
2966 // added together. If so, we can fold them.
2967 for (unsigned OtherIdx = Idx+1;
2968 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2969 ++OtherIdx) {
2970 // We expect the AddRecExpr's to be sorted in reverse dominance order,
2971 // so that the 1st found AddRecExpr is dominated by all others.
2972 assert(DT.dominates(
2973 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2974 AddRec->getLoop()->getHeader()) &&
2975 "AddRecExprs are not sorted in reverse dominance order?");
2976 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2977 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2978 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2979 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2980 ++OtherIdx) {
2981 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2982 if (OtherAddRec->getLoop() == AddRecLoop) {
2983 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2984 i != e; ++i) {
2985 if (i >= AddRecOps.size()) {
2986 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
2987 break;
2988 }
2990 AddRecOps[i], OtherAddRec->getOperand(i)};
2991 AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2992 }
2993 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2994 }
2995 }
2996 // Step size has changed, so we cannot guarantee no self-wraparound.
2997 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2998 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2999 }
3000 }
3001
3002 // Otherwise couldn't fold anything into this recurrence. Move onto the
3003 // next one.
3004 }
3005
3006 // Okay, it looks like we really DO need an add expr. Check to see if we
3007 // already have one, otherwise create a new one.
3008 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
3009}
3010
3011const SCEV *
3012ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
3013 SCEV::NoWrapFlags Flags) {
3015 ID.AddInteger(scAddExpr);
3016 for (const SCEV *Op : Ops)
3017 ID.AddPointer(Op);
3018 void *IP = nullptr;
3019 SCEVAddExpr *S =
3020 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3021 if (!S) {
3022 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3024 S = new (SCEVAllocator)
3025 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
3026 UniqueSCEVs.InsertNode(S, IP);
3027 registerUser(S, Ops);
3028 }
3029 S->setNoWrapFlags(Flags);
3030 return S;
3031}
3032
3033const SCEV *
3034ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
3035 const Loop *L, SCEV::NoWrapFlags Flags) {
3036 FoldingSetNodeID ID;
3037 ID.AddInteger(scAddRecExpr);
3038 for (const SCEV *Op : Ops)
3039 ID.AddPointer(Op);
3040 ID.AddPointer(L);
3041 void *IP = nullptr;
3042 SCEVAddRecExpr *S =
3043 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3044 if (!S) {
3045 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3047 S = new (SCEVAllocator)
3048 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3049 UniqueSCEVs.InsertNode(S, IP);
3050 LoopUsers[L].push_back(S);
3051 registerUser(S, Ops);
3052 }
3053 setNoWrapFlags(S, Flags);
3054 return S;
3055}
3056
3057const SCEV *
3058ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
3059 SCEV::NoWrapFlags Flags) {
3060 FoldingSetNodeID ID;
3061 ID.AddInteger(scMulExpr);
3062 for (const SCEV *Op : Ops)
3063 ID.AddPointer(Op);
3064 void *IP = nullptr;
3065 SCEVMulExpr *S =
3066 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3067 if (!S) {
3068 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3070 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3071 O, Ops.size());
3072 UniqueSCEVs.InsertNode(S, IP);
3073 registerUser(S, Ops);
3074 }
3075 S->setNoWrapFlags(Flags);
3076 return S;
3077}
3078
3079static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3080 uint64_t k = i*j;
3081 if (j > 1 && k / j != i) Overflow = true;
3082 return k;
3083}
3084
3085/// Compute the result of "n choose k", the binomial coefficient. If an
3086/// intermediate computation overflows, Overflow will be set and the return will
3087/// be garbage. Overflow is not cleared on absence of overflow.
3088static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3089 // We use the multiplicative formula:
3090 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3091 // At each iteration, we take the n-th term of the numeral and divide by the
3092 // (k-n)th term of the denominator. This division will always produce an
3093 // integral result, and helps reduce the chance of overflow in the
3094 // intermediate computations. However, we can still overflow even when the
3095 // final result would fit.
3096
3097 if (n == 0 || n == k) return 1;
3098 if (k > n) return 0;
3099
3100 if (k > n/2)
3101 k = n-k;
3102
3103 uint64_t r = 1;
3104 for (uint64_t i = 1; i <= k; ++i) {
3105 r = umul_ov(r, n-(i-1), Overflow);
3106 r /= i;
3107 }
3108 return r;
3109}
3110
3111/// Determine if any of the operands in this SCEV are a constant or if
3112/// any of the add or multiply expressions in this SCEV contain a constant.
3113static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3114 struct FindConstantInAddMulChain {
3115 bool FoundConstant = false;
3116
3117 bool follow(const SCEV *S) {
3118 FoundConstant |= isa<SCEVConstant>(S);
3119 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3120 }
3121
3122 bool isDone() const {
3123 return FoundConstant;
3124 }
3125 };
3126
3127 FindConstantInAddMulChain F;
3129 ST.visitAll(StartExpr);
3130 return F.FoundConstant;
3131}
3132
3133/// Get a canonical multiply expression, or something simpler if possible.
3135 SCEV::NoWrapFlags OrigFlags,
3136 unsigned Depth) {
3137 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3138 "only nuw or nsw allowed");
3139 assert(!Ops.empty() && "Cannot get empty mul!");
3140 if (Ops.size() == 1) return Ops[0];
3141#ifndef NDEBUG
3142 Type *ETy = Ops[0]->getType();
3143 assert(!ETy->isPointerTy());
3144 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3145 assert(Ops[i]->getType() == ETy &&
3146 "SCEVMulExpr operand types don't match!");
3147#endif
3148
3149 const SCEV *Folded = constantFoldAndGroupOps(
3150 *this, LI, DT, Ops,
3151 [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3152 [](const APInt &C) { return C.isOne(); }, // identity
3153 [](const APInt &C) { return C.isZero(); }); // absorber
3154 if (Folded)
3155 return Folded;
3156
3157 // Delay expensive flag strengthening until necessary.
3158 auto ComputeFlags = [this, OrigFlags](ArrayRef<const SCEV *> Ops) {
3159 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3160 };
3161
3162 // Limit recursion calls depth.
3164 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3165
3166 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3167 // Don't strengthen flags if we have no new information.
3168 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3169 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3170 Mul->setNoWrapFlags(ComputeFlags(Ops));
3171 return S;
3172 }
3173
3174 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3175 if (Ops.size() == 2) {
3176 // C1*(C2+V) -> C1*C2 + C1*V
3177 // If any of Add's ops are Adds or Muls with a constant, apply this
3178 // transformation as well.
3179 //
3180 // TODO: There are some cases where this transformation is not
3181 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3182 // this transformation should be narrowed down.
3183 const SCEV *Op0, *Op1;
3184 if (match(Ops[1], m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))) &&
3186 const SCEV *LHS = getMulExpr(LHSC, Op0, SCEV::FlagAnyWrap, Depth + 1);
3187 const SCEV *RHS = getMulExpr(LHSC, Op1, SCEV::FlagAnyWrap, Depth + 1);
3188 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3189 }
3190
3191 if (Ops[0]->isAllOnesValue()) {
3192 // If we have a mul by -1 of an add, try distributing the -1 among the
3193 // add operands.
3194 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3196 bool AnyFolded = false;
3197 for (const SCEV *AddOp : Add->operands()) {
3198 const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3199 Depth + 1);
3200 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3201 NewOps.push_back(Mul);
3202 }
3203 if (AnyFolded)
3204 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3205 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3206 // Negation preserves a recurrence's no self-wrap property.
3208 for (const SCEV *AddRecOp : AddRec->operands())
3209 Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3210 Depth + 1));
3211 // Let M be the minimum representable signed value. AddRec with nsw
3212 // multiplied by -1 can have signed overflow if and only if it takes a
3213 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3214 // maximum signed value. In all other cases signed overflow is
3215 // impossible.
3216 auto FlagsMask = SCEV::FlagNW;
3217 if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) {
3218 auto MinInt =
3219 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3220 if (getSignedRangeMin(AddRec) != MinInt)
3221 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3222 }
3223 return getAddRecExpr(Operands, AddRec->getLoop(),
3224 AddRec->getNoWrapFlags(FlagsMask));
3225 }
3226 }
3227
3228 // Try to push the constant operand into a ZExt: C * zext (A + B) ->
3229 // zext (C*A + C*B) if trunc (C) * (A + B) does not unsigned-wrap.
3230 const SCEVAddExpr *InnerAdd;
3231 if (match(Ops[1], m_scev_ZExt(m_scev_Add(InnerAdd)))) {
3232 const SCEV *NarrowC = getTruncateExpr(LHSC, InnerAdd->getType());
3233 if (isa<SCEVConstant>(InnerAdd->getOperand(0)) &&
3234 getZeroExtendExpr(NarrowC, Ops[1]->getType()) == LHSC &&
3235 hasFlags(StrengthenNoWrapFlags(this, scMulExpr, {NarrowC, InnerAdd},
3237 SCEV::FlagNUW)) {
3238 auto *Res = getMulExpr(NarrowC, InnerAdd, SCEV::FlagNUW, Depth + 1);
3239 return getZeroExtendExpr(Res, Ops[1]->getType(), Depth + 1);
3240 };
3241 }
3242
3243 // Try to fold (C1 * D /u C2) -> C1/C2 * D, if C1 and C2 are powers-of-2,
3244 // D is a multiple of C2, and C1 is a multiple of C2. If C2 is a multiple
3245 // of C1, fold to (D /u (C2 /u C1)).
3246 const SCEV *D;
3247 APInt C1V = LHSC->getAPInt();
3248 // (C1 * D /u C2) == -1 * -C1 * D /u C2 when C1 != INT_MIN. Don't treat -1
3249 // as -1 * 1, as it won't enable additional folds.
3250 if (C1V.isNegative() && !C1V.isMinSignedValue() && !C1V.isAllOnes())
3251 C1V = C1V.abs();
3252 const SCEVConstant *C2;
3253 if (C1V.isPowerOf2() &&
3255 C2->getAPInt().isPowerOf2() &&
3256 C1V.logBase2() <= getMinTrailingZeros(D)) {
3257 const SCEV *NewMul = nullptr;
3258 if (C1V.uge(C2->getAPInt())) {
3259 NewMul = getMulExpr(getUDivExpr(getConstant(C1V), C2), D);
3260 } else if (C2->getAPInt().logBase2() <= getMinTrailingZeros(D)) {
3261 assert(C1V.ugt(1) && "C1 <= 1 should have been folded earlier");
3262 NewMul = getUDivExpr(D, getUDivExpr(C2, getConstant(C1V)));
3263 }
3264 if (NewMul)
3265 return C1V == LHSC->getAPInt() ? NewMul : getNegativeSCEV(NewMul);
3266 }
3267 }
3268 }
3269
3270 // Skip over the add expression until we get to a multiply.
3271 unsigned Idx = 0;
3272 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3273 ++Idx;
3274
3275 // If there are mul operands inline them all into this expression.
3276 if (Idx < Ops.size()) {
3277 bool DeletedMul = false;
3278 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3279 if (Ops.size() > MulOpsInlineThreshold)
3280 break;
3281 // If we have an mul, expand the mul operands onto the end of the
3282 // operands list.
3283 Ops.erase(Ops.begin()+Idx);
3284 append_range(Ops, Mul->operands());
3285 DeletedMul = true;
3286 }
3287
3288 // If we deleted at least one mul, we added operands to the end of the
3289 // list, and they are not necessarily sorted. Recurse to resort and
3290 // resimplify any operands we just acquired.
3291 if (DeletedMul)
3292 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3293 }
3294
3295 // If there are any add recurrences in the operands list, see if any other
3296 // added values are loop invariant. If so, we can fold them into the
3297 // recurrence.
3298 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3299 ++Idx;
3300
3301 // Scan over all recurrences, trying to fold loop invariants into them.
3302 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3303 // Scan all of the other operands to this mul and add them to the vector
3304 // if they are loop invariant w.r.t. the recurrence.
3306 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3307 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3308 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3309 LIOps.push_back(Ops[i]);
3310 Ops.erase(Ops.begin()+i);
3311 --i; --e;
3312 }
3313
3314 // If we found some loop invariants, fold them into the recurrence.
3315 if (!LIOps.empty()) {
3316 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3318 NewOps.reserve(AddRec->getNumOperands());
3319 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3320
3321 // If both the mul and addrec are nuw, we can preserve nuw.
3322 // If both the mul and addrec are nsw, we can only preserve nsw if either
3323 // a) they are also nuw, or
3324 // b) all multiplications of addrec operands with scale are nsw.
3325 SCEV::NoWrapFlags Flags =
3326 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3327
3328 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3329 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3330 SCEV::FlagAnyWrap, Depth + 1));
3331
3332 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3334 Instruction::Mul, getSignedRange(Scale),
3336 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3337 Flags = clearFlags(Flags, SCEV::FlagNSW);
3338 }
3339 }
3340
3341 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3342
3343 // If all of the other operands were loop invariant, we are done.
3344 if (Ops.size() == 1) return NewRec;
3345
3346 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3347 for (unsigned i = 0;; ++i)
3348 if (Ops[i] == AddRec) {
3349 Ops[i] = NewRec;
3350 break;
3351 }
3352 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3353 }
3354
3355 // Okay, if there weren't any loop invariants to be folded, check to see
3356 // if there are multiple AddRec's with the same loop induction variable
3357 // being multiplied together. If so, we can fold them.
3358
3359 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3360 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3361 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3362 // ]]],+,...up to x=2n}.
3363 // Note that the arguments to choose() are always integers with values
3364 // known at compile time, never SCEV objects.
3365 //
3366 // The implementation avoids pointless extra computations when the two
3367 // addrec's are of different length (mathematically, it's equivalent to
3368 // an infinite stream of zeros on the right).
3369 bool OpsModified = false;
3370 for (unsigned OtherIdx = Idx+1;
3371 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3372 ++OtherIdx) {
3373 const SCEVAddRecExpr *OtherAddRec =
3374 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3375 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3376 continue;
3377
3378 // Limit max number of arguments to avoid creation of unreasonably big
3379 // SCEVAddRecs with very complex operands.
3380 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3381 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3382 continue;
3383
3384 bool Overflow = false;
3385 Type *Ty = AddRec->getType();
3386 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3388 for (int x = 0, xe = AddRec->getNumOperands() +
3389 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3391 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3392 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3393 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3394 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3395 z < ze && !Overflow; ++z) {
3396 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3397 uint64_t Coeff;
3398 if (LargerThan64Bits)
3399 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3400 else
3401 Coeff = Coeff1*Coeff2;
3402 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3403 const SCEV *Term1 = AddRec->getOperand(y-z);
3404 const SCEV *Term2 = OtherAddRec->getOperand(z);
3405 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3406 SCEV::FlagAnyWrap, Depth + 1));
3407 }
3408 }
3409 if (SumOps.empty())
3410 SumOps.push_back(getZero(Ty));
3411 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3412 }
3413 if (!Overflow) {
3414 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3416 if (Ops.size() == 2) return NewAddRec;
3417 Ops[Idx] = NewAddRec;
3418 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3419 OpsModified = true;
3420 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3421 if (!AddRec)
3422 break;
3423 }
3424 }
3425 if (OpsModified)
3426 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3427
3428 // Otherwise couldn't fold anything into this recurrence. Move onto the
3429 // next one.
3430 }
3431
3432 // Okay, it looks like we really DO need an mul expr. Check to see if we
3433 // already have one, otherwise create a new one.
3434 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3435}
3436
3437/// Represents an unsigned remainder expression based on unsigned division.
3439 const SCEV *RHS) {
3440 assert(getEffectiveSCEVType(LHS->getType()) ==
3441 getEffectiveSCEVType(RHS->getType()) &&
3442 "SCEVURemExpr operand types don't match!");
3443
3444 // Short-circuit easy cases
3445 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3446 // If constant is one, the result is trivial
3447 if (RHSC->getValue()->isOne())
3448 return getZero(LHS->getType()); // X urem 1 --> 0
3449
3450 // If constant is a power of two, fold into a zext(trunc(LHS)).
3451 if (RHSC->getAPInt().isPowerOf2()) {
3452 Type *FullTy = LHS->getType();
3453 Type *TruncTy =
3454 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3455 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3456 }
3457 }
3458
3459 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3460 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3461 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3462 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3463}
3464
3465/// Get a canonical unsigned division expression, or something simpler if
3466/// possible.
3468 const SCEV *RHS) {
3469 assert(!LHS->getType()->isPointerTy() &&
3470 "SCEVUDivExpr operand can't be pointer!");
3471 assert(LHS->getType() == RHS->getType() &&
3472 "SCEVUDivExpr operand types don't match!");
3473
3475 ID.AddInteger(scUDivExpr);
3476 ID.AddPointer(LHS);
3477 ID.AddPointer(RHS);
3478 void *IP = nullptr;
3479 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3480 return S;
3481
3482 // 0 udiv Y == 0
3483 if (match(LHS, m_scev_Zero()))
3484 return LHS;
3485
3486 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3487 if (RHSC->getValue()->isOne())
3488 return LHS; // X udiv 1 --> x
3489 // If the denominator is zero, the result of the udiv is undefined. Don't
3490 // try to analyze it, because the resolution chosen here may differ from
3491 // the resolution chosen in other parts of the compiler.
3492 if (!RHSC->getValue()->isZero()) {
3493 // Determine if the division can be folded into the operands of
3494 // its operands.
3495 // TODO: Generalize this to non-constants by using known-bits information.
3496 Type *Ty = LHS->getType();
3497 unsigned LZ = RHSC->getAPInt().countl_zero();
3498 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3499 // For non-power-of-two values, effectively round the value up to the
3500 // nearest power of two.
3501 if (!RHSC->getAPInt().isPowerOf2())
3502 ++MaxShiftAmt;
3503 IntegerType *ExtTy =
3504 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3505 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3506 if (const SCEVConstant *Step =
3507 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3508 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3509 const APInt &StepInt = Step->getAPInt();
3510 const APInt &DivInt = RHSC->getAPInt();
3511 if (!StepInt.urem(DivInt) &&
3512 getZeroExtendExpr(AR, ExtTy) ==
3513 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3514 getZeroExtendExpr(Step, ExtTy),
3515 AR->getLoop(), SCEV::FlagAnyWrap)) {
3517 for (const SCEV *Op : AR->operands())
3518 Operands.push_back(getUDivExpr(Op, RHS));
3519 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3520 }
3521 /// Get a canonical UDivExpr for a recurrence.
3522 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3523 const APInt *StartRem;
3524 if (!DivInt.urem(StepInt) && match(getURemExpr(AR->getStart(), Step),
3525 m_scev_APInt(StartRem))) {
3526 bool NoWrap =
3527 getZeroExtendExpr(AR, ExtTy) ==
3528 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3529 getZeroExtendExpr(Step, ExtTy), AR->getLoop(),
3531
3532 // With N <= C and both N, C as powers-of-2, the transformation
3533 // {X,+,N}/C => {(X - X%N),+,N}/C preserves division results even
3534 // if wrapping occurs, as the division results remain equivalent for
3535 // all offsets in [[(X - X%N), X).
3536 bool CanFoldWithWrap = StepInt.ule(DivInt) && // N <= C
3537 StepInt.isPowerOf2() && DivInt.isPowerOf2();
3538 // Only fold if the subtraction can be folded in the start
3539 // expression.
3540 const SCEV *NewStart =
3541 getMinusSCEV(AR->getStart(), getConstant(*StartRem));
3542 if (*StartRem != 0 && (NoWrap || CanFoldWithWrap) &&
3543 !isa<SCEVAddExpr>(NewStart)) {
3544 const SCEV *NewLHS =
3545 getAddRecExpr(NewStart, Step, AR->getLoop(),
3546 NoWrap ? SCEV::FlagNW : SCEV::FlagAnyWrap);
3547 if (LHS != NewLHS) {
3548 LHS = NewLHS;
3549
3550 // Reset the ID to include the new LHS, and check if it is
3551 // already cached.
3552 ID.clear();
3553 ID.AddInteger(scUDivExpr);
3554 ID.AddPointer(LHS);
3555 ID.AddPointer(RHS);
3556 IP = nullptr;
3557 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3558 return S;
3559 }
3560 }
3561 }
3562 }
3563 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3564 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3566 for (const SCEV *Op : M->operands())
3567 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3568 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3569 // Find an operand that's safely divisible.
3570 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3571 const SCEV *Op = M->getOperand(i);
3572 const SCEV *Div = getUDivExpr(Op, RHSC);
3573 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3574 Operands = SmallVector<const SCEV *, 4>(M->operands());
3575 Operands[i] = Div;
3576 return getMulExpr(Operands);
3577 }
3578 }
3579 }
3580
3581 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3582 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3583 if (auto *DivisorConstant =
3584 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3585 bool Overflow = false;
3586 APInt NewRHS =
3587 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3588 if (Overflow) {
3589 return getConstant(RHSC->getType(), 0, false);
3590 }
3591 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3592 }
3593 }
3594
3595 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3596 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3598 for (const SCEV *Op : A->operands())
3599 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3600 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3601 Operands.clear();
3602 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3603 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3604 if (isa<SCEVUDivExpr>(Op) ||
3605 getMulExpr(Op, RHS) != A->getOperand(i))
3606 break;
3607 Operands.push_back(Op);
3608 }
3609 if (Operands.size() == A->getNumOperands())
3610 return getAddExpr(Operands);
3611 }
3612 }
3613
3614 // Fold if both operands are constant.
3615 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3616 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3617 }
3618 }
3619
3620 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3621 const APInt *NegC, *C;
3622 if (match(LHS,
3625 NegC->isNegative() && !NegC->isMinSignedValue() && *C == -*NegC)
3626 return getZero(LHS->getType());
3627
3628 // TODO: Generalize to handle any common factors.
3629 // udiv (mul nuw a, vscale), (mul nuw b, vscale) --> udiv a, b
3630 const SCEV *NewLHS, *NewRHS;
3631 if (match(LHS, m_scev_c_NUWMul(m_SCEV(NewLHS), m_SCEVVScale())) &&
3632 match(RHS, m_scev_c_NUWMul(m_SCEV(NewRHS), m_SCEVVScale())))
3633 return getUDivExpr(NewLHS, NewRHS);
3634
3635 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3636 // changes). Make sure we get a new one.
3637 IP = nullptr;
3638 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3639 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3640 LHS, RHS);
3641 UniqueSCEVs.InsertNode(S, IP);
3642 registerUser(S, {LHS, RHS});
3643 return S;
3644}
3645
3646APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3647 APInt A = C1->getAPInt().abs();
3648 APInt B = C2->getAPInt().abs();
3649 uint32_t ABW = A.getBitWidth();
3650 uint32_t BBW = B.getBitWidth();
3651
3652 if (ABW > BBW)
3653 B = B.zext(ABW);
3654 else if (ABW < BBW)
3655 A = A.zext(BBW);
3656
3657 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3658}
3659
3660/// Get a canonical unsigned division expression, or something simpler if
3661/// possible. There is no representation for an exact udiv in SCEV IR, but we
3662/// can attempt to remove factors from the LHS and RHS. We can't do this when
3663/// it's not exact because the udiv may be clearing bits.
3665 const SCEV *RHS) {
3666 // TODO: we could try to find factors in all sorts of things, but for now we
3667 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3668 // end of this file for inspiration.
3669
3671 if (!Mul || !Mul->hasNoUnsignedWrap())
3672 return getUDivExpr(LHS, RHS);
3673
3674 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3675 // If the mulexpr multiplies by a constant, then that constant must be the
3676 // first element of the mulexpr.
3677 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3678 if (LHSCst == RHSCst) {
3679 SmallVector<const SCEV *, 2> Operands(drop_begin(Mul->operands()));
3680 return getMulExpr(Operands);
3681 }
3682
3683 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3684 // that there's a factor provided by one of the other terms. We need to
3685 // check.
3686 APInt Factor = gcd(LHSCst, RHSCst);
3687 if (!Factor.isIntN(1)) {
3688 LHSCst =
3689 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3690 RHSCst =
3691 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3693 Operands.push_back(LHSCst);
3694 append_range(Operands, Mul->operands().drop_front());
3695 LHS = getMulExpr(Operands);
3696 RHS = RHSCst;
3698 if (!Mul)
3699 return getUDivExactExpr(LHS, RHS);
3700 }
3701 }
3702 }
3703
3704 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3705 if (Mul->getOperand(i) == RHS) {
3707 append_range(Operands, Mul->operands().take_front(i));
3708 append_range(Operands, Mul->operands().drop_front(i + 1));
3709 return getMulExpr(Operands);
3710 }
3711 }
3712
3713 return getUDivExpr(LHS, RHS);
3714}
3715
3716/// Get an add recurrence expression for the specified loop. Simplify the
3717/// expression as much as possible.
3718const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3719 const Loop *L,
3720 SCEV::NoWrapFlags Flags) {
3722 Operands.push_back(Start);
3723 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3724 if (StepChrec->getLoop() == L) {
3725 append_range(Operands, StepChrec->operands());
3726 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3727 }
3728
3729 Operands.push_back(Step);
3730 return getAddRecExpr(Operands, L, Flags);
3731}
3732
3733/// Get an add recurrence expression for the specified loop. Simplify the
3734/// expression as much as possible.
3735const SCEV *
3737 const Loop *L, SCEV::NoWrapFlags Flags) {
3738 if (Operands.size() == 1) return Operands[0];
3739#ifndef NDEBUG
3740 Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3741 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3742 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3743 "SCEVAddRecExpr operand types don't match!");
3744 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3745 }
3746 for (const SCEV *Op : Operands)
3748 "SCEVAddRecExpr operand is not available at loop entry!");
3749#endif
3750
3751 if (Operands.back()->isZero()) {
3752 Operands.pop_back();
3753 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3754 }
3755
3756 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3757 // use that information to infer NUW and NSW flags. However, computing a
3758 // BE count requires calling getAddRecExpr, so we may not yet have a
3759 // meaningful BE count at this point (and if we don't, we'd be stuck
3760 // with a SCEVCouldNotCompute as the cached BE count).
3761
3762 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3763
3764 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3765 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3766 const Loop *NestedLoop = NestedAR->getLoop();
3767 if (L->contains(NestedLoop)
3768 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3769 : (!NestedLoop->contains(L) &&
3770 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3771 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3772 Operands[0] = NestedAR->getStart();
3773 // AddRecs require their operands be loop-invariant with respect to their
3774 // loops. Don't perform this transformation if it would break this
3775 // requirement.
3776 bool AllInvariant = all_of(
3777 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3778
3779 if (AllInvariant) {
3780 // Create a recurrence for the outer loop with the same step size.
3781 //
3782 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3783 // inner recurrence has the same property.
3784 SCEV::NoWrapFlags OuterFlags =
3785 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3786
3787 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3788 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3789 return isLoopInvariant(Op, NestedLoop);
3790 });
3791
3792 if (AllInvariant) {
3793 // Ok, both add recurrences are valid after the transformation.
3794 //
3795 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3796 // the outer recurrence has the same property.
3797 SCEV::NoWrapFlags InnerFlags =
3798 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3799 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3800 }
3801 }
3802 // Reset Operands to its original state.
3803 Operands[0] = NestedAR;
3804 }
3805 }
3806
3807 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3808 // already have one, otherwise create a new one.
3809 return getOrCreateAddRecExpr(Operands, L, Flags);
3810}
3811
3813 ArrayRef<const SCEV *> IndexExprs) {
3814 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3815 // getSCEV(Base)->getType() has the same address space as Base->getType()
3816 // because SCEV::getType() preserves the address space.
3817 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3818 if (NW != GEPNoWrapFlags::none()) {
3819 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3820 // but to do that, we have to ensure that said flag is valid in the entire
3821 // defined scope of the SCEV.
3822 // TODO: non-instructions have global scope. We might be able to prove
3823 // some global scope cases
3824 auto *GEPI = dyn_cast<Instruction>(GEP);
3825 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3826 NW = GEPNoWrapFlags::none();
3827 }
3828
3829 return getGEPExpr(BaseExpr, IndexExprs, GEP->getSourceElementType(), NW);
3830}
3831
3833 ArrayRef<const SCEV *> IndexExprs,
3834 Type *SrcElementTy, GEPNoWrapFlags NW) {
3836 if (NW.hasNoUnsignedSignedWrap())
3837 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3838 if (NW.hasNoUnsignedWrap())
3839 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3840
3841 Type *CurTy = BaseExpr->getType();
3842 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3843 bool FirstIter = true;
3845 for (const SCEV *IndexExpr : IndexExprs) {
3846 // Compute the (potentially symbolic) offset in bytes for this index.
3847 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3848 // For a struct, add the member offset.
3849 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3850 unsigned FieldNo = Index->getZExtValue();
3851 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3852 Offsets.push_back(FieldOffset);
3853
3854 // Update CurTy to the type of the field at Index.
3855 CurTy = STy->getTypeAtIndex(Index);
3856 } else {
3857 // Update CurTy to its element type.
3858 if (FirstIter) {
3859 assert(isa<PointerType>(CurTy) &&
3860 "The first index of a GEP indexes a pointer");
3861 CurTy = SrcElementTy;
3862 FirstIter = false;
3863 } else {
3865 }
3866 // For an array, add the element offset, explicitly scaled.
3867 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3868 // Getelementptr indices are signed.
3869 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3870
3871 // Multiply the index by the element size to compute the element offset.
3872 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3873 Offsets.push_back(LocalOffset);
3874 }
3875 }
3876
3877 // Handle degenerate case of GEP without offsets.
3878 if (Offsets.empty())
3879 return BaseExpr;
3880
3881 // Add the offsets together, assuming nsw if inbounds.
3882 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3883 // Add the base address and the offset. We cannot use the nsw flag, as the
3884 // base address is unsigned. However, if we know that the offset is
3885 // non-negative, we can use nuw.
3886 bool NUW = NW.hasNoUnsignedWrap() ||
3889 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3890 assert(BaseExpr->getType() == GEPExpr->getType() &&
3891 "GEP should not change type mid-flight.");
3892 return GEPExpr;
3893}
3894
3895SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3898 ID.AddInteger(SCEVType);
3899 for (const SCEV *Op : Ops)
3900 ID.AddPointer(Op);
3901 void *IP = nullptr;
3902 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3903}
3904
3905const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3907 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3908}
3909
3912 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3913 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3914 if (Ops.size() == 1) return Ops[0];
3915#ifndef NDEBUG
3916 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3917 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3918 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3919 "Operand types don't match!");
3920 assert(Ops[0]->getType()->isPointerTy() ==
3921 Ops[i]->getType()->isPointerTy() &&
3922 "min/max should be consistently pointerish");
3923 }
3924#endif
3925
3926 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3927 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3928
3929 const SCEV *Folded = constantFoldAndGroupOps(
3930 *this, LI, DT, Ops,
3931 [&](const APInt &C1, const APInt &C2) {
3932 switch (Kind) {
3933 case scSMaxExpr:
3934 return APIntOps::smax(C1, C2);
3935 case scSMinExpr:
3936 return APIntOps::smin(C1, C2);
3937 case scUMaxExpr:
3938 return APIntOps::umax(C1, C2);
3939 case scUMinExpr:
3940 return APIntOps::umin(C1, C2);
3941 default:
3942 llvm_unreachable("Unknown SCEV min/max opcode");
3943 }
3944 },
3945 [&](const APInt &C) {
3946 // identity
3947 if (IsMax)
3948 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3949 else
3950 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3951 },
3952 [&](const APInt &C) {
3953 // absorber
3954 if (IsMax)
3955 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3956 else
3957 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3958 });
3959 if (Folded)
3960 return Folded;
3961
3962 // Check if we have created the same expression before.
3963 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3964 return S;
3965 }
3966
3967 // Find the first operation of the same kind
3968 unsigned Idx = 0;
3969 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3970 ++Idx;
3971
3972 // Check to see if one of the operands is of the same kind. If so, expand its
3973 // operands onto our operand list, and recurse to simplify.
3974 if (Idx < Ops.size()) {
3975 bool DeletedAny = false;
3976 while (Ops[Idx]->getSCEVType() == Kind) {
3977 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3978 Ops.erase(Ops.begin()+Idx);
3979 append_range(Ops, SMME->operands());
3980 DeletedAny = true;
3981 }
3982
3983 if (DeletedAny)
3984 return getMinMaxExpr(Kind, Ops);
3985 }
3986
3987 // Okay, check to see if the same value occurs in the operand list twice. If
3988 // so, delete one. Since we sorted the list, these values are required to
3989 // be adjacent.
3994 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3995 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3996 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3997 if (Ops[i] == Ops[i + 1] ||
3998 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3999 // X op Y op Y --> X op Y
4000 // X op Y --> X, if we know X, Y are ordered appropriately
4001 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
4002 --i;
4003 --e;
4004 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
4005 Ops[i + 1])) {
4006 // X op Y --> Y, if we know X, Y are ordered appropriately
4007 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
4008 --i;
4009 --e;
4010 }
4011 }
4012
4013 if (Ops.size() == 1) return Ops[0];
4014
4015 assert(!Ops.empty() && "Reduced smax down to nothing!");
4016
4017 // Okay, it looks like we really DO need an expr. Check to see if we
4018 // already have one, otherwise create a new one.
4020 ID.AddInteger(Kind);
4021 for (const SCEV *Op : Ops)
4022 ID.AddPointer(Op);
4023 void *IP = nullptr;
4024 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4025 if (ExistingSCEV)
4026 return ExistingSCEV;
4027 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4029 SCEV *S = new (SCEVAllocator)
4030 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4031
4032 UniqueSCEVs.InsertNode(S, IP);
4033 registerUser(S, Ops);
4034 return S;
4035}
4036
4037namespace {
4038
4039class SCEVSequentialMinMaxDeduplicatingVisitor final
4040 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
4041 std::optional<const SCEV *>> {
4042 using RetVal = std::optional<const SCEV *>;
4044
4045 ScalarEvolution &SE;
4046 const SCEVTypes RootKind; // Must be a sequential min/max expression.
4047 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
4049
4050 bool canRecurseInto(SCEVTypes Kind) const {
4051 // We can only recurse into the SCEV expression of the same effective type
4052 // as the type of our root SCEV expression.
4053 return RootKind == Kind || NonSequentialRootKind == Kind;
4054 };
4055
4056 RetVal visitAnyMinMaxExpr(const SCEV *S) {
4058 "Only for min/max expressions.");
4059 SCEVTypes Kind = S->getSCEVType();
4060
4061 if (!canRecurseInto(Kind))
4062 return S;
4063
4064 auto *NAry = cast<SCEVNAryExpr>(S);
4066 bool Changed = visit(Kind, NAry->operands(), NewOps);
4067
4068 if (!Changed)
4069 return S;
4070 if (NewOps.empty())
4071 return std::nullopt;
4072
4074 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
4075 : SE.getMinMaxExpr(Kind, NewOps);
4076 }
4077
4078 RetVal visit(const SCEV *S) {
4079 // Has the whole operand been seen already?
4080 if (!SeenOps.insert(S).second)
4081 return std::nullopt;
4082 return Base::visit(S);
4083 }
4084
4085public:
4086 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4087 SCEVTypes RootKind)
4088 : SE(SE), RootKind(RootKind),
4089 NonSequentialRootKind(
4090 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4091 RootKind)) {}
4092
4093 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
4094 SmallVectorImpl<const SCEV *> &NewOps) {
4095 bool Changed = false;
4097 Ops.reserve(OrigOps.size());
4098
4099 for (const SCEV *Op : OrigOps) {
4100 RetVal NewOp = visit(Op);
4101 if (NewOp != Op)
4102 Changed = true;
4103 if (NewOp)
4104 Ops.emplace_back(*NewOp);
4105 }
4106
4107 if (Changed)
4108 NewOps = std::move(Ops);
4109 return Changed;
4110 }
4111
4112 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4113
4114 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4115
4116 RetVal visitPtrToAddrExpr(const SCEVPtrToAddrExpr *Expr) { return Expr; }
4117
4118 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4119
4120 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4121
4122 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4123
4124 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4125
4126 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4127
4128 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4129
4130 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4131
4132 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4133
4134 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4135 return visitAnyMinMaxExpr(Expr);
4136 }
4137
4138 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4139 return visitAnyMinMaxExpr(Expr);
4140 }
4141
4142 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4143 return visitAnyMinMaxExpr(Expr);
4144 }
4145
4146 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4147 return visitAnyMinMaxExpr(Expr);
4148 }
4149
4150 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4151 return visitAnyMinMaxExpr(Expr);
4152 }
4153
4154 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4155
4156 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4157};
4158
4159} // namespace
4160
4162 switch (Kind) {
4163 case scConstant:
4164 case scVScale:
4165 case scTruncate:
4166 case scZeroExtend:
4167 case scSignExtend:
4168 case scPtrToAddr:
4169 case scPtrToInt:
4170 case scAddExpr:
4171 case scMulExpr:
4172 case scUDivExpr:
4173 case scAddRecExpr:
4174 case scUMaxExpr:
4175 case scSMaxExpr:
4176 case scUMinExpr:
4177 case scSMinExpr:
4178 case scUnknown:
4179 // If any operand is poison, the whole expression is poison.
4180 return true;
4182 // FIXME: if the *first* operand is poison, the whole expression is poison.
4183 return false; // Pessimistically, say that it does not propagate poison.
4184 case scCouldNotCompute:
4185 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4186 }
4187 llvm_unreachable("Unknown SCEV kind!");
4188}
4189
4190namespace {
4191// The only way poison may be introduced in a SCEV expression is from a
4192// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4193// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4194// introduce poison -- they encode guaranteed, non-speculated knowledge.
4195//
4196// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4197// with the notable exception of umin_seq, where only poison from the first
4198// operand is (unconditionally) propagated.
4199struct SCEVPoisonCollector {
4200 bool LookThroughMaybePoisonBlocking;
4201 SmallPtrSet<const SCEVUnknown *, 4> MaybePoison;
4202 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4203 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4204
4205 bool follow(const SCEV *S) {
4206 if (!LookThroughMaybePoisonBlocking &&
4208 return false;
4209
4210 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4211 if (!isGuaranteedNotToBePoison(SU->getValue()))
4212 MaybePoison.insert(SU);
4213 }
4214 return true;
4215 }
4216 bool isDone() const { return false; }
4217};
4218} // namespace
4219
4220/// Return true if V is poison given that AssumedPoison is already poison.
4221static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4222 // First collect all SCEVs that might result in AssumedPoison to be poison.
4223 // We need to look through potentially poison-blocking operations here,
4224 // because we want to find all SCEVs that *might* result in poison, not only
4225 // those that are *required* to.
4226 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4227 visitAll(AssumedPoison, PC1);
4228
4229 // AssumedPoison is never poison. As the assumption is false, the implication
4230 // is true. Don't bother walking the other SCEV in this case.
4231 if (PC1.MaybePoison.empty())
4232 return true;
4233
4234 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4235 // as well. We cannot look through potentially poison-blocking operations
4236 // here, as their arguments only *may* make the result poison.
4237 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4238 visitAll(S, PC2);
4239
4240 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4241 // it will also make S poison by being part of PC2.MaybePoison.
4242 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison);
4243}
4244
4246 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4247 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4248 visitAll(S, PC);
4249 for (const SCEVUnknown *SU : PC.MaybePoison)
4250 Result.insert(SU->getValue());
4251}
4252
4254 const SCEV *S, Instruction *I,
4255 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4256 // If the instruction cannot be poison, it's always safe to reuse.
4258 return true;
4259
4260 // Otherwise, it is possible that I is more poisonous that S. Collect the
4261 // poison-contributors of S, and then check whether I has any additional
4262 // poison-contributors. Poison that is contributed through poison-generating
4263 // flags is handled by dropping those flags instead.
4265 getPoisonGeneratingValues(PoisonVals, S);
4266
4267 SmallVector<Value *> Worklist;
4269 Worklist.push_back(I);
4270 while (!Worklist.empty()) {
4271 Value *V = Worklist.pop_back_val();
4272 if (!Visited.insert(V).second)
4273 continue;
4274
4275 // Avoid walking large instruction graphs.
4276 if (Visited.size() > 16)
4277 return false;
4278
4279 // Either the value can't be poison, or the S would also be poison if it
4280 // is.
4281 if (PoisonVals.contains(V) || ::isGuaranteedNotToBePoison(V))
4282 continue;
4283
4284 auto *I = dyn_cast<Instruction>(V);
4285 if (!I)
4286 return false;
4287
4288 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4289 // can't replace an arbitrary add with disjoint or, even if we drop the
4290 // flag. We would need to convert the or into an add.
4291 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4292 if (PDI->isDisjoint())
4293 return false;
4294
4295 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4296 // because SCEV currently assumes it can't be poison. Remove this special
4297 // case once we proper model when vscale can be poison.
4298 if (auto *II = dyn_cast<IntrinsicInst>(I);
4299 II && II->getIntrinsicID() == Intrinsic::vscale)
4300 continue;
4301
4302 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4303 return false;
4304
4305 // If the instruction can't create poison, we can recurse to its operands.
4306 if (I->hasPoisonGeneratingAnnotations())
4307 DropPoisonGeneratingInsts.push_back(I);
4308
4309 llvm::append_range(Worklist, I->operands());
4310 }
4311 return true;
4312}
4313
4314const SCEV *
4317 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4318 "Not a SCEVSequentialMinMaxExpr!");
4319 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4320 if (Ops.size() == 1)
4321 return Ops[0];
4322#ifndef NDEBUG
4323 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4324 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4325 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4326 "Operand types don't match!");
4327 assert(Ops[0]->getType()->isPointerTy() ==
4328 Ops[i]->getType()->isPointerTy() &&
4329 "min/max should be consistently pointerish");
4330 }
4331#endif
4332
4333 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4334 // so we can *NOT* do any kind of sorting of the expressions!
4335
4336 // Check if we have created the same expression before.
4337 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4338 return S;
4339
4340 // FIXME: there are *some* simplifications that we can do here.
4341
4342 // Keep only the first instance of an operand.
4343 {
4344 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4345 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4346 if (Changed)
4347 return getSequentialMinMaxExpr(Kind, Ops);
4348 }
4349
4350 // Check to see if one of the operands is of the same kind. If so, expand its
4351 // operands onto our operand list, and recurse to simplify.
4352 {
4353 unsigned Idx = 0;
4354 bool DeletedAny = false;
4355 while (Idx < Ops.size()) {
4356 if (Ops[Idx]->getSCEVType() != Kind) {
4357 ++Idx;
4358 continue;
4359 }
4360 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4361 Ops.erase(Ops.begin() + Idx);
4362 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4363 SMME->operands().end());
4364 DeletedAny = true;
4365 }
4366
4367 if (DeletedAny)
4368 return getSequentialMinMaxExpr(Kind, Ops);
4369 }
4370
4371 const SCEV *SaturationPoint;
4373 switch (Kind) {
4375 SaturationPoint = getZero(Ops[0]->getType());
4376 Pred = ICmpInst::ICMP_ULE;
4377 break;
4378 default:
4379 llvm_unreachable("Not a sequential min/max type.");
4380 }
4381
4382 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4383 if (!isGuaranteedNotToCauseUB(Ops[i]))
4384 continue;
4385 // We can replace %x umin_seq %y with %x umin %y if either:
4386 // * %y being poison implies %x is also poison.
4387 // * %x cannot be the saturating value (e.g. zero for umin).
4388 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4389 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4390 SaturationPoint)) {
4391 SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
4392 Ops[i - 1] = getMinMaxExpr(
4394 SeqOps);
4395 Ops.erase(Ops.begin() + i);
4396 return getSequentialMinMaxExpr(Kind, Ops);
4397 }
4398 // Fold %x umin_seq %y to %x if %x ule %y.
4399 // TODO: We might be able to prove the predicate for a later operand.
4400 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4401 Ops.erase(Ops.begin() + i);
4402 return getSequentialMinMaxExpr(Kind, Ops);
4403 }
4404 }
4405
4406 // Okay, it looks like we really DO need an expr. Check to see if we
4407 // already have one, otherwise create a new one.
4409 ID.AddInteger(Kind);
4410 for (const SCEV *Op : Ops)
4411 ID.AddPointer(Op);
4412 void *IP = nullptr;
4413 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4414 if (ExistingSCEV)
4415 return ExistingSCEV;
4416
4417 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4419 SCEV *S = new (SCEVAllocator)
4420 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4421
4422 UniqueSCEVs.InsertNode(S, IP);
4423 registerUser(S, Ops);
4424 return S;
4425}
4426
4427const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4428 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4429 return getSMaxExpr(Ops);
4430}
4431
4435
4436const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4437 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4438 return getUMaxExpr(Ops);
4439}
4440
4444
4446 const SCEV *RHS) {
4447 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4448 return getSMinExpr(Ops);
4449}
4450
4454
4455const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS,
4456 bool Sequential) {
4457 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4458 return getUMinExpr(Ops, Sequential);
4459}
4460
4466
4467const SCEV *
4469 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4470 if (Size.isScalable())
4471 Res = getMulExpr(Res, getVScale(IntTy));
4472 return Res;
4473}
4474
4476 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4477}
4478
4480 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4481}
4482
4484 StructType *STy,
4485 unsigned FieldNo) {
4486 // We can bypass creating a target-independent constant expression and then
4487 // folding it back into a ConstantInt. This is just a compile-time
4488 // optimization.
4489 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4490 assert(!SL->getSizeInBits().isScalable() &&
4491 "Cannot get offset for structure containing scalable vector types");
4492 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4493}
4494
4496 // Don't attempt to do anything other than create a SCEVUnknown object
4497 // here. createSCEV only calls getUnknown after checking for all other
4498 // interesting possibilities, and any other code that calls getUnknown
4499 // is doing so in order to hide a value from SCEV canonicalization.
4500
4502 ID.AddInteger(scUnknown);
4503 ID.AddPointer(V);
4504 void *IP = nullptr;
4505 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4506 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4507 "Stale SCEVUnknown in uniquing map!");
4508 return S;
4509 }
4510 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4511 FirstUnknown);
4512 FirstUnknown = cast<SCEVUnknown>(S);
4513 UniqueSCEVs.InsertNode(S, IP);
4514 return S;
4515}
4516
4517//===----------------------------------------------------------------------===//
4518// Basic SCEV Analysis and PHI Idiom Recognition Code
4519//
4520
4521/// Test if values of the given type are analyzable within the SCEV
4522/// framework. This primarily includes integer types, and it can optionally
4523/// include pointer types if the ScalarEvolution class has access to
4524/// target-specific information.
4526 // Integers and pointers are always SCEVable.
4527 return Ty->isIntOrPtrTy();
4528}
4529
4530/// Return the size in bits of the specified type, for which isSCEVable must
4531/// return true.
4533 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4534 if (Ty->isPointerTy())
4536 return getDataLayout().getTypeSizeInBits(Ty);
4537}
4538
4539/// Return a type with the same bitwidth as the given type and which represents
4540/// how SCEV will treat the given type, for which isSCEVable must return
4541/// true. For pointer types, this is the pointer index sized integer type.
4543 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4544
4545 if (Ty->isIntegerTy())
4546 return Ty;
4547
4548 // The only other support type is pointer.
4549 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4550 return getDataLayout().getIndexType(Ty);
4551}
4552
4554 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4555}
4556
4558 const SCEV *B) {
4559 /// For a valid use point to exist, the defining scope of one operand
4560 /// must dominate the other.
4561 bool PreciseA, PreciseB;
4562 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4563 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4564 if (!PreciseA || !PreciseB)
4565 // Can't tell.
4566 return false;
4567 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4568 DT.dominates(ScopeB, ScopeA);
4569}
4570
4572 return CouldNotCompute.get();
4573}
4574
4575bool ScalarEvolution::checkValidity(const SCEV *S) const {
4576 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4577 auto *SU = dyn_cast<SCEVUnknown>(S);
4578 return SU && SU->getValue() == nullptr;
4579 });
4580
4581 return !ContainsNulls;
4582}
4583
4585 HasRecMapType::iterator I = HasRecMap.find(S);
4586 if (I != HasRecMap.end())
4587 return I->second;
4588
4589 bool FoundAddRec =
4590 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4591 HasRecMap.insert({S, FoundAddRec});
4592 return FoundAddRec;
4593}
4594
4595/// Return the ValueOffsetPair set for \p S. \p S can be represented
4596/// by the value and offset from any ValueOffsetPair in the set.
4597ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4598 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4599 if (SI == ExprValueMap.end())
4600 return {};
4601 return SI->second.getArrayRef();
4602}
4603
4604/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4605/// cannot be used separately. eraseValueFromMap should be used to remove
4606/// V from ValueExprMap and ExprValueMap at the same time.
4607void ScalarEvolution::eraseValueFromMap(Value *V) {
4608 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4609 if (I != ValueExprMap.end()) {
4610 auto EVIt = ExprValueMap.find(I->second);
4611 bool Removed = EVIt->second.remove(V);
4612 (void) Removed;
4613 assert(Removed && "Value not in ExprValueMap?");
4614 ValueExprMap.erase(I);
4615 }
4616}
4617
4618void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4619 // A recursive query may have already computed the SCEV. It should be
4620 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4621 // inferred nowrap flags.
4622 auto It = ValueExprMap.find_as(V);
4623 if (It == ValueExprMap.end()) {
4624 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4625 ExprValueMap[S].insert(V);
4626 }
4627}
4628
4629/// Return an existing SCEV if it exists, otherwise analyze the expression and
4630/// create a new one.
4632 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4633
4634 if (const SCEV *S = getExistingSCEV(V))
4635 return S;
4636 return createSCEVIter(V);
4637}
4638
4640 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4641
4642 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4643 if (I != ValueExprMap.end()) {
4644 const SCEV *S = I->second;
4645 assert(checkValidity(S) &&
4646 "existing SCEV has not been properly invalidated");
4647 return S;
4648 }
4649 return nullptr;
4650}
4651
4652/// Return a SCEV corresponding to -V = -1*V
4654 SCEV::NoWrapFlags Flags) {
4655 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4656 return getConstant(
4657 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4658
4659 Type *Ty = V->getType();
4660 Ty = getEffectiveSCEVType(Ty);
4661 return getMulExpr(V, getMinusOne(Ty), Flags);
4662}
4663
4664/// If Expr computes ~A, return A else return nullptr
4665static const SCEV *MatchNotExpr(const SCEV *Expr) {
4666 const SCEV *MulOp;
4667 if (match(Expr, m_scev_Add(m_scev_AllOnes(),
4668 m_scev_Mul(m_scev_AllOnes(), m_SCEV(MulOp)))))
4669 return MulOp;
4670 return nullptr;
4671}
4672
4673/// Return a SCEV corresponding to ~V = -1-V
4675 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4676
4677 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4678 return getConstant(
4679 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4680
4681 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4682 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4683 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4684 SmallVector<const SCEV *, 2> MatchedOperands;
4685 for (const SCEV *Operand : MME->operands()) {
4686 const SCEV *Matched = MatchNotExpr(Operand);
4687 if (!Matched)
4688 return (const SCEV *)nullptr;
4689 MatchedOperands.push_back(Matched);
4690 }
4691 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4692 MatchedOperands);
4693 };
4694 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4695 return Replaced;
4696 }
4697
4698 Type *Ty = V->getType();
4699 Ty = getEffectiveSCEVType(Ty);
4700 return getMinusSCEV(getMinusOne(Ty), V);
4701}
4702
4704 assert(P->getType()->isPointerTy());
4705
4706 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4707 // The base of an AddRec is the first operand.
4708 SmallVector<const SCEV *> Ops{AddRec->operands()};
4709 Ops[0] = removePointerBase(Ops[0]);
4710 // Don't try to transfer nowrap flags for now. We could in some cases
4711 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4712 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4713 }
4714 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4715 // The base of an Add is the pointer operand.
4716 SmallVector<const SCEV *> Ops{Add->operands()};
4717 const SCEV **PtrOp = nullptr;
4718 for (const SCEV *&AddOp : Ops) {
4719 if (AddOp->getType()->isPointerTy()) {
4720 assert(!PtrOp && "Cannot have multiple pointer ops");
4721 PtrOp = &AddOp;
4722 }
4723 }
4724 *PtrOp = removePointerBase(*PtrOp);
4725 // Don't try to transfer nowrap flags for now. We could in some cases
4726 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4727 return getAddExpr(Ops);
4728 }
4729 // Any other expression must be a pointer base.
4730 return getZero(P->getType());
4731}
4732
4733const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4734 SCEV::NoWrapFlags Flags,
4735 unsigned Depth) {
4736 // Fast path: X - X --> 0.
4737 if (LHS == RHS)
4738 return getZero(LHS->getType());
4739
4740 // If we subtract two pointers with different pointer bases, bail.
4741 // Eventually, we're going to add an assertion to getMulExpr that we
4742 // can't multiply by a pointer.
4743 if (RHS->getType()->isPointerTy()) {
4744 if (!LHS->getType()->isPointerTy() ||
4745 getPointerBase(LHS) != getPointerBase(RHS))
4746 return getCouldNotCompute();
4747 LHS = removePointerBase(LHS);
4748 RHS = removePointerBase(RHS);
4749 }
4750
4751 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4752 // makes it so that we cannot make much use of NUW.
4753 auto AddFlags = SCEV::FlagAnyWrap;
4754 const bool RHSIsNotMinSigned =
4756 if (hasFlags(Flags, SCEV::FlagNSW)) {
4757 // Let M be the minimum representable signed value. Then (-1)*RHS
4758 // signed-wraps if and only if RHS is M. That can happen even for
4759 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4760 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4761 // (-1)*RHS, we need to prove that RHS != M.
4762 //
4763 // If LHS is non-negative and we know that LHS - RHS does not
4764 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4765 // either by proving that RHS > M or that LHS >= 0.
4766 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4767 AddFlags = SCEV::FlagNSW;
4768 }
4769 }
4770
4771 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4772 // RHS is NSW and LHS >= 0.
4773 //
4774 // The difficulty here is that the NSW flag may have been proven
4775 // relative to a loop that is to be found in a recurrence in LHS and
4776 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4777 // larger scope than intended.
4778 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4779
4780 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4781}
4782
4784 unsigned Depth) {
4785 Type *SrcTy = V->getType();
4786 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4787 "Cannot truncate or zero extend with non-integer arguments!");
4788 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4789 return V; // No conversion
4790 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4791 return getTruncateExpr(V, Ty, Depth);
4792 return getZeroExtendExpr(V, Ty, Depth);
4793}
4794
4796 unsigned Depth) {
4797 Type *SrcTy = V->getType();
4798 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4799 "Cannot truncate or zero extend with non-integer arguments!");
4800 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4801 return V; // No conversion
4802 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4803 return getTruncateExpr(V, Ty, Depth);
4804 return getSignExtendExpr(V, Ty, Depth);
4805}
4806
4807const SCEV *
4809 Type *SrcTy = V->getType();
4810 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4811 "Cannot noop or zero extend with non-integer arguments!");
4813 "getNoopOrZeroExtend cannot truncate!");
4814 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4815 return V; // No conversion
4816 return getZeroExtendExpr(V, Ty);
4817}
4818
4819const SCEV *
4821 Type *SrcTy = V->getType();
4822 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4823 "Cannot noop or sign extend with non-integer arguments!");
4825 "getNoopOrSignExtend cannot truncate!");
4826 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4827 return V; // No conversion
4828 return getSignExtendExpr(V, Ty);
4829}
4830
4831const SCEV *
4833 Type *SrcTy = V->getType();
4834 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4835 "Cannot noop or any extend with non-integer arguments!");
4837 "getNoopOrAnyExtend cannot truncate!");
4838 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4839 return V; // No conversion
4840 return getAnyExtendExpr(V, Ty);
4841}
4842
4843const SCEV *
4845 Type *SrcTy = V->getType();
4846 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4847 "Cannot truncate or noop with non-integer arguments!");
4849 "getTruncateOrNoop cannot extend!");
4850 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4851 return V; // No conversion
4852 return getTruncateExpr(V, Ty);
4853}
4854
4856 const SCEV *RHS) {
4857 const SCEV *PromotedLHS = LHS;
4858 const SCEV *PromotedRHS = RHS;
4859
4860 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4861 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4862 else
4863 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4864
4865 return getUMaxExpr(PromotedLHS, PromotedRHS);
4866}
4867
4869 const SCEV *RHS,
4870 bool Sequential) {
4871 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4872 return getUMinFromMismatchedTypes(Ops, Sequential);
4873}
4874
4875const SCEV *
4877 bool Sequential) {
4878 assert(!Ops.empty() && "At least one operand must be!");
4879 // Trivial case.
4880 if (Ops.size() == 1)
4881 return Ops[0];
4882
4883 // Find the max type first.
4884 Type *MaxType = nullptr;
4885 for (const auto *S : Ops)
4886 if (MaxType)
4887 MaxType = getWiderType(MaxType, S->getType());
4888 else
4889 MaxType = S->getType();
4890 assert(MaxType && "Failed to find maximum type!");
4891
4892 // Extend all ops to max type.
4893 SmallVector<const SCEV *, 2> PromotedOps;
4894 for (const auto *S : Ops)
4895 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4896
4897 // Generate umin.
4898 return getUMinExpr(PromotedOps, Sequential);
4899}
4900
4902 // A pointer operand may evaluate to a nonpointer expression, such as null.
4903 if (!V->getType()->isPointerTy())
4904 return V;
4905
4906 while (true) {
4907 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4908 V = AddRec->getStart();
4909 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4910 const SCEV *PtrOp = nullptr;
4911 for (const SCEV *AddOp : Add->operands()) {
4912 if (AddOp->getType()->isPointerTy()) {
4913 assert(!PtrOp && "Cannot have multiple pointer ops");
4914 PtrOp = AddOp;
4915 }
4916 }
4917 assert(PtrOp && "Must have pointer op");
4918 V = PtrOp;
4919 } else // Not something we can look further into.
4920 return V;
4921 }
4922}
4923
4924/// Push users of the given Instruction onto the given Worklist.
4928 // Push the def-use children onto the Worklist stack.
4929 for (User *U : I->users()) {
4930 auto *UserInsn = cast<Instruction>(U);
4931 if (Visited.insert(UserInsn).second)
4932 Worklist.push_back(UserInsn);
4933 }
4934}
4935
4936namespace {
4937
4938/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4939/// expression in case its Loop is L. If it is not L then
4940/// if IgnoreOtherLoops is true then use AddRec itself
4941/// otherwise rewrite cannot be done.
4942/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4943class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4944public:
4945 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4946 bool IgnoreOtherLoops = true) {
4947 SCEVInitRewriter Rewriter(L, SE);
4948 const SCEV *Result = Rewriter.visit(S);
4949 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4950 return SE.getCouldNotCompute();
4951 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4952 ? SE.getCouldNotCompute()
4953 : Result;
4954 }
4955
4956 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4957 if (!SE.isLoopInvariant(Expr, L))
4958 SeenLoopVariantSCEVUnknown = true;
4959 return Expr;
4960 }
4961
4962 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4963 // Only re-write AddRecExprs for this loop.
4964 if (Expr->getLoop() == L)
4965 return Expr->getStart();
4966 SeenOtherLoops = true;
4967 return Expr;
4968 }
4969
4970 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4971
4972 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4973
4974private:
4975 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4976 : SCEVRewriteVisitor(SE), L(L) {}
4977
4978 const Loop *L;
4979 bool SeenLoopVariantSCEVUnknown = false;
4980 bool SeenOtherLoops = false;
4981};
4982
4983/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4984/// increment expression in case its Loop is L. If it is not L then
4985/// use AddRec itself.
4986/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4987class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4988public:
4989 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4990 SCEVPostIncRewriter Rewriter(L, SE);
4991 const SCEV *Result = Rewriter.visit(S);
4992 return Rewriter.hasSeenLoopVariantSCEVUnknown()
4993 ? SE.getCouldNotCompute()
4994 : Result;
4995 }
4996
4997 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4998 if (!SE.isLoopInvariant(Expr, L))
4999 SeenLoopVariantSCEVUnknown = true;
5000 return Expr;
5001 }
5002
5003 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5004 // Only re-write AddRecExprs for this loop.
5005 if (Expr->getLoop() == L)
5006 return Expr->getPostIncExpr(SE);
5007 SeenOtherLoops = true;
5008 return Expr;
5009 }
5010
5011 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
5012
5013 bool hasSeenOtherLoops() { return SeenOtherLoops; }
5014
5015private:
5016 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
5017 : SCEVRewriteVisitor(SE), L(L) {}
5018
5019 const Loop *L;
5020 bool SeenLoopVariantSCEVUnknown = false;
5021 bool SeenOtherLoops = false;
5022};
5023
5024/// This class evaluates the compare condition by matching it against the
5025/// condition of loop latch. If there is a match we assume a true value
5026/// for the condition while building SCEV nodes.
5027class SCEVBackedgeConditionFolder
5028 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
5029public:
5030 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5031 ScalarEvolution &SE) {
5032 bool IsPosBECond = false;
5033 Value *BECond = nullptr;
5034 if (BasicBlock *Latch = L->getLoopLatch()) {
5035 BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
5036 if (BI && BI->isConditional()) {
5037 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
5038 "Both outgoing branches should not target same header!");
5039 BECond = BI->getCondition();
5040 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
5041 } else {
5042 return S;
5043 }
5044 }
5045 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
5046 return Rewriter.visit(S);
5047 }
5048
5049 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5050 const SCEV *Result = Expr;
5051 bool InvariantF = SE.isLoopInvariant(Expr, L);
5052
5053 if (!InvariantF) {
5055 switch (I->getOpcode()) {
5056 case Instruction::Select: {
5057 SelectInst *SI = cast<SelectInst>(I);
5058 std::optional<const SCEV *> Res =
5059 compareWithBackedgeCondition(SI->getCondition());
5060 if (Res) {
5061 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
5062 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
5063 }
5064 break;
5065 }
5066 default: {
5067 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
5068 if (Res)
5069 Result = *Res;
5070 break;
5071 }
5072 }
5073 }
5074 return Result;
5075 }
5076
5077private:
5078 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5079 bool IsPosBECond, ScalarEvolution &SE)
5080 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5081 IsPositiveBECond(IsPosBECond) {}
5082
5083 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5084
5085 const Loop *L;
5086 /// Loop back condition.
5087 Value *BackedgeCond = nullptr;
5088 /// Set to true if loop back is on positive branch condition.
5089 bool IsPositiveBECond;
5090};
5091
5092std::optional<const SCEV *>
5093SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5094
5095 // If value matches the backedge condition for loop latch,
5096 // then return a constant evolution node based on loopback
5097 // branch taken.
5098 if (BackedgeCond == IC)
5099 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5101 return std::nullopt;
5102}
5103
5104class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5105public:
5106 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5107 ScalarEvolution &SE) {
5108 SCEVShiftRewriter Rewriter(L, SE);
5109 const SCEV *Result = Rewriter.visit(S);
5110 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5111 }
5112
5113 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5114 // Only allow AddRecExprs for this loop.
5115 if (!SE.isLoopInvariant(Expr, L))
5116 Valid = false;
5117 return Expr;
5118 }
5119
5120 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5121 if (Expr->getLoop() == L && Expr->isAffine())
5122 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5123 Valid = false;
5124 return Expr;
5125 }
5126
5127 bool isValid() { return Valid; }
5128
5129private:
5130 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5131 : SCEVRewriteVisitor(SE), L(L) {}
5132
5133 const Loop *L;
5134 bool Valid = true;
5135};
5136
5137} // end anonymous namespace
5138
5140ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5141 if (!AR->isAffine())
5142 return SCEV::FlagAnyWrap;
5143
5144 using OBO = OverflowingBinaryOperator;
5145
5147
5148 if (!AR->hasNoSelfWrap()) {
5149 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5150 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5151 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5152 const APInt &BECountAP = BECountMax->getAPInt();
5153 unsigned NoOverflowBitWidth =
5154 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5155 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5157 }
5158 }
5159
5160 if (!AR->hasNoSignedWrap()) {
5161 ConstantRange AddRecRange = getSignedRange(AR);
5162 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5163
5165 Instruction::Add, IncRange, OBO::NoSignedWrap);
5166 if (NSWRegion.contains(AddRecRange))
5168 }
5169
5170 if (!AR->hasNoUnsignedWrap()) {
5171 ConstantRange AddRecRange = getUnsignedRange(AR);
5172 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5173
5175 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5176 if (NUWRegion.contains(AddRecRange))
5178 }
5179
5180 return Result;
5181}
5182
5184ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5186
5187 if (AR->hasNoSignedWrap())
5188 return Result;
5189
5190 if (!AR->isAffine())
5191 return Result;
5192
5193 // This function can be expensive, only try to prove NSW once per AddRec.
5194 if (!SignedWrapViaInductionTried.insert(AR).second)
5195 return Result;
5196
5197 const SCEV *Step = AR->getStepRecurrence(*this);
5198 const Loop *L = AR->getLoop();
5199
5200 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5201 // Note that this serves two purposes: It filters out loops that are
5202 // simply not analyzable, and it covers the case where this code is
5203 // being called from within backedge-taken count analysis, such that
5204 // attempting to ask for the backedge-taken count would likely result
5205 // in infinite recursion. In the later case, the analysis code will
5206 // cope with a conservative value, and it will take care to purge
5207 // that value once it has finished.
5208 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5209
5210 // Normally, in the cases we can prove no-overflow via a
5211 // backedge guarding condition, we can also compute a backedge
5212 // taken count for the loop. The exceptions are assumptions and
5213 // guards present in the loop -- SCEV is not great at exploiting
5214 // these to compute max backedge taken counts, but can still use
5215 // these to prove lack of overflow. Use this fact to avoid
5216 // doing extra work that may not pay off.
5217
5218 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5219 AC.assumptions().empty())
5220 return Result;
5221
5222 // If the backedge is guarded by a comparison with the pre-inc value the
5223 // addrec is safe. Also, if the entry is guarded by a comparison with the
5224 // start value and the backedge is guarded by a comparison with the post-inc
5225 // value, the addrec is safe.
5227 const SCEV *OverflowLimit =
5228 getSignedOverflowLimitForStep(Step, &Pred, this);
5229 if (OverflowLimit &&
5230 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5231 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5232 Result = setFlags(Result, SCEV::FlagNSW);
5233 }
5234 return Result;
5235}
5237ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5239
5240 if (AR->hasNoUnsignedWrap())
5241 return Result;
5242
5243 if (!AR->isAffine())
5244 return Result;
5245
5246 // This function can be expensive, only try to prove NUW once per AddRec.
5247 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5248 return Result;
5249
5250 const SCEV *Step = AR->getStepRecurrence(*this);
5251 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5252 const Loop *L = AR->getLoop();
5253
5254 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5255 // Note that this serves two purposes: It filters out loops that are
5256 // simply not analyzable, and it covers the case where this code is
5257 // being called from within backedge-taken count analysis, such that
5258 // attempting to ask for the backedge-taken count would likely result
5259 // in infinite recursion. In the later case, the analysis code will
5260 // cope with a conservative value, and it will take care to purge
5261 // that value once it has finished.
5262 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5263
5264 // Normally, in the cases we can prove no-overflow via a
5265 // backedge guarding condition, we can also compute a backedge
5266 // taken count for the loop. The exceptions are assumptions and
5267 // guards present in the loop -- SCEV is not great at exploiting
5268 // these to compute max backedge taken counts, but can still use
5269 // these to prove lack of overflow. Use this fact to avoid
5270 // doing extra work that may not pay off.
5271
5272 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5273 AC.assumptions().empty())
5274 return Result;
5275
5276 // If the backedge is guarded by a comparison with the pre-inc value the
5277 // addrec is safe. Also, if the entry is guarded by a comparison with the
5278 // start value and the backedge is guarded by a comparison with the post-inc
5279 // value, the addrec is safe.
5280 if (isKnownPositive(Step)) {
5281 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
5282 getUnsignedRangeMax(Step));
5285 Result = setFlags(Result, SCEV::FlagNUW);
5286 }
5287 }
5288
5289 return Result;
5290}
5291
5292namespace {
5293
5294/// Represents an abstract binary operation. This may exist as a
5295/// normal instruction or constant expression, or may have been
5296/// derived from an expression tree.
5297struct BinaryOp {
5298 unsigned Opcode;
5299 Value *LHS;
5300 Value *RHS;
5301 bool IsNSW = false;
5302 bool IsNUW = false;
5303
5304 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5305 /// constant expression.
5306 Operator *Op = nullptr;
5307
5308 explicit BinaryOp(Operator *Op)
5309 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5310 Op(Op) {
5311 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5312 IsNSW = OBO->hasNoSignedWrap();
5313 IsNUW = OBO->hasNoUnsignedWrap();
5314 }
5315 }
5316
5317 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5318 bool IsNUW = false)
5319 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5320};
5321
5322} // end anonymous namespace
5323
5324/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5325static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5326 AssumptionCache &AC,
5327 const DominatorTree &DT,
5328 const Instruction *CxtI) {
5329 auto *Op = dyn_cast<Operator>(V);
5330 if (!Op)
5331 return std::nullopt;
5332
5333 // Implementation detail: all the cleverness here should happen without
5334 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5335 // SCEV expressions when possible, and we should not break that.
5336
5337 switch (Op->getOpcode()) {
5338 case Instruction::Add:
5339 case Instruction::Sub:
5340 case Instruction::Mul:
5341 case Instruction::UDiv:
5342 case Instruction::URem:
5343 case Instruction::And:
5344 case Instruction::AShr:
5345 case Instruction::Shl:
5346 return BinaryOp(Op);
5347
5348 case Instruction::Or: {
5349 // Convert or disjoint into add nuw nsw.
5350 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5351 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5352 /*IsNSW=*/true, /*IsNUW=*/true);
5353 return BinaryOp(Op);
5354 }
5355
5356 case Instruction::Xor:
5357 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5358 // If the RHS of the xor is a signmask, then this is just an add.
5359 // Instcombine turns add of signmask into xor as a strength reduction step.
5360 if (RHSC->getValue().isSignMask())
5361 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5362 // Binary `xor` is a bit-wise `add`.
5363 if (V->getType()->isIntegerTy(1))
5364 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5365 return BinaryOp(Op);
5366
5367 case Instruction::LShr:
5368 // Turn logical shift right of a constant into a unsigned divide.
5369 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5370 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5371
5372 // If the shift count is not less than the bitwidth, the result of
5373 // the shift is undefined. Don't try to analyze it, because the
5374 // resolution chosen here may differ from the resolution chosen in
5375 // other parts of the compiler.
5376 if (SA->getValue().ult(BitWidth)) {
5377 Constant *X =
5378 ConstantInt::get(SA->getContext(),
5379 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5380 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5381 }
5382 }
5383 return BinaryOp(Op);
5384
5385 case Instruction::ExtractValue: {
5386 auto *EVI = cast<ExtractValueInst>(Op);
5387 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5388 break;
5389
5390 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5391 if (!WO)
5392 break;
5393
5394 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5395 bool Signed = WO->isSigned();
5396 // TODO: Should add nuw/nsw flags for mul as well.
5397 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5398 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5399
5400 // Now that we know that all uses of the arithmetic-result component of
5401 // CI are guarded by the overflow check, we can go ahead and pretend
5402 // that the arithmetic is non-overflowing.
5403 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5404 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5405 }
5406
5407 default:
5408 break;
5409 }
5410
5411 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5412 // semantics as a Sub, return a binary sub expression.
5413 if (auto *II = dyn_cast<IntrinsicInst>(V))
5414 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5415 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5416
5417 return std::nullopt;
5418}
5419
5420/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5421/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5422/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5423/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5424/// follows one of the following patterns:
5425/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5426/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5427/// If the SCEV expression of \p Op conforms with one of the expected patterns
5428/// we return the type of the truncation operation, and indicate whether the
5429/// truncated type should be treated as signed/unsigned by setting
5430/// \p Signed to true/false, respectively.
5431static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5432 bool &Signed, ScalarEvolution &SE) {
5433 // The case where Op == SymbolicPHI (that is, with no type conversions on
5434 // the way) is handled by the regular add recurrence creating logic and
5435 // would have already been triggered in createAddRecForPHI. Reaching it here
5436 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5437 // because one of the other operands of the SCEVAddExpr updating this PHI is
5438 // not invariant).
5439 //
5440 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5441 // this case predicates that allow us to prove that Op == SymbolicPHI will
5442 // be added.
5443 if (Op == SymbolicPHI)
5444 return nullptr;
5445
5446 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5447 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5448 if (SourceBits != NewBits)
5449 return nullptr;
5450
5451 if (match(Op, m_scev_SExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5452 Signed = true;
5453 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5454 }
5455 if (match(Op, m_scev_ZExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5456 Signed = false;
5457 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5458 }
5459 return nullptr;
5460}
5461
5462static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5463 if (!PN->getType()->isIntegerTy())
5464 return nullptr;
5465 const Loop *L = LI.getLoopFor(PN->getParent());
5466 if (!L || L->getHeader() != PN->getParent())
5467 return nullptr;
5468 return L;
5469}
5470
5471// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5472// computation that updates the phi follows the following pattern:
5473// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5474// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5475// If so, try to see if it can be rewritten as an AddRecExpr under some
5476// Predicates. If successful, return them as a pair. Also cache the results
5477// of the analysis.
5478//
5479// Example usage scenario:
5480// Say the Rewriter is called for the following SCEV:
5481// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5482// where:
5483// %X = phi i64 (%Start, %BEValue)
5484// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5485// and call this function with %SymbolicPHI = %X.
5486//
5487// The analysis will find that the value coming around the backedge has
5488// the following SCEV:
5489// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5490// Upon concluding that this matches the desired pattern, the function
5491// will return the pair {NewAddRec, SmallPredsVec} where:
5492// NewAddRec = {%Start,+,%Step}
5493// SmallPredsVec = {P1, P2, P3} as follows:
5494// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5495// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5496// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5497// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5498// under the predicates {P1,P2,P3}.
5499// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5500// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5501//
5502// TODO's:
5503//
5504// 1) Extend the Induction descriptor to also support inductions that involve
5505// casts: When needed (namely, when we are called in the context of the
5506// vectorizer induction analysis), a Set of cast instructions will be
5507// populated by this method, and provided back to isInductionPHI. This is
5508// needed to allow the vectorizer to properly record them to be ignored by
5509// the cost model and to avoid vectorizing them (otherwise these casts,
5510// which are redundant under the runtime overflow checks, will be
5511// vectorized, which can be costly).
5512//
5513// 2) Support additional induction/PHISCEV patterns: We also want to support
5514// inductions where the sext-trunc / zext-trunc operations (partly) occur
5515// after the induction update operation (the induction increment):
5516//
5517// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5518// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5519//
5520// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5521// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5522//
5523// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5524std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5525ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5527
5528 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5529 // return an AddRec expression under some predicate.
5530
5531 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5532 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5533 assert(L && "Expecting an integer loop header phi");
5534
5535 // The loop may have multiple entrances or multiple exits; we can analyze
5536 // this phi as an addrec if it has a unique entry value and a unique
5537 // backedge value.
5538 Value *BEValueV = nullptr, *StartValueV = nullptr;
5539 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5540 Value *V = PN->getIncomingValue(i);
5541 if (L->contains(PN->getIncomingBlock(i))) {
5542 if (!BEValueV) {
5543 BEValueV = V;
5544 } else if (BEValueV != V) {
5545 BEValueV = nullptr;
5546 break;
5547 }
5548 } else if (!StartValueV) {
5549 StartValueV = V;
5550 } else if (StartValueV != V) {
5551 StartValueV = nullptr;
5552 break;
5553 }
5554 }
5555 if (!BEValueV || !StartValueV)
5556 return std::nullopt;
5557
5558 const SCEV *BEValue = getSCEV(BEValueV);
5559
5560 // If the value coming around the backedge is an add with the symbolic
5561 // value we just inserted, possibly with casts that we can ignore under
5562 // an appropriate runtime guard, then we found a simple induction variable!
5563 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5564 if (!Add)
5565 return std::nullopt;
5566
5567 // If there is a single occurrence of the symbolic value, possibly
5568 // casted, replace it with a recurrence.
5569 unsigned FoundIndex = Add->getNumOperands();
5570 Type *TruncTy = nullptr;
5571 bool Signed;
5572 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5573 if ((TruncTy =
5574 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5575 if (FoundIndex == e) {
5576 FoundIndex = i;
5577 break;
5578 }
5579
5580 if (FoundIndex == Add->getNumOperands())
5581 return std::nullopt;
5582
5583 // Create an add with everything but the specified operand.
5585 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5586 if (i != FoundIndex)
5587 Ops.push_back(Add->getOperand(i));
5588 const SCEV *Accum = getAddExpr(Ops);
5589
5590 // The runtime checks will not be valid if the step amount is
5591 // varying inside the loop.
5592 if (!isLoopInvariant(Accum, L))
5593 return std::nullopt;
5594
5595 // *** Part2: Create the predicates
5596
5597 // Analysis was successful: we have a phi-with-cast pattern for which we
5598 // can return an AddRec expression under the following predicates:
5599 //
5600 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5601 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5602 // P2: An Equal predicate that guarantees that
5603 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5604 // P3: An Equal predicate that guarantees that
5605 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5606 //
5607 // As we next prove, the above predicates guarantee that:
5608 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5609 //
5610 //
5611 // More formally, we want to prove that:
5612 // Expr(i+1) = Start + (i+1) * Accum
5613 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5614 //
5615 // Given that:
5616 // 1) Expr(0) = Start
5617 // 2) Expr(1) = Start + Accum
5618 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5619 // 3) Induction hypothesis (step i):
5620 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5621 //
5622 // Proof:
5623 // Expr(i+1) =
5624 // = Start + (i+1)*Accum
5625 // = (Start + i*Accum) + Accum
5626 // = Expr(i) + Accum
5627 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5628 // :: from step i
5629 //
5630 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5631 //
5632 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5633 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5634 // + Accum :: from P3
5635 //
5636 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5637 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5638 //
5639 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5640 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5641 //
5642 // By induction, the same applies to all iterations 1<=i<n:
5643 //
5644
5645 // Create a truncated addrec for which we will add a no overflow check (P1).
5646 const SCEV *StartVal = getSCEV(StartValueV);
5647 const SCEV *PHISCEV =
5648 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5649 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5650
5651 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5652 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5653 // will be constant.
5654 //
5655 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5656 // add P1.
5657 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5661 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5662 Predicates.push_back(AddRecPred);
5663 }
5664
5665 // Create the Equal Predicates P2,P3:
5666
5667 // It is possible that the predicates P2 and/or P3 are computable at
5668 // compile time due to StartVal and/or Accum being constants.
5669 // If either one is, then we can check that now and escape if either P2
5670 // or P3 is false.
5671
5672 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5673 // for each of StartVal and Accum
5674 auto getExtendedExpr = [&](const SCEV *Expr,
5675 bool CreateSignExtend) -> const SCEV * {
5676 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5677 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5678 const SCEV *ExtendedExpr =
5679 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5680 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5681 return ExtendedExpr;
5682 };
5683
5684 // Given:
5685 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5686 // = getExtendedExpr(Expr)
5687 // Determine whether the predicate P: Expr == ExtendedExpr
5688 // is known to be false at compile time
5689 auto PredIsKnownFalse = [&](const SCEV *Expr,
5690 const SCEV *ExtendedExpr) -> bool {
5691 return Expr != ExtendedExpr &&
5692 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5693 };
5694
5695 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5696 if (PredIsKnownFalse(StartVal, StartExtended)) {
5697 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5698 return std::nullopt;
5699 }
5700
5701 // The Step is always Signed (because the overflow checks are either
5702 // NSSW or NUSW)
5703 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5704 if (PredIsKnownFalse(Accum, AccumExtended)) {
5705 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5706 return std::nullopt;
5707 }
5708
5709 auto AppendPredicate = [&](const SCEV *Expr,
5710 const SCEV *ExtendedExpr) -> void {
5711 if (Expr != ExtendedExpr &&
5712 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5713 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5714 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5715 Predicates.push_back(Pred);
5716 }
5717 };
5718
5719 AppendPredicate(StartVal, StartExtended);
5720 AppendPredicate(Accum, AccumExtended);
5721
5722 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5723 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5724 // into NewAR if it will also add the runtime overflow checks specified in
5725 // Predicates.
5726 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5727
5728 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5729 std::make_pair(NewAR, Predicates);
5730 // Remember the result of the analysis for this SCEV at this locayyytion.
5731 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5732 return PredRewrite;
5733}
5734
5735std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5737 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5738 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5739 if (!L)
5740 return std::nullopt;
5741
5742 // Check to see if we already analyzed this PHI.
5743 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5744 if (I != PredicatedSCEVRewrites.end()) {
5745 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5746 I->second;
5747 // Analysis was done before and failed to create an AddRec:
5748 if (Rewrite.first == SymbolicPHI)
5749 return std::nullopt;
5750 // Analysis was done before and succeeded to create an AddRec under
5751 // a predicate:
5752 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5753 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5754 return Rewrite;
5755 }
5756
5757 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5758 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5759
5760 // Record in the cache that the analysis failed
5761 if (!Rewrite) {
5763 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5764 return std::nullopt;
5765 }
5766
5767 return Rewrite;
5768}
5769
5770// FIXME: This utility is currently required because the Rewriter currently
5771// does not rewrite this expression:
5772// {0, +, (sext ix (trunc iy to ix) to iy)}
5773// into {0, +, %step},
5774// even when the following Equal predicate exists:
5775// "%step == (sext ix (trunc iy to ix) to iy)".
5777 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5778 if (AR1 == AR2)
5779 return true;
5780
5781 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5782 if (Expr1 != Expr2 &&
5783 !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5784 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5785 return false;
5786 return true;
5787 };
5788
5789 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5790 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5791 return false;
5792 return true;
5793}
5794
5795/// A helper function for createAddRecFromPHI to handle simple cases.
5796///
5797/// This function tries to find an AddRec expression for the simplest (yet most
5798/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5799/// If it fails, createAddRecFromPHI will use a more general, but slow,
5800/// technique for finding the AddRec expression.
5801const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5802 Value *BEValueV,
5803 Value *StartValueV) {
5804 const Loop *L = LI.getLoopFor(PN->getParent());
5805 assert(L && L->getHeader() == PN->getParent());
5806 assert(BEValueV && StartValueV);
5807
5808 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5809 if (!BO)
5810 return nullptr;
5811
5812 if (BO->Opcode != Instruction::Add)
5813 return nullptr;
5814
5815 const SCEV *Accum = nullptr;
5816 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5817 Accum = getSCEV(BO->RHS);
5818 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5819 Accum = getSCEV(BO->LHS);
5820
5821 if (!Accum)
5822 return nullptr;
5823
5825 if (BO->IsNUW)
5826 Flags = setFlags(Flags, SCEV::FlagNUW);
5827 if (BO->IsNSW)
5828 Flags = setFlags(Flags, SCEV::FlagNSW);
5829
5830 const SCEV *StartVal = getSCEV(StartValueV);
5831 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5832 insertValueToMap(PN, PHISCEV);
5833
5834 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5835 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5837 proveNoWrapViaConstantRanges(AR)));
5838 }
5839
5840 // We can add Flags to the post-inc expression only if we
5841 // know that it is *undefined behavior* for BEValueV to
5842 // overflow.
5843 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5844 assert(isLoopInvariant(Accum, L) &&
5845 "Accum is defined outside L, but is not invariant?");
5846 if (isAddRecNeverPoison(BEInst, L))
5847 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5848 }
5849
5850 return PHISCEV;
5851}
5852
5853const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5854 const Loop *L = LI.getLoopFor(PN->getParent());
5855 if (!L || L->getHeader() != PN->getParent())
5856 return nullptr;
5857
5858 // The loop may have multiple entrances or multiple exits; we can analyze
5859 // this phi as an addrec if it has a unique entry value and a unique
5860 // backedge value.
5861 Value *BEValueV = nullptr, *StartValueV = nullptr;
5862 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5863 Value *V = PN->getIncomingValue(i);
5864 if (L->contains(PN->getIncomingBlock(i))) {
5865 if (!BEValueV) {
5866 BEValueV = V;
5867 } else if (BEValueV != V) {
5868 BEValueV = nullptr;
5869 break;
5870 }
5871 } else if (!StartValueV) {
5872 StartValueV = V;
5873 } else if (StartValueV != V) {
5874 StartValueV = nullptr;
5875 break;
5876 }
5877 }
5878 if (!BEValueV || !StartValueV)
5879 return nullptr;
5880
5881 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5882 "PHI node already processed?");
5883
5884 // First, try to find AddRec expression without creating a fictituos symbolic
5885 // value for PN.
5886 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5887 return S;
5888
5889 // Handle PHI node value symbolically.
5890 const SCEV *SymbolicName = getUnknown(PN);
5891 insertValueToMap(PN, SymbolicName);
5892
5893 // Using this symbolic name for the PHI, analyze the value coming around
5894 // the back-edge.
5895 const SCEV *BEValue = getSCEV(BEValueV);
5896
5897 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5898 // has a special value for the first iteration of the loop.
5899
5900 // If the value coming around the backedge is an add with the symbolic
5901 // value we just inserted, then we found a simple induction variable!
5902 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5903 // If there is a single occurrence of the symbolic value, replace it
5904 // with a recurrence.
5905 unsigned FoundIndex = Add->getNumOperands();
5906 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5907 if (Add->getOperand(i) == SymbolicName)
5908 if (FoundIndex == e) {
5909 FoundIndex = i;
5910 break;
5911 }
5912
5913 if (FoundIndex != Add->getNumOperands()) {
5914 // Create an add with everything but the specified operand.
5916 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5917 if (i != FoundIndex)
5918 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5919 L, *this));
5920 const SCEV *Accum = getAddExpr(Ops);
5921
5922 // This is not a valid addrec if the step amount is varying each
5923 // loop iteration, but is not itself an addrec in this loop.
5924 if (isLoopInvariant(Accum, L) ||
5925 (isa<SCEVAddRecExpr>(Accum) &&
5926 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5928
5929 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
5930 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5931 if (BO->IsNUW)
5932 Flags = setFlags(Flags, SCEV::FlagNUW);
5933 if (BO->IsNSW)
5934 Flags = setFlags(Flags, SCEV::FlagNSW);
5935 }
5936 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5937 if (GEP->getOperand(0) == PN) {
5938 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
5939 // If the increment has any nowrap flags, then we know the address
5940 // space cannot be wrapped around.
5941 if (NW != GEPNoWrapFlags::none())
5942 Flags = setFlags(Flags, SCEV::FlagNW);
5943 // If the GEP is nuw or nusw with non-negative offset, we know that
5944 // no unsigned wrap occurs. We cannot set the nsw flag as only the
5945 // offset is treated as signed, while the base is unsigned.
5946 if (NW.hasNoUnsignedWrap() ||
5948 Flags = setFlags(Flags, SCEV::FlagNUW);
5949 }
5950
5951 // We cannot transfer nuw and nsw flags from subtraction
5952 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5953 // for instance.
5954 }
5955
5956 const SCEV *StartVal = getSCEV(StartValueV);
5957 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5958
5959 // Okay, for the entire analysis of this edge we assumed the PHI
5960 // to be symbolic. We now need to go back and purge all of the
5961 // entries for the scalars that use the symbolic expression.
5962 forgetMemoizedResults(SymbolicName);
5963 insertValueToMap(PN, PHISCEV);
5964
5965 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5966 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5968 proveNoWrapViaConstantRanges(AR)));
5969 }
5970
5971 // We can add Flags to the post-inc expression only if we
5972 // know that it is *undefined behavior* for BEValueV to
5973 // overflow.
5974 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5975 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5976 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5977
5978 return PHISCEV;
5979 }
5980 }
5981 } else {
5982 // Otherwise, this could be a loop like this:
5983 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
5984 // In this case, j = {1,+,1} and BEValue is j.
5985 // Because the other in-value of i (0) fits the evolution of BEValue
5986 // i really is an addrec evolution.
5987 //
5988 // We can generalize this saying that i is the shifted value of BEValue
5989 // by one iteration:
5990 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
5991
5992 // Do not allow refinement in rewriting of BEValue.
5993 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
5994 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
5995 if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() &&
5996 isGuaranteedNotToCauseUB(Shifted) && ::impliesPoison(Shifted, Start)) {
5997 const SCEV *StartVal = getSCEV(StartValueV);
5998 if (Start == StartVal) {
5999 // Okay, for the entire analysis of this edge we assumed the PHI
6000 // to be symbolic. We now need to go back and purge all of the
6001 // entries for the scalars that use the symbolic expression.
6002 forgetMemoizedResults(SymbolicName);
6003 insertValueToMap(PN, Shifted);
6004 return Shifted;
6005 }
6006 }
6007 }
6008
6009 // Remove the temporary PHI node SCEV that has been inserted while intending
6010 // to create an AddRecExpr for this PHI node. We can not keep this temporary
6011 // as it will prevent later (possibly simpler) SCEV expressions to be added
6012 // to the ValueExprMap.
6013 eraseValueFromMap(PN);
6014
6015 return nullptr;
6016}
6017
6018// Try to match a control flow sequence that branches out at BI and merges back
6019// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
6020// match.
6022 Value *&C, Value *&LHS, Value *&RHS) {
6023 C = BI->getCondition();
6024
6025 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
6026 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
6027
6028 if (!LeftEdge.isSingleEdge())
6029 return false;
6030
6031 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
6032
6033 Use &LeftUse = Merge->getOperandUse(0);
6034 Use &RightUse = Merge->getOperandUse(1);
6035
6036 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
6037 LHS = LeftUse;
6038 RHS = RightUse;
6039 return true;
6040 }
6041
6042 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
6043 LHS = RightUse;
6044 RHS = LeftUse;
6045 return true;
6046 }
6047
6048 return false;
6049}
6050
6051const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
6052 auto IsReachable =
6053 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
6054 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
6055 // Try to match
6056 //
6057 // br %cond, label %left, label %right
6058 // left:
6059 // br label %merge
6060 // right:
6061 // br label %merge
6062 // merge:
6063 // V = phi [ %x, %left ], [ %y, %right ]
6064 //
6065 // as "select %cond, %x, %y"
6066
6067 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
6068 assert(IDom && "At least the entry block should dominate PN");
6069
6070 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
6071 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6072
6073 if (BI && BI->isConditional() &&
6074 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
6077 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6078 }
6079
6080 return nullptr;
6081}
6082
6083/// Returns SCEV for the first operand of a phi if all phi operands have
6084/// identical opcodes and operands
6085/// eg.
6086/// a: %add = %a + %b
6087/// br %c
6088/// b: %add1 = %a + %b
6089/// br %c
6090/// c: %phi = phi [%add, a], [%add1, b]
6091/// scev(%phi) => scev(%add)
6092const SCEV *
6093ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) {
6094 BinaryOperator *CommonInst = nullptr;
6095 // Check if instructions are identical.
6096 for (Value *Incoming : PN->incoming_values()) {
6097 auto *IncomingInst = dyn_cast<BinaryOperator>(Incoming);
6098 if (!IncomingInst)
6099 return nullptr;
6100 if (CommonInst) {
6101 if (!CommonInst->isIdenticalToWhenDefined(IncomingInst))
6102 return nullptr; // Not identical, give up
6103 } else {
6104 // Remember binary operator
6105 CommonInst = IncomingInst;
6106 }
6107 }
6108 if (!CommonInst)
6109 return nullptr;
6110
6111 // Check if SCEV exprs for instructions are identical.
6112 const SCEV *CommonSCEV = getSCEV(CommonInst);
6113 bool SCEVExprsIdentical =
6115 [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); });
6116 return SCEVExprsIdentical ? CommonSCEV : nullptr;
6117}
6118
6119const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6120 if (const SCEV *S = createAddRecFromPHI(PN))
6121 return S;
6122
6123 // We do not allow simplifying phi (undef, X) to X here, to avoid reusing the
6124 // phi node for X.
6125 if (Value *V = simplifyInstruction(
6126 PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
6127 /*UseInstrInfo=*/true, /*CanUseUndef=*/false}))
6128 return getSCEV(V);
6129
6130 if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN))
6131 return S;
6132
6133 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6134 return S;
6135
6136 // If it's not a loop phi, we can't handle it yet.
6137 return getUnknown(PN);
6138}
6139
6140bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6141 SCEVTypes RootKind) {
6142 struct FindClosure {
6143 const SCEV *OperandToFind;
6144 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6145 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6146
6147 bool Found = false;
6148
6149 bool canRecurseInto(SCEVTypes Kind) const {
6150 // We can only recurse into the SCEV expression of the same effective type
6151 // as the type of our root SCEV expression, and into zero-extensions.
6152 return RootKind == Kind || NonSequentialRootKind == Kind ||
6153 scZeroExtend == Kind;
6154 };
6155
6156 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6157 : OperandToFind(OperandToFind), RootKind(RootKind),
6158 NonSequentialRootKind(
6160 RootKind)) {}
6161
6162 bool follow(const SCEV *S) {
6163 Found = S == OperandToFind;
6164
6165 return !isDone() && canRecurseInto(S->getSCEVType());
6166 }
6167
6168 bool isDone() const { return Found; }
6169 };
6170
6171 FindClosure FC(OperandToFind, RootKind);
6172 visitAll(Root, FC);
6173 return FC.Found;
6174}
6175
6176std::optional<const SCEV *>
6177ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6178 ICmpInst *Cond,
6179 Value *TrueVal,
6180 Value *FalseVal) {
6181 // Try to match some simple smax or umax patterns.
6182 auto *ICI = Cond;
6183
6184 Value *LHS = ICI->getOperand(0);
6185 Value *RHS = ICI->getOperand(1);
6186
6187 switch (ICI->getPredicate()) {
6188 case ICmpInst::ICMP_SLT:
6189 case ICmpInst::ICMP_SLE:
6190 case ICmpInst::ICMP_ULT:
6191 case ICmpInst::ICMP_ULE:
6192 std::swap(LHS, RHS);
6193 [[fallthrough]];
6194 case ICmpInst::ICMP_SGT:
6195 case ICmpInst::ICMP_SGE:
6196 case ICmpInst::ICMP_UGT:
6197 case ICmpInst::ICMP_UGE:
6198 // a > b ? a+x : b+x -> max(a, b)+x
6199 // a > b ? b+x : a+x -> min(a, b)+x
6201 bool Signed = ICI->isSigned();
6202 const SCEV *LA = getSCEV(TrueVal);
6203 const SCEV *RA = getSCEV(FalseVal);
6204 const SCEV *LS = getSCEV(LHS);
6205 const SCEV *RS = getSCEV(RHS);
6206 if (LA->getType()->isPointerTy()) {
6207 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6208 // Need to make sure we can't produce weird expressions involving
6209 // negated pointers.
6210 if (LA == LS && RA == RS)
6211 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6212 if (LA == RS && RA == LS)
6213 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6214 }
6215 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6216 if (Op->getType()->isPointerTy()) {
6219 return Op;
6220 }
6221 if (Signed)
6222 Op = getNoopOrSignExtend(Op, Ty);
6223 else
6224 Op = getNoopOrZeroExtend(Op, Ty);
6225 return Op;
6226 };
6227 LS = CoerceOperand(LS);
6228 RS = CoerceOperand(RS);
6230 break;
6231 const SCEV *LDiff = getMinusSCEV(LA, LS);
6232 const SCEV *RDiff = getMinusSCEV(RA, RS);
6233 if (LDiff == RDiff)
6234 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6235 LDiff);
6236 LDiff = getMinusSCEV(LA, RS);
6237 RDiff = getMinusSCEV(RA, LS);
6238 if (LDiff == RDiff)
6239 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6240 LDiff);
6241 }
6242 break;
6243 case ICmpInst::ICMP_NE:
6244 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6245 std::swap(TrueVal, FalseVal);
6246 [[fallthrough]];
6247 case ICmpInst::ICMP_EQ:
6248 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6251 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6252 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6253 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6254 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6255 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6256 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6257 return getAddExpr(getUMaxExpr(X, C), Y);
6258 }
6259 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6260 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6261 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6262 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6264 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6265 const SCEV *X = getSCEV(LHS);
6266 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6267 X = ZExt->getOperand();
6268 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6269 const SCEV *FalseValExpr = getSCEV(FalseVal);
6270 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6271 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6272 /*Sequential=*/true);
6273 }
6274 }
6275 break;
6276 default:
6277 break;
6278 }
6279
6280 return std::nullopt;
6281}
6282
6283static std::optional<const SCEV *>
6285 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6286 assert(CondExpr->getType()->isIntegerTy(1) &&
6287 TrueExpr->getType() == FalseExpr->getType() &&
6288 TrueExpr->getType()->isIntegerTy(1) &&
6289 "Unexpected operands of a select.");
6290
6291 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6292 // --> C + (umin_seq cond, x - C)
6293 //
6294 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6295 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6296 // --> C + (umin_seq ~cond, x - C)
6297
6298 // FIXME: while we can't legally model the case where both of the hands
6299 // are fully variable, we only require that the *difference* is constant.
6300 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6301 return std::nullopt;
6302
6303 const SCEV *X, *C;
6304 if (isa<SCEVConstant>(TrueExpr)) {
6305 CondExpr = SE->getNotSCEV(CondExpr);
6306 X = FalseExpr;
6307 C = TrueExpr;
6308 } else {
6309 X = TrueExpr;
6310 C = FalseExpr;
6311 }
6312 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6313 /*Sequential=*/true));
6314}
6315
6316static std::optional<const SCEV *>
6318 Value *FalseVal) {
6319 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6320 return std::nullopt;
6321
6322 const auto *SECond = SE->getSCEV(Cond);
6323 const auto *SETrue = SE->getSCEV(TrueVal);
6324 const auto *SEFalse = SE->getSCEV(FalseVal);
6325 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6326}
6327
6328const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6329 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6330 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6331 assert(TrueVal->getType() == FalseVal->getType() &&
6332 V->getType() == TrueVal->getType() &&
6333 "Types of select hands and of the result must match.");
6334
6335 // For now, only deal with i1-typed `select`s.
6336 if (!V->getType()->isIntegerTy(1))
6337 return getUnknown(V);
6338
6339 if (std::optional<const SCEV *> S =
6340 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6341 return *S;
6342
6343 return getUnknown(V);
6344}
6345
6346const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6347 Value *TrueVal,
6348 Value *FalseVal) {
6349 // Handle "constant" branch or select. This can occur for instance when a
6350 // loop pass transforms an inner loop and moves on to process the outer loop.
6351 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6352 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6353
6354 if (auto *I = dyn_cast<Instruction>(V)) {
6355 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6356 if (std::optional<const SCEV *> S =
6357 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6358 TrueVal, FalseVal))
6359 return *S;
6360 }
6361 }
6362
6363 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6364}
6365
6366/// Expand GEP instructions into add and multiply operations. This allows them
6367/// to be analyzed by regular SCEV code.
6368const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6369 assert(GEP->getSourceElementType()->isSized() &&
6370 "GEP source element type must be sized");
6371
6373 for (Value *Index : GEP->indices())
6374 IndexExprs.push_back(getSCEV(Index));
6375 return getGEPExpr(GEP, IndexExprs);
6376}
6377
6378APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S,
6379 const Instruction *CtxI) {
6380 uint64_t BitWidth = getTypeSizeInBits(S->getType());
6381 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6382 return TrailingZeros >= BitWidth
6384 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6385 };
6386 auto GetGCDMultiple = [this, CtxI](const SCEVNAryExpr *N) {
6387 // The result is GCD of all operands results.
6388 APInt Res = getConstantMultiple(N->getOperand(0), CtxI);
6389 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6391 Res, getConstantMultiple(N->getOperand(I), CtxI));
6392 return Res;
6393 };
6394
6395 switch (S->getSCEVType()) {
6396 case scConstant:
6397 return cast<SCEVConstant>(S)->getAPInt();
6398 case scPtrToAddr:
6399 case scPtrToInt:
6400 return getConstantMultiple(cast<SCEVCastExpr>(S)->getOperand());
6401 case scUDivExpr:
6402 case scVScale:
6403 return APInt(BitWidth, 1);
6404 case scTruncate: {
6405 // Only multiples that are a power of 2 will hold after truncation.
6406 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6407 uint32_t TZ = getMinTrailingZeros(T->getOperand(), CtxI);
6408 return GetShiftedByZeros(TZ);
6409 }
6410 case scZeroExtend: {
6411 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6412 return getConstantMultiple(Z->getOperand(), CtxI).zext(BitWidth);
6413 }
6414 case scSignExtend: {
6415 // Only multiples that are a power of 2 will hold after sext.
6416 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6417 uint32_t TZ = getMinTrailingZeros(E->getOperand(), CtxI);
6418 return GetShiftedByZeros(TZ);
6419 }
6420 case scMulExpr: {
6421 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6422 if (M->hasNoUnsignedWrap()) {
6423 // The result is the product of all operand results.
6424 APInt Res = getConstantMultiple(M->getOperand(0), CtxI);
6425 for (const SCEV *Operand : M->operands().drop_front())
6426 Res = Res * getConstantMultiple(Operand, CtxI);
6427 return Res;
6428 }
6429
6430 // If there are no wrap guarentees, find the trailing zeros, which is the
6431 // sum of trailing zeros for all its operands.
6432 uint32_t TZ = 0;
6433 for (const SCEV *Operand : M->operands())
6434 TZ += getMinTrailingZeros(Operand, CtxI);
6435 return GetShiftedByZeros(TZ);
6436 }
6437 case scAddExpr:
6438 case scAddRecExpr: {
6439 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6440 if (N->hasNoUnsignedWrap())
6441 return GetGCDMultiple(N);
6442 // Find the trailing bits, which is the minimum of its operands.
6443 uint32_t TZ = getMinTrailingZeros(N->getOperand(0), CtxI);
6444 for (const SCEV *Operand : N->operands().drop_front())
6445 TZ = std::min(TZ, getMinTrailingZeros(Operand, CtxI));
6446 return GetShiftedByZeros(TZ);
6447 }
6448 case scUMaxExpr:
6449 case scSMaxExpr:
6450 case scUMinExpr:
6451 case scSMinExpr:
6453 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6454 case scUnknown: {
6455 // Ask ValueTracking for known bits. SCEVUnknown only become available at
6456 // the point their underlying IR instruction has been defined. If CtxI was
6457 // not provided, use:
6458 // * the first instruction in the entry block if it is an argument
6459 // * the instruction itself otherwise.
6460 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6461 if (!CtxI) {
6462 if (isa<Argument>(U->getValue()))
6463 CtxI = &*F.getEntryBlock().begin();
6464 else if (auto *I = dyn_cast<Instruction>(U->getValue()))
6465 CtxI = I;
6466 }
6467 unsigned Known =
6468 computeKnownBits(U->getValue(), getDataLayout(), &AC, CtxI, &DT)
6469 .countMinTrailingZeros();
6470 return GetShiftedByZeros(Known);
6471 }
6472 case scCouldNotCompute:
6473 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6474 }
6475 llvm_unreachable("Unknown SCEV kind!");
6476}
6477
6479 const Instruction *CtxI) {
6480 // Skip looking up and updating the cache if there is a context instruction,
6481 // as the result will only be valid in the specified context.
6482 if (CtxI)
6483 return getConstantMultipleImpl(S, CtxI);
6484
6485 auto I = ConstantMultipleCache.find(S);
6486 if (I != ConstantMultipleCache.end())
6487 return I->second;
6488
6489 APInt Result = getConstantMultipleImpl(S, CtxI);
6490 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6491 assert(InsertPair.second && "Should insert a new key");
6492 return InsertPair.first->second;
6493}
6494
6496 APInt Multiple = getConstantMultiple(S);
6497 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6498}
6499
6501 const Instruction *CtxI) {
6502 return std::min(getConstantMultiple(S, CtxI).countTrailingZeros(),
6503 (unsigned)getTypeSizeInBits(S->getType()));
6504}
6505
6506/// Helper method to assign a range to V from metadata present in the IR.
6507static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6509 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6510 return getConstantRangeFromMetadata(*MD);
6511 if (const auto *CB = dyn_cast<CallBase>(V))
6512 if (std::optional<ConstantRange> Range = CB->getRange())
6513 return Range;
6514 }
6515 if (auto *A = dyn_cast<Argument>(V))
6516 if (std::optional<ConstantRange> Range = A->getRange())
6517 return Range;
6518
6519 return std::nullopt;
6520}
6521
6523 SCEV::NoWrapFlags Flags) {
6524 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6525 AddRec->setNoWrapFlags(Flags);
6526 UnsignedRanges.erase(AddRec);
6527 SignedRanges.erase(AddRec);
6528 ConstantMultipleCache.erase(AddRec);
6529 }
6530}
6531
6532ConstantRange ScalarEvolution::
6533getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6534 const DataLayout &DL = getDataLayout();
6535
6536 unsigned BitWidth = getTypeSizeInBits(U->getType());
6537 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6538
6539 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6540 // use information about the trip count to improve our available range. Note
6541 // that the trip count independent cases are already handled by known bits.
6542 // WARNING: The definition of recurrence used here is subtly different than
6543 // the one used by AddRec (and thus most of this file). Step is allowed to
6544 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6545 // and other addrecs in the same loop (for non-affine addrecs). The code
6546 // below intentionally handles the case where step is not loop invariant.
6547 auto *P = dyn_cast<PHINode>(U->getValue());
6548 if (!P)
6549 return FullSet;
6550
6551 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6552 // even the values that are not available in these blocks may come from them,
6553 // and this leads to false-positive recurrence test.
6554 for (auto *Pred : predecessors(P->getParent()))
6555 if (!DT.isReachableFromEntry(Pred))
6556 return FullSet;
6557
6558 BinaryOperator *BO;
6559 Value *Start, *Step;
6560 if (!matchSimpleRecurrence(P, BO, Start, Step))
6561 return FullSet;
6562
6563 // If we found a recurrence in reachable code, we must be in a loop. Note
6564 // that BO might be in some subloop of L, and that's completely okay.
6565 auto *L = LI.getLoopFor(P->getParent());
6566 assert(L && L->getHeader() == P->getParent());
6567 if (!L->contains(BO->getParent()))
6568 // NOTE: This bailout should be an assert instead. However, asserting
6569 // the condition here exposes a case where LoopFusion is querying SCEV
6570 // with malformed loop information during the midst of the transform.
6571 // There doesn't appear to be an obvious fix, so for the moment bailout
6572 // until the caller issue can be fixed. PR49566 tracks the bug.
6573 return FullSet;
6574
6575 // TODO: Extend to other opcodes such as mul, and div
6576 switch (BO->getOpcode()) {
6577 default:
6578 return FullSet;
6579 case Instruction::AShr:
6580 case Instruction::LShr:
6581 case Instruction::Shl:
6582 break;
6583 };
6584
6585 if (BO->getOperand(0) != P)
6586 // TODO: Handle the power function forms some day.
6587 return FullSet;
6588
6589 unsigned TC = getSmallConstantMaxTripCount(L);
6590 if (!TC || TC >= BitWidth)
6591 return FullSet;
6592
6593 auto KnownStart = computeKnownBits(Start, DL, &AC, nullptr, &DT);
6594 auto KnownStep = computeKnownBits(Step, DL, &AC, nullptr, &DT);
6595 assert(KnownStart.getBitWidth() == BitWidth &&
6596 KnownStep.getBitWidth() == BitWidth);
6597
6598 // Compute total shift amount, being careful of overflow and bitwidths.
6599 auto MaxShiftAmt = KnownStep.getMaxValue();
6600 APInt TCAP(BitWidth, TC-1);
6601 bool Overflow = false;
6602 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6603 if (Overflow)
6604 return FullSet;
6605
6606 switch (BO->getOpcode()) {
6607 default:
6608 llvm_unreachable("filtered out above");
6609 case Instruction::AShr: {
6610 // For each ashr, three cases:
6611 // shift = 0 => unchanged value
6612 // saturation => 0 or -1
6613 // other => a value closer to zero (of the same sign)
6614 // Thus, the end value is closer to zero than the start.
6615 auto KnownEnd = KnownBits::ashr(KnownStart,
6616 KnownBits::makeConstant(TotalShift));
6617 if (KnownStart.isNonNegative())
6618 // Analogous to lshr (simply not yet canonicalized)
6619 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6620 KnownStart.getMaxValue() + 1);
6621 if (KnownStart.isNegative())
6622 // End >=u Start && End <=s Start
6623 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6624 KnownEnd.getMaxValue() + 1);
6625 break;
6626 }
6627 case Instruction::LShr: {
6628 // For each lshr, three cases:
6629 // shift = 0 => unchanged value
6630 // saturation => 0
6631 // other => a smaller positive number
6632 // Thus, the low end of the unsigned range is the last value produced.
6633 auto KnownEnd = KnownBits::lshr(KnownStart,
6634 KnownBits::makeConstant(TotalShift));
6635 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6636 KnownStart.getMaxValue() + 1);
6637 }
6638 case Instruction::Shl: {
6639 // Iff no bits are shifted out, value increases on every shift.
6640 auto KnownEnd = KnownBits::shl(KnownStart,
6641 KnownBits::makeConstant(TotalShift));
6642 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6643 return ConstantRange(KnownStart.getMinValue(),
6644 KnownEnd.getMaxValue() + 1);
6645 break;
6646 }
6647 };
6648 return FullSet;
6649}
6650
6651const ConstantRange &
6652ScalarEvolution::getRangeRefIter(const SCEV *S,
6653 ScalarEvolution::RangeSignHint SignHint) {
6654 DenseMap<const SCEV *, ConstantRange> &Cache =
6655 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6656 : SignedRanges;
6658 SmallPtrSet<const SCEV *, 8> Seen;
6659
6660 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6661 // SCEVUnknown PHI node.
6662 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6663 if (!Seen.insert(Expr).second)
6664 return;
6665 if (Cache.contains(Expr))
6666 return;
6667 switch (Expr->getSCEVType()) {
6668 case scUnknown:
6669 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6670 break;
6671 [[fallthrough]];
6672 case scConstant:
6673 case scVScale:
6674 case scTruncate:
6675 case scZeroExtend:
6676 case scSignExtend:
6677 case scPtrToAddr:
6678 case scPtrToInt:
6679 case scAddExpr:
6680 case scMulExpr:
6681 case scUDivExpr:
6682 case scAddRecExpr:
6683 case scUMaxExpr:
6684 case scSMaxExpr:
6685 case scUMinExpr:
6686 case scSMinExpr:
6688 WorkList.push_back(Expr);
6689 break;
6690 case scCouldNotCompute:
6691 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6692 }
6693 };
6694 AddToWorklist(S);
6695
6696 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6697 for (unsigned I = 0; I != WorkList.size(); ++I) {
6698 const SCEV *P = WorkList[I];
6699 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6700 // If it is not a `SCEVUnknown`, just recurse into operands.
6701 if (!UnknownS) {
6702 for (const SCEV *Op : P->operands())
6703 AddToWorklist(Op);
6704 continue;
6705 }
6706 // `SCEVUnknown`'s require special treatment.
6707 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6708 if (!PendingPhiRangesIter.insert(P).second)
6709 continue;
6710 for (auto &Op : reverse(P->operands()))
6711 AddToWorklist(getSCEV(Op));
6712 }
6713 }
6714
6715 if (!WorkList.empty()) {
6716 // Use getRangeRef to compute ranges for items in the worklist in reverse
6717 // order. This will force ranges for earlier operands to be computed before
6718 // their users in most cases.
6719 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6720 getRangeRef(P, SignHint);
6721
6722 if (auto *UnknownS = dyn_cast<SCEVUnknown>(P))
6723 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue()))
6724 PendingPhiRangesIter.erase(P);
6725 }
6726 }
6727
6728 return getRangeRef(S, SignHint, 0);
6729}
6730
6731/// Determine the range for a particular SCEV. If SignHint is
6732/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6733/// with a "cleaner" unsigned (resp. signed) representation.
6734const ConstantRange &ScalarEvolution::getRangeRef(
6735 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6736 DenseMap<const SCEV *, ConstantRange> &Cache =
6737 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6738 : SignedRanges;
6740 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6742
6743 // See if we've computed this range already.
6745 if (I != Cache.end())
6746 return I->second;
6747
6748 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6749 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6750
6751 // Switch to iteratively computing the range for S, if it is part of a deeply
6752 // nested expression.
6754 return getRangeRefIter(S, SignHint);
6755
6756 unsigned BitWidth = getTypeSizeInBits(S->getType());
6757 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6758 using OBO = OverflowingBinaryOperator;
6759
6760 // If the value has known zeros, the maximum value will have those known zeros
6761 // as well.
6762 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6763 APInt Multiple = getNonZeroConstantMultiple(S);
6764 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6765 if (!Remainder.isZero())
6766 ConservativeResult =
6767 ConstantRange(APInt::getMinValue(BitWidth),
6768 APInt::getMaxValue(BitWidth) - Remainder + 1);
6769 }
6770 else {
6771 uint32_t TZ = getMinTrailingZeros(S);
6772 if (TZ != 0) {
6773 ConservativeResult = ConstantRange(
6775 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6776 }
6777 }
6778
6779 switch (S->getSCEVType()) {
6780 case scConstant:
6781 llvm_unreachable("Already handled above.");
6782 case scVScale:
6783 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6784 case scTruncate: {
6785 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6786 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6787 return setRange(
6788 Trunc, SignHint,
6789 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6790 }
6791 case scZeroExtend: {
6792 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6793 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6794 return setRange(
6795 ZExt, SignHint,
6796 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6797 }
6798 case scSignExtend: {
6799 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6800 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6801 return setRange(
6802 SExt, SignHint,
6803 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6804 }
6805 case scPtrToAddr:
6806 case scPtrToInt: {
6807 const SCEVCastExpr *Cast = cast<SCEVCastExpr>(S);
6808 ConstantRange X = getRangeRef(Cast->getOperand(), SignHint, Depth + 1);
6809 return setRange(Cast, SignHint, X);
6810 }
6811 case scAddExpr: {
6812 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6813 // Check if this is a URem pattern: A - (A / B) * B, which is always < B.
6814 const SCEV *URemLHS = nullptr, *URemRHS = nullptr;
6815 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED &&
6816 match(S, m_scev_URem(m_SCEV(URemLHS), m_SCEV(URemRHS), *this))) {
6817 ConstantRange LHSRange = getRangeRef(URemLHS, SignHint, Depth + 1);
6818 ConstantRange RHSRange = getRangeRef(URemRHS, SignHint, Depth + 1);
6819 ConservativeResult =
6820 ConservativeResult.intersectWith(LHSRange.urem(RHSRange), RangeType);
6821 }
6822 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6823 unsigned WrapType = OBO::AnyWrap;
6824 if (Add->hasNoSignedWrap())
6825 WrapType |= OBO::NoSignedWrap;
6826 if (Add->hasNoUnsignedWrap())
6827 WrapType |= OBO::NoUnsignedWrap;
6828 for (const SCEV *Op : drop_begin(Add->operands()))
6829 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType,
6830 RangeType);
6831 return setRange(Add, SignHint,
6832 ConservativeResult.intersectWith(X, RangeType));
6833 }
6834 case scMulExpr: {
6835 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6836 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6837 for (const SCEV *Op : drop_begin(Mul->operands()))
6838 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1));
6839 return setRange(Mul, SignHint,
6840 ConservativeResult.intersectWith(X, RangeType));
6841 }
6842 case scUDivExpr: {
6843 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6844 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6845 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6846 return setRange(UDiv, SignHint,
6847 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6848 }
6849 case scAddRecExpr: {
6850 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6851 // If there's no unsigned wrap, the value will never be less than its
6852 // initial value.
6853 if (AddRec->hasNoUnsignedWrap()) {
6854 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6855 if (!UnsignedMinValue.isZero())
6856 ConservativeResult = ConservativeResult.intersectWith(
6857 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6858 }
6859
6860 // If there's no signed wrap, and all the operands except initial value have
6861 // the same sign or zero, the value won't ever be:
6862 // 1: smaller than initial value if operands are non negative,
6863 // 2: bigger than initial value if operands are non positive.
6864 // For both cases, value can not cross signed min/max boundary.
6865 if (AddRec->hasNoSignedWrap()) {
6866 bool AllNonNeg = true;
6867 bool AllNonPos = true;
6868 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6869 if (!isKnownNonNegative(AddRec->getOperand(i)))
6870 AllNonNeg = false;
6871 if (!isKnownNonPositive(AddRec->getOperand(i)))
6872 AllNonPos = false;
6873 }
6874 if (AllNonNeg)
6875 ConservativeResult = ConservativeResult.intersectWith(
6878 RangeType);
6879 else if (AllNonPos)
6880 ConservativeResult = ConservativeResult.intersectWith(
6882 getSignedRangeMax(AddRec->getStart()) +
6883 1),
6884 RangeType);
6885 }
6886
6887 // TODO: non-affine addrec
6888 if (AddRec->isAffine()) {
6889 const SCEV *MaxBEScev =
6891 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
6892 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
6893
6894 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
6895 // MaxBECount's active bits are all <= AddRec's bit width.
6896 if (MaxBECount.getBitWidth() > BitWidth &&
6897 MaxBECount.getActiveBits() <= BitWidth)
6898 MaxBECount = MaxBECount.trunc(BitWidth);
6899 else if (MaxBECount.getBitWidth() < BitWidth)
6900 MaxBECount = MaxBECount.zext(BitWidth);
6901
6902 if (MaxBECount.getBitWidth() == BitWidth) {
6903 auto RangeFromAffine = getRangeForAffineAR(
6904 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6905 ConservativeResult =
6906 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
6907
6908 auto RangeFromFactoring = getRangeViaFactoring(
6909 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6910 ConservativeResult =
6911 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
6912 }
6913 }
6914
6915 // Now try symbolic BE count and more powerful methods.
6917 const SCEV *SymbolicMaxBECount =
6919 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
6920 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
6921 AddRec->hasNoSelfWrap()) {
6922 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
6923 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
6924 ConservativeResult =
6925 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
6926 }
6927 }
6928 }
6929
6930 return setRange(AddRec, SignHint, std::move(ConservativeResult));
6931 }
6932 case scUMaxExpr:
6933 case scSMaxExpr:
6934 case scUMinExpr:
6935 case scSMinExpr:
6936 case scSequentialUMinExpr: {
6938 switch (S->getSCEVType()) {
6939 case scUMaxExpr:
6940 ID = Intrinsic::umax;
6941 break;
6942 case scSMaxExpr:
6943 ID = Intrinsic::smax;
6944 break;
6945 case scUMinExpr:
6947 ID = Intrinsic::umin;
6948 break;
6949 case scSMinExpr:
6950 ID = Intrinsic::smin;
6951 break;
6952 default:
6953 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
6954 }
6955
6956 const auto *NAry = cast<SCEVNAryExpr>(S);
6957 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
6958 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
6959 X = X.intrinsic(
6960 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
6961 return setRange(S, SignHint,
6962 ConservativeResult.intersectWith(X, RangeType));
6963 }
6964 case scUnknown: {
6965 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6966 Value *V = U->getValue();
6967
6968 // Check if the IR explicitly contains !range metadata.
6969 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
6970 if (MDRange)
6971 ConservativeResult =
6972 ConservativeResult.intersectWith(*MDRange, RangeType);
6973
6974 // Use facts about recurrences in the underlying IR. Note that add
6975 // recurrences are AddRecExprs and thus don't hit this path. This
6976 // primarily handles shift recurrences.
6977 auto CR = getRangeForUnknownRecurrence(U);
6978 ConservativeResult = ConservativeResult.intersectWith(CR);
6979
6980 // See if ValueTracking can give us a useful range.
6981 const DataLayout &DL = getDataLayout();
6982 KnownBits Known = computeKnownBits(V, DL, &AC, nullptr, &DT);
6983 if (Known.getBitWidth() != BitWidth)
6984 Known = Known.zextOrTrunc(BitWidth);
6985
6986 // ValueTracking may be able to compute a tighter result for the number of
6987 // sign bits than for the value of those sign bits.
6988 unsigned NS = ComputeNumSignBits(V, DL, &AC, nullptr, &DT);
6989 if (U->getType()->isPointerTy()) {
6990 // If the pointer size is larger than the index size type, this can cause
6991 // NS to be larger than BitWidth. So compensate for this.
6992 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
6993 int ptrIdxDiff = ptrSize - BitWidth;
6994 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
6995 NS -= ptrIdxDiff;
6996 }
6997
6998 if (NS > 1) {
6999 // If we know any of the sign bits, we know all of the sign bits.
7000 if (!Known.Zero.getHiBits(NS).isZero())
7001 Known.Zero.setHighBits(NS);
7002 if (!Known.One.getHiBits(NS).isZero())
7003 Known.One.setHighBits(NS);
7004 }
7005
7006 if (Known.getMinValue() != Known.getMaxValue() + 1)
7007 ConservativeResult = ConservativeResult.intersectWith(
7008 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
7009 RangeType);
7010 if (NS > 1)
7011 ConservativeResult = ConservativeResult.intersectWith(
7012 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
7013 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
7014 RangeType);
7015
7016 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
7017 // Strengthen the range if the underlying IR value is a
7018 // global/alloca/heap allocation using the size of the object.
7019 bool CanBeNull, CanBeFreed;
7020 uint64_t DerefBytes =
7021 V->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed);
7022 if (DerefBytes > 1 && isUIntN(BitWidth, DerefBytes)) {
7023 // The highest address the object can start is DerefBytes bytes before
7024 // the end (unsigned max value). If this value is not a multiple of the
7025 // alignment, the last possible start value is the next lowest multiple
7026 // of the alignment. Note: The computations below cannot overflow,
7027 // because if they would there's no possible start address for the
7028 // object.
7029 APInt MaxVal =
7030 APInt::getMaxValue(BitWidth) - APInt(BitWidth, DerefBytes);
7031 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
7032 uint64_t Rem = MaxVal.urem(Align);
7033 MaxVal -= APInt(BitWidth, Rem);
7034 APInt MinVal = APInt::getZero(BitWidth);
7035 if (llvm::isKnownNonZero(V, DL))
7036 MinVal = Align;
7037 ConservativeResult = ConservativeResult.intersectWith(
7038 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
7039 }
7040 }
7041
7042 // A range of Phi is a subset of union of all ranges of its input.
7043 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
7044 // Make sure that we do not run over cycled Phis.
7045 if (PendingPhiRanges.insert(Phi).second) {
7046 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
7047
7048 for (const auto &Op : Phi->operands()) {
7049 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
7050 RangeFromOps = RangeFromOps.unionWith(OpRange);
7051 // No point to continue if we already have a full set.
7052 if (RangeFromOps.isFullSet())
7053 break;
7054 }
7055 ConservativeResult =
7056 ConservativeResult.intersectWith(RangeFromOps, RangeType);
7057 bool Erased = PendingPhiRanges.erase(Phi);
7058 assert(Erased && "Failed to erase Phi properly?");
7059 (void)Erased;
7060 }
7061 }
7062
7063 // vscale can't be equal to zero
7064 if (const auto *II = dyn_cast<IntrinsicInst>(V))
7065 if (II->getIntrinsicID() == Intrinsic::vscale) {
7066 ConstantRange Disallowed = APInt::getZero(BitWidth);
7067 ConservativeResult = ConservativeResult.difference(Disallowed);
7068 }
7069
7070 return setRange(U, SignHint, std::move(ConservativeResult));
7071 }
7072 case scCouldNotCompute:
7073 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
7074 }
7075
7076 return setRange(S, SignHint, std::move(ConservativeResult));
7077}
7078
7079// Given a StartRange, Step and MaxBECount for an expression compute a range of
7080// values that the expression can take. Initially, the expression has a value
7081// from StartRange and then is changed by Step up to MaxBECount times. Signed
7082// argument defines if we treat Step as signed or unsigned.
7084 const ConstantRange &StartRange,
7085 const APInt &MaxBECount,
7086 bool Signed) {
7087 unsigned BitWidth = Step.getBitWidth();
7088 assert(BitWidth == StartRange.getBitWidth() &&
7089 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
7090 // If either Step or MaxBECount is 0, then the expression won't change, and we
7091 // just need to return the initial range.
7092 if (Step == 0 || MaxBECount == 0)
7093 return StartRange;
7094
7095 // If we don't know anything about the initial value (i.e. StartRange is
7096 // FullRange), then we don't know anything about the final range either.
7097 // Return FullRange.
7098 if (StartRange.isFullSet())
7099 return ConstantRange::getFull(BitWidth);
7100
7101 // If Step is signed and negative, then we use its absolute value, but we also
7102 // note that we're moving in the opposite direction.
7103 bool Descending = Signed && Step.isNegative();
7104
7105 if (Signed)
7106 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
7107 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
7108 // This equations hold true due to the well-defined wrap-around behavior of
7109 // APInt.
7110 Step = Step.abs();
7111
7112 // Check if Offset is more than full span of BitWidth. If it is, the
7113 // expression is guaranteed to overflow.
7114 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
7115 return ConstantRange::getFull(BitWidth);
7116
7117 // Offset is by how much the expression can change. Checks above guarantee no
7118 // overflow here.
7119 APInt Offset = Step * MaxBECount;
7120
7121 // Minimum value of the final range will match the minimal value of StartRange
7122 // if the expression is increasing and will be decreased by Offset otherwise.
7123 // Maximum value of the final range will match the maximal value of StartRange
7124 // if the expression is decreasing and will be increased by Offset otherwise.
7125 APInt StartLower = StartRange.getLower();
7126 APInt StartUpper = StartRange.getUpper() - 1;
7127 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
7128 : (StartUpper + std::move(Offset));
7129
7130 // It's possible that the new minimum/maximum value will fall into the initial
7131 // range (due to wrap around). This means that the expression can take any
7132 // value in this bitwidth, and we have to return full range.
7133 if (StartRange.contains(MovedBoundary))
7134 return ConstantRange::getFull(BitWidth);
7135
7136 APInt NewLower =
7137 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7138 APInt NewUpper =
7139 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7140 NewUpper += 1;
7141
7142 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7143 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7144}
7145
7146ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7147 const SCEV *Step,
7148 const APInt &MaxBECount) {
7149 assert(getTypeSizeInBits(Start->getType()) ==
7150 getTypeSizeInBits(Step->getType()) &&
7151 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7152 "mismatched bit widths");
7153
7154 // First, consider step signed.
7155 ConstantRange StartSRange = getSignedRange(Start);
7156 ConstantRange StepSRange = getSignedRange(Step);
7157
7158 // If Step can be both positive and negative, we need to find ranges for the
7159 // maximum absolute step values in both directions and union them.
7160 ConstantRange SR = getRangeForAffineARHelper(
7161 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7163 StartSRange, MaxBECount,
7164 /* Signed = */ true));
7165
7166 // Next, consider step unsigned.
7167 ConstantRange UR = getRangeForAffineARHelper(
7168 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7169 /* Signed = */ false);
7170
7171 // Finally, intersect signed and unsigned ranges.
7173}
7174
7175ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7176 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7177 ScalarEvolution::RangeSignHint SignHint) {
7178 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7179 assert(AddRec->hasNoSelfWrap() &&
7180 "This only works for non-self-wrapping AddRecs!");
7181 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7182 const SCEV *Step = AddRec->getStepRecurrence(*this);
7183 // Only deal with constant step to save compile time.
7184 if (!isa<SCEVConstant>(Step))
7185 return ConstantRange::getFull(BitWidth);
7186 // Let's make sure that we can prove that we do not self-wrap during
7187 // MaxBECount iterations. We need this because MaxBECount is a maximum
7188 // iteration count estimate, and we might infer nw from some exit for which we
7189 // do not know max exit count (or any other side reasoning).
7190 // TODO: Turn into assert at some point.
7191 if (getTypeSizeInBits(MaxBECount->getType()) >
7192 getTypeSizeInBits(AddRec->getType()))
7193 return ConstantRange::getFull(BitWidth);
7194 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7195 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7196 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7197 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7198 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7199 MaxItersWithoutWrap))
7200 return ConstantRange::getFull(BitWidth);
7201
7202 ICmpInst::Predicate LEPred =
7204 ICmpInst::Predicate GEPred =
7206 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7207
7208 // We know that there is no self-wrap. Let's take Start and End values and
7209 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7210 // the iteration. They either lie inside the range [Min(Start, End),
7211 // Max(Start, End)] or outside it:
7212 //
7213 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7214 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7215 //
7216 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7217 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7218 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7219 // Start <= End and step is positive, or Start >= End and step is negative.
7220 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7221 ConstantRange StartRange = getRangeRef(Start, SignHint);
7222 ConstantRange EndRange = getRangeRef(End, SignHint);
7223 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7224 // If they already cover full iteration space, we will know nothing useful
7225 // even if we prove what we want to prove.
7226 if (RangeBetween.isFullSet())
7227 return RangeBetween;
7228 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7229 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7230 : RangeBetween.isWrappedSet();
7231 if (IsWrappedSet)
7232 return ConstantRange::getFull(BitWidth);
7233
7234 if (isKnownPositive(Step) &&
7235 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7236 return RangeBetween;
7237 if (isKnownNegative(Step) &&
7238 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7239 return RangeBetween;
7240 return ConstantRange::getFull(BitWidth);
7241}
7242
7243ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7244 const SCEV *Step,
7245 const APInt &MaxBECount) {
7246 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7247 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7248
7249 unsigned BitWidth = MaxBECount.getBitWidth();
7250 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7251 getTypeSizeInBits(Step->getType()) == BitWidth &&
7252 "mismatched bit widths");
7253
7254 struct SelectPattern {
7255 Value *Condition = nullptr;
7256 APInt TrueValue;
7257 APInt FalseValue;
7258
7259 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7260 const SCEV *S) {
7261 std::optional<unsigned> CastOp;
7262 APInt Offset(BitWidth, 0);
7263
7265 "Should be!");
7266
7267 // Peel off a constant offset. In the future we could consider being
7268 // smarter here and handle {Start+Step,+,Step} too.
7269 const APInt *Off;
7270 if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
7271 Offset = *Off;
7272
7273 // Peel off a cast operation
7274 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7275 CastOp = SCast->getSCEVType();
7276 S = SCast->getOperand();
7277 }
7278
7279 using namespace llvm::PatternMatch;
7280
7281 auto *SU = dyn_cast<SCEVUnknown>(S);
7282 const APInt *TrueVal, *FalseVal;
7283 if (!SU ||
7284 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7285 m_APInt(FalseVal)))) {
7286 Condition = nullptr;
7287 return;
7288 }
7289
7290 TrueValue = *TrueVal;
7291 FalseValue = *FalseVal;
7292
7293 // Re-apply the cast we peeled off earlier
7294 if (CastOp)
7295 switch (*CastOp) {
7296 default:
7297 llvm_unreachable("Unknown SCEV cast type!");
7298
7299 case scTruncate:
7300 TrueValue = TrueValue.trunc(BitWidth);
7301 FalseValue = FalseValue.trunc(BitWidth);
7302 break;
7303 case scZeroExtend:
7304 TrueValue = TrueValue.zext(BitWidth);
7305 FalseValue = FalseValue.zext(BitWidth);
7306 break;
7307 case scSignExtend:
7308 TrueValue = TrueValue.sext(BitWidth);
7309 FalseValue = FalseValue.sext(BitWidth);
7310 break;
7311 }
7312
7313 // Re-apply the constant offset we peeled off earlier
7314 TrueValue += Offset;
7315 FalseValue += Offset;
7316 }
7317
7318 bool isRecognized() { return Condition != nullptr; }
7319 };
7320
7321 SelectPattern StartPattern(*this, BitWidth, Start);
7322 if (!StartPattern.isRecognized())
7323 return ConstantRange::getFull(BitWidth);
7324
7325 SelectPattern StepPattern(*this, BitWidth, Step);
7326 if (!StepPattern.isRecognized())
7327 return ConstantRange::getFull(BitWidth);
7328
7329 if (StartPattern.Condition != StepPattern.Condition) {
7330 // We don't handle this case today; but we could, by considering four
7331 // possibilities below instead of two. I'm not sure if there are cases where
7332 // that will help over what getRange already does, though.
7333 return ConstantRange::getFull(BitWidth);
7334 }
7335
7336 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7337 // construct arbitrary general SCEV expressions here. This function is called
7338 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7339 // say) can end up caching a suboptimal value.
7340
7341 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7342 // C2352 and C2512 (otherwise it isn't needed).
7343
7344 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7345 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7346 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7347 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7348
7349 ConstantRange TrueRange =
7350 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7351 ConstantRange FalseRange =
7352 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7353
7354 return TrueRange.unionWith(FalseRange);
7355}
7356
7357SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7358 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7359 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7360
7361 // Return early if there are no flags to propagate to the SCEV.
7363 if (BinOp->hasNoUnsignedWrap())
7365 if (BinOp->hasNoSignedWrap())
7367 if (Flags == SCEV::FlagAnyWrap)
7368 return SCEV::FlagAnyWrap;
7369
7370 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7371}
7372
7373const Instruction *
7374ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7375 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7376 return &*AddRec->getLoop()->getHeader()->begin();
7377 if (auto *U = dyn_cast<SCEVUnknown>(S))
7378 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7379 return I;
7380 return nullptr;
7381}
7382
7383const Instruction *
7384ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops,
7385 bool &Precise) {
7386 Precise = true;
7387 // Do a bounded search of the def relation of the requested SCEVs.
7388 SmallPtrSet<const SCEV *, 16> Visited;
7390 auto pushOp = [&](const SCEV *S) {
7391 if (!Visited.insert(S).second)
7392 return;
7393 // Threshold of 30 here is arbitrary.
7394 if (Visited.size() > 30) {
7395 Precise = false;
7396 return;
7397 }
7398 Worklist.push_back(S);
7399 };
7400
7401 for (const auto *S : Ops)
7402 pushOp(S);
7403
7404 const Instruction *Bound = nullptr;
7405 while (!Worklist.empty()) {
7406 auto *S = Worklist.pop_back_val();
7407 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7408 if (!Bound || DT.dominates(Bound, DefI))
7409 Bound = DefI;
7410 } else {
7411 for (const auto *Op : S->operands())
7412 pushOp(Op);
7413 }
7414 }
7415 return Bound ? Bound : &*F.getEntryBlock().begin();
7416}
7417
7418const Instruction *
7419ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) {
7420 bool Discard;
7421 return getDefiningScopeBound(Ops, Discard);
7422}
7423
7424bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7425 const Instruction *B) {
7426 if (A->getParent() == B->getParent() &&
7428 B->getIterator()))
7429 return true;
7430
7431 auto *BLoop = LI.getLoopFor(B->getParent());
7432 if (BLoop && BLoop->getHeader() == B->getParent() &&
7433 BLoop->getLoopPreheader() == A->getParent() &&
7435 A->getParent()->end()) &&
7436 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7437 B->getIterator()))
7438 return true;
7439 return false;
7440}
7441
7442bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
7443 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ true);
7444 visitAll(Op, PC);
7445 return PC.MaybePoison.empty();
7446}
7447
7448bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
7449 return !SCEVExprContains(Op, [this](const SCEV *S) {
7450 const SCEV *Op1;
7451 bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
7452 // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
7453 // is a non-zero constant, we have to assume the UDiv may be UB.
7454 return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
7455 });
7456}
7457
7458bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7459 // Only proceed if we can prove that I does not yield poison.
7461 return false;
7462
7463 // At this point we know that if I is executed, then it does not wrap
7464 // according to at least one of NSW or NUW. If I is not executed, then we do
7465 // not know if the calculation that I represents would wrap. Multiple
7466 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7467 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7468 // derived from other instructions that map to the same SCEV. We cannot make
7469 // that guarantee for cases where I is not executed. So we need to find a
7470 // upper bound on the defining scope for the SCEV, and prove that I is
7471 // executed every time we enter that scope. When the bounding scope is a
7472 // loop (the common case), this is equivalent to proving I executes on every
7473 // iteration of that loop.
7475 for (const Use &Op : I->operands()) {
7476 // I could be an extractvalue from a call to an overflow intrinsic.
7477 // TODO: We can do better here in some cases.
7478 if (isSCEVable(Op->getType()))
7479 SCEVOps.push_back(getSCEV(Op));
7480 }
7481 auto *DefI = getDefiningScopeBound(SCEVOps);
7482 return isGuaranteedToTransferExecutionTo(DefI, I);
7483}
7484
7485bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7486 // If we know that \c I can never be poison period, then that's enough.
7487 if (isSCEVExprNeverPoison(I))
7488 return true;
7489
7490 // If the loop only has one exit, then we know that, if the loop is entered,
7491 // any instruction dominating that exit will be executed. If any such
7492 // instruction would result in UB, the addrec cannot be poison.
7493 //
7494 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7495 // also handles uses outside the loop header (they just need to dominate the
7496 // single exit).
7497
7498 auto *ExitingBB = L->getExitingBlock();
7499 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7500 return false;
7501
7502 SmallPtrSet<const Value *, 16> KnownPoison;
7504
7505 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7506 // things that are known to be poison under that assumption go on the
7507 // Worklist.
7508 KnownPoison.insert(I);
7509 Worklist.push_back(I);
7510
7511 while (!Worklist.empty()) {
7512 const Instruction *Poison = Worklist.pop_back_val();
7513
7514 for (const Use &U : Poison->uses()) {
7515 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7516 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7517 DT.dominates(PoisonUser->getParent(), ExitingBB))
7518 return true;
7519
7520 if (propagatesPoison(U) && L->contains(PoisonUser))
7521 if (KnownPoison.insert(PoisonUser).second)
7522 Worklist.push_back(PoisonUser);
7523 }
7524 }
7525
7526 return false;
7527}
7528
7529ScalarEvolution::LoopProperties
7530ScalarEvolution::getLoopProperties(const Loop *L) {
7531 using LoopProperties = ScalarEvolution::LoopProperties;
7532
7533 auto Itr = LoopPropertiesCache.find(L);
7534 if (Itr == LoopPropertiesCache.end()) {
7535 auto HasSideEffects = [](Instruction *I) {
7536 if (auto *SI = dyn_cast<StoreInst>(I))
7537 return !SI->isSimple();
7538
7539 if (I->mayThrow())
7540 return true;
7541
7542 // Non-volatile memset / memcpy do not count as side-effect for forward
7543 // progress.
7544 if (isa<MemIntrinsic>(I) && !I->isVolatile())
7545 return false;
7546
7547 return I->mayWriteToMemory();
7548 };
7549
7550 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7551 /*HasNoSideEffects*/ true};
7552
7553 for (auto *BB : L->getBlocks())
7554 for (auto &I : *BB) {
7556 LP.HasNoAbnormalExits = false;
7557 if (HasSideEffects(&I))
7558 LP.HasNoSideEffects = false;
7559 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7560 break; // We're already as pessimistic as we can get.
7561 }
7562
7563 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7564 assert(InsertPair.second && "We just checked!");
7565 Itr = InsertPair.first;
7566 }
7567
7568 return Itr->second;
7569}
7570
7572 // A mustprogress loop without side effects must be finite.
7573 // TODO: The check used here is very conservative. It's only *specific*
7574 // side effects which are well defined in infinite loops.
7575 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7576}
7577
7578const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7579 // Worklist item with a Value and a bool indicating whether all operands have
7580 // been visited already.
7583
7584 Stack.emplace_back(V, true);
7585 Stack.emplace_back(V, false);
7586 while (!Stack.empty()) {
7587 auto E = Stack.pop_back_val();
7588 Value *CurV = E.getPointer();
7589
7590 if (getExistingSCEV(CurV))
7591 continue;
7592
7594 const SCEV *CreatedSCEV = nullptr;
7595 // If all operands have been visited already, create the SCEV.
7596 if (E.getInt()) {
7597 CreatedSCEV = createSCEV(CurV);
7598 } else {
7599 // Otherwise get the operands we need to create SCEV's for before creating
7600 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7601 // just use it.
7602 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7603 }
7604
7605 if (CreatedSCEV) {
7606 insertValueToMap(CurV, CreatedSCEV);
7607 } else {
7608 // Queue CurV for SCEV creation, followed by its's operands which need to
7609 // be constructed first.
7610 Stack.emplace_back(CurV, true);
7611 for (Value *Op : Ops)
7612 Stack.emplace_back(Op, false);
7613 }
7614 }
7615
7616 return getExistingSCEV(V);
7617}
7618
7619const SCEV *
7620ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7621 if (!isSCEVable(V->getType()))
7622 return getUnknown(V);
7623
7624 if (Instruction *I = dyn_cast<Instruction>(V)) {
7625 // Don't attempt to analyze instructions in blocks that aren't
7626 // reachable. Such instructions don't matter, and they aren't required
7627 // to obey basic rules for definitions dominating uses which this
7628 // analysis depends on.
7629 if (!DT.isReachableFromEntry(I->getParent()))
7630 return getUnknown(PoisonValue::get(V->getType()));
7631 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7632 return getConstant(CI);
7633 else if (isa<GlobalAlias>(V))
7634 return getUnknown(V);
7635 else if (!isa<ConstantExpr>(V))
7636 return getUnknown(V);
7637
7639 if (auto BO =
7641 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7642 switch (BO->Opcode) {
7643 case Instruction::Add:
7644 case Instruction::Mul: {
7645 // For additions and multiplications, traverse add/mul chains for which we
7646 // can potentially create a single SCEV, to reduce the number of
7647 // get{Add,Mul}Expr calls.
7648 do {
7649 if (BO->Op) {
7650 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7651 Ops.push_back(BO->Op);
7652 break;
7653 }
7654 }
7655 Ops.push_back(BO->RHS);
7656 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7658 if (!NewBO ||
7659 (BO->Opcode == Instruction::Add &&
7660 (NewBO->Opcode != Instruction::Add &&
7661 NewBO->Opcode != Instruction::Sub)) ||
7662 (BO->Opcode == Instruction::Mul &&
7663 NewBO->Opcode != Instruction::Mul)) {
7664 Ops.push_back(BO->LHS);
7665 break;
7666 }
7667 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7668 // requires a SCEV for the LHS.
7669 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7670 auto *I = dyn_cast<Instruction>(BO->Op);
7671 if (I && programUndefinedIfPoison(I)) {
7672 Ops.push_back(BO->LHS);
7673 break;
7674 }
7675 }
7676 BO = NewBO;
7677 } while (true);
7678 return nullptr;
7679 }
7680 case Instruction::Sub:
7681 case Instruction::UDiv:
7682 case Instruction::URem:
7683 break;
7684 case Instruction::AShr:
7685 case Instruction::Shl:
7686 case Instruction::Xor:
7687 if (!IsConstArg)
7688 return nullptr;
7689 break;
7690 case Instruction::And:
7691 case Instruction::Or:
7692 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7693 return nullptr;
7694 break;
7695 case Instruction::LShr:
7696 return getUnknown(V);
7697 default:
7698 llvm_unreachable("Unhandled binop");
7699 break;
7700 }
7701
7702 Ops.push_back(BO->LHS);
7703 Ops.push_back(BO->RHS);
7704 return nullptr;
7705 }
7706
7707 switch (U->getOpcode()) {
7708 case Instruction::Trunc:
7709 case Instruction::ZExt:
7710 case Instruction::SExt:
7711 case Instruction::PtrToAddr:
7712 case Instruction::PtrToInt:
7713 Ops.push_back(U->getOperand(0));
7714 return nullptr;
7715
7716 case Instruction::BitCast:
7717 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7718 Ops.push_back(U->getOperand(0));
7719 return nullptr;
7720 }
7721 return getUnknown(V);
7722
7723 case Instruction::SDiv:
7724 case Instruction::SRem:
7725 Ops.push_back(U->getOperand(0));
7726 Ops.push_back(U->getOperand(1));
7727 return nullptr;
7728
7729 case Instruction::GetElementPtr:
7730 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7731 "GEP source element type must be sized");
7732 llvm::append_range(Ops, U->operands());
7733 return nullptr;
7734
7735 case Instruction::IntToPtr:
7736 return getUnknown(V);
7737
7738 case Instruction::PHI:
7739 // Keep constructing SCEVs' for phis recursively for now.
7740 return nullptr;
7741
7742 case Instruction::Select: {
7743 // Check if U is a select that can be simplified to a SCEVUnknown.
7744 auto CanSimplifyToUnknown = [this, U]() {
7745 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7746 return false;
7747
7748 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7749 if (!ICI)
7750 return false;
7751 Value *LHS = ICI->getOperand(0);
7752 Value *RHS = ICI->getOperand(1);
7753 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7754 ICI->getPredicate() == CmpInst::ICMP_NE) {
7756 return true;
7757 } else if (getTypeSizeInBits(LHS->getType()) >
7758 getTypeSizeInBits(U->getType()))
7759 return true;
7760 return false;
7761 };
7762 if (CanSimplifyToUnknown())
7763 return getUnknown(U);
7764
7765 llvm::append_range(Ops, U->operands());
7766 return nullptr;
7767 break;
7768 }
7769 case Instruction::Call:
7770 case Instruction::Invoke:
7771 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7772 Ops.push_back(RV);
7773 return nullptr;
7774 }
7775
7776 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7777 switch (II->getIntrinsicID()) {
7778 case Intrinsic::abs:
7779 Ops.push_back(II->getArgOperand(0));
7780 return nullptr;
7781 case Intrinsic::umax:
7782 case Intrinsic::umin:
7783 case Intrinsic::smax:
7784 case Intrinsic::smin:
7785 case Intrinsic::usub_sat:
7786 case Intrinsic::uadd_sat:
7787 Ops.push_back(II->getArgOperand(0));
7788 Ops.push_back(II->getArgOperand(1));
7789 return nullptr;
7790 case Intrinsic::start_loop_iterations:
7791 case Intrinsic::annotation:
7792 case Intrinsic::ptr_annotation:
7793 Ops.push_back(II->getArgOperand(0));
7794 return nullptr;
7795 default:
7796 break;
7797 }
7798 }
7799 break;
7800 }
7801
7802 return nullptr;
7803}
7804
7805const SCEV *ScalarEvolution::createSCEV(Value *V) {
7806 if (!isSCEVable(V->getType()))
7807 return getUnknown(V);
7808
7809 if (Instruction *I = dyn_cast<Instruction>(V)) {
7810 // Don't attempt to analyze instructions in blocks that aren't
7811 // reachable. Such instructions don't matter, and they aren't required
7812 // to obey basic rules for definitions dominating uses which this
7813 // analysis depends on.
7814 if (!DT.isReachableFromEntry(I->getParent()))
7815 return getUnknown(PoisonValue::get(V->getType()));
7816 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7817 return getConstant(CI);
7818 else if (isa<GlobalAlias>(V))
7819 return getUnknown(V);
7820 else if (!isa<ConstantExpr>(V))
7821 return getUnknown(V);
7822
7823 const SCEV *LHS;
7824 const SCEV *RHS;
7825
7827 if (auto BO =
7829 switch (BO->Opcode) {
7830 case Instruction::Add: {
7831 // The simple thing to do would be to just call getSCEV on both operands
7832 // and call getAddExpr with the result. However if we're looking at a
7833 // bunch of things all added together, this can be quite inefficient,
7834 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7835 // Instead, gather up all the operands and make a single getAddExpr call.
7836 // LLVM IR canonical form means we need only traverse the left operands.
7838 do {
7839 if (BO->Op) {
7840 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7841 AddOps.push_back(OpSCEV);
7842 break;
7843 }
7844
7845 // If a NUW or NSW flag can be applied to the SCEV for this
7846 // addition, then compute the SCEV for this addition by itself
7847 // with a separate call to getAddExpr. We need to do that
7848 // instead of pushing the operands of the addition onto AddOps,
7849 // since the flags are only known to apply to this particular
7850 // addition - they may not apply to other additions that can be
7851 // formed with operands from AddOps.
7852 const SCEV *RHS = getSCEV(BO->RHS);
7853 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7854 if (Flags != SCEV::FlagAnyWrap) {
7855 const SCEV *LHS = getSCEV(BO->LHS);
7856 if (BO->Opcode == Instruction::Sub)
7857 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
7858 else
7859 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
7860 break;
7861 }
7862 }
7863
7864 if (BO->Opcode == Instruction::Sub)
7865 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
7866 else
7867 AddOps.push_back(getSCEV(BO->RHS));
7868
7869 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7871 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7872 NewBO->Opcode != Instruction::Sub)) {
7873 AddOps.push_back(getSCEV(BO->LHS));
7874 break;
7875 }
7876 BO = NewBO;
7877 } while (true);
7878
7879 return getAddExpr(AddOps);
7880 }
7881
7882 case Instruction::Mul: {
7884 do {
7885 if (BO->Op) {
7886 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7887 MulOps.push_back(OpSCEV);
7888 break;
7889 }
7890
7891 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7892 if (Flags != SCEV::FlagAnyWrap) {
7893 LHS = getSCEV(BO->LHS);
7894 RHS = getSCEV(BO->RHS);
7895 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
7896 break;
7897 }
7898 }
7899
7900 MulOps.push_back(getSCEV(BO->RHS));
7901 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7903 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7904 MulOps.push_back(getSCEV(BO->LHS));
7905 break;
7906 }
7907 BO = NewBO;
7908 } while (true);
7909
7910 return getMulExpr(MulOps);
7911 }
7912 case Instruction::UDiv:
7913 LHS = getSCEV(BO->LHS);
7914 RHS = getSCEV(BO->RHS);
7915 return getUDivExpr(LHS, RHS);
7916 case Instruction::URem:
7917 LHS = getSCEV(BO->LHS);
7918 RHS = getSCEV(BO->RHS);
7919 return getURemExpr(LHS, RHS);
7920 case Instruction::Sub: {
7922 if (BO->Op)
7923 Flags = getNoWrapFlagsFromUB(BO->Op);
7924 LHS = getSCEV(BO->LHS);
7925 RHS = getSCEV(BO->RHS);
7926 return getMinusSCEV(LHS, RHS, Flags);
7927 }
7928 case Instruction::And:
7929 // For an expression like x&255 that merely masks off the high bits,
7930 // use zext(trunc(x)) as the SCEV expression.
7931 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7932 if (CI->isZero())
7933 return getSCEV(BO->RHS);
7934 if (CI->isMinusOne())
7935 return getSCEV(BO->LHS);
7936 const APInt &A = CI->getValue();
7937
7938 // Instcombine's ShrinkDemandedConstant may strip bits out of
7939 // constants, obscuring what would otherwise be a low-bits mask.
7940 // Use computeKnownBits to compute what ShrinkDemandedConstant
7941 // knew about to reconstruct a low-bits mask value.
7942 unsigned LZ = A.countl_zero();
7943 unsigned TZ = A.countr_zero();
7944 unsigned BitWidth = A.getBitWidth();
7945 KnownBits Known(BitWidth);
7946 computeKnownBits(BO->LHS, Known, getDataLayout(), &AC, nullptr, &DT);
7947
7948 APInt EffectiveMask =
7949 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
7950 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
7951 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
7952 const SCEV *LHS = getSCEV(BO->LHS);
7953 const SCEV *ShiftedLHS = nullptr;
7954 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
7955 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
7956 // For an expression like (x * 8) & 8, simplify the multiply.
7957 unsigned MulZeros = OpC->getAPInt().countr_zero();
7958 unsigned GCD = std::min(MulZeros, TZ);
7959 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
7961 MulOps.push_back(getConstant(OpC->getAPInt().ashr(GCD)));
7962 append_range(MulOps, LHSMul->operands().drop_front());
7963 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
7964 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
7965 }
7966 }
7967 if (!ShiftedLHS)
7968 ShiftedLHS = getUDivExpr(LHS, MulCount);
7969 return getMulExpr(
7971 getTruncateExpr(ShiftedLHS,
7972 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
7973 BO->LHS->getType()),
7974 MulCount);
7975 }
7976 }
7977 // Binary `and` is a bit-wise `umin`.
7978 if (BO->LHS->getType()->isIntegerTy(1)) {
7979 LHS = getSCEV(BO->LHS);
7980 RHS = getSCEV(BO->RHS);
7981 return getUMinExpr(LHS, RHS);
7982 }
7983 break;
7984
7985 case Instruction::Or:
7986 // Binary `or` is a bit-wise `umax`.
7987 if (BO->LHS->getType()->isIntegerTy(1)) {
7988 LHS = getSCEV(BO->LHS);
7989 RHS = getSCEV(BO->RHS);
7990 return getUMaxExpr(LHS, RHS);
7991 }
7992 break;
7993
7994 case Instruction::Xor:
7995 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7996 // If the RHS of xor is -1, then this is a not operation.
7997 if (CI->isMinusOne())
7998 return getNotSCEV(getSCEV(BO->LHS));
7999
8000 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
8001 // This is a variant of the check for xor with -1, and it handles
8002 // the case where instcombine has trimmed non-demanded bits out
8003 // of an xor with -1.
8004 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
8005 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
8006 if (LBO->getOpcode() == Instruction::And &&
8007 LCI->getValue() == CI->getValue())
8008 if (const SCEVZeroExtendExpr *Z =
8010 Type *UTy = BO->LHS->getType();
8011 const SCEV *Z0 = Z->getOperand();
8012 Type *Z0Ty = Z0->getType();
8013 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
8014
8015 // If C is a low-bits mask, the zero extend is serving to
8016 // mask off the high bits. Complement the operand and
8017 // re-apply the zext.
8018 if (CI->getValue().isMask(Z0TySize))
8019 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
8020
8021 // If C is a single bit, it may be in the sign-bit position
8022 // before the zero-extend. In this case, represent the xor
8023 // using an add, which is equivalent, and re-apply the zext.
8024 APInt Trunc = CI->getValue().trunc(Z0TySize);
8025 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
8026 Trunc.isSignMask())
8027 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
8028 UTy);
8029 }
8030 }
8031 break;
8032
8033 case Instruction::Shl:
8034 // Turn shift left of a constant amount into a multiply.
8035 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
8036 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
8037
8038 // If the shift count is not less than the bitwidth, the result of
8039 // the shift is undefined. Don't try to analyze it, because the
8040 // resolution chosen here may differ from the resolution chosen in
8041 // other parts of the compiler.
8042 if (SA->getValue().uge(BitWidth))
8043 break;
8044
8045 // We can safely preserve the nuw flag in all cases. It's also safe to
8046 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
8047 // requires special handling. It can be preserved as long as we're not
8048 // left shifting by bitwidth - 1.
8049 auto Flags = SCEV::FlagAnyWrap;
8050 if (BO->Op) {
8051 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
8052 if ((MulFlags & SCEV::FlagNSW) &&
8053 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
8055 if (MulFlags & SCEV::FlagNUW)
8057 }
8058
8059 ConstantInt *X = ConstantInt::get(
8060 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
8061 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
8062 }
8063 break;
8064
8065 case Instruction::AShr:
8066 // AShr X, C, where C is a constant.
8067 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
8068 if (!CI)
8069 break;
8070
8071 Type *OuterTy = BO->LHS->getType();
8072 uint64_t BitWidth = getTypeSizeInBits(OuterTy);
8073 // If the shift count is not less than the bitwidth, the result of
8074 // the shift is undefined. Don't try to analyze it, because the
8075 // resolution chosen here may differ from the resolution chosen in
8076 // other parts of the compiler.
8077 if (CI->getValue().uge(BitWidth))
8078 break;
8079
8080 if (CI->isZero())
8081 return getSCEV(BO->LHS); // shift by zero --> noop
8082
8083 uint64_t AShrAmt = CI->getZExtValue();
8084 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
8085
8086 Operator *L = dyn_cast<Operator>(BO->LHS);
8087 const SCEV *AddTruncateExpr = nullptr;
8088 ConstantInt *ShlAmtCI = nullptr;
8089 const SCEV *AddConstant = nullptr;
8090
8091 if (L && L->getOpcode() == Instruction::Add) {
8092 // X = Shl A, n
8093 // Y = Add X, c
8094 // Z = AShr Y, m
8095 // n, c and m are constants.
8096
8097 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
8098 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
8099 if (LShift && LShift->getOpcode() == Instruction::Shl) {
8100 if (AddOperandCI) {
8101 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
8102 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
8103 // since we truncate to TruncTy, the AddConstant should be of the
8104 // same type, so create a new Constant with type same as TruncTy.
8105 // Also, the Add constant should be shifted right by AShr amount.
8106 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
8107 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
8108 // we model the expression as sext(add(trunc(A), c << n)), since the
8109 // sext(trunc) part is already handled below, we create a
8110 // AddExpr(TruncExp) which will be used later.
8111 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8112 }
8113 }
8114 } else if (L && L->getOpcode() == Instruction::Shl) {
8115 // X = Shl A, n
8116 // Y = AShr X, m
8117 // Both n and m are constant.
8118
8119 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
8120 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
8121 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8122 }
8123
8124 if (AddTruncateExpr && ShlAmtCI) {
8125 // We can merge the two given cases into a single SCEV statement,
8126 // incase n = m, the mul expression will be 2^0, so it gets resolved to
8127 // a simpler case. The following code handles the two cases:
8128 //
8129 // 1) For a two-shift sext-inreg, i.e. n = m,
8130 // use sext(trunc(x)) as the SCEV expression.
8131 //
8132 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
8133 // expression. We already checked that ShlAmt < BitWidth, so
8134 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
8135 // ShlAmt - AShrAmt < Amt.
8136 const APInt &ShlAmt = ShlAmtCI->getValue();
8137 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
8138 APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
8139 ShlAmtCI->getZExtValue() - AShrAmt);
8140 const SCEV *CompositeExpr =
8141 getMulExpr(AddTruncateExpr, getConstant(Mul));
8142 if (L->getOpcode() != Instruction::Shl)
8143 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8144
8145 return getSignExtendExpr(CompositeExpr, OuterTy);
8146 }
8147 }
8148 break;
8149 }
8150 }
8151
8152 switch (U->getOpcode()) {
8153 case Instruction::Trunc:
8154 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8155
8156 case Instruction::ZExt:
8157 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8158
8159 case Instruction::SExt:
8160 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8162 // The NSW flag of a subtract does not always survive the conversion to
8163 // A + (-1)*B. By pushing sign extension onto its operands we are much
8164 // more likely to preserve NSW and allow later AddRec optimisations.
8165 //
8166 // NOTE: This is effectively duplicating this logic from getSignExtend:
8167 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8168 // but by that point the NSW information has potentially been lost.
8169 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8170 Type *Ty = U->getType();
8171 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8172 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8173 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8174 }
8175 }
8176 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8177
8178 case Instruction::BitCast:
8179 // BitCasts are no-op casts so we just eliminate the cast.
8180 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8181 return getSCEV(U->getOperand(0));
8182 break;
8183
8184 case Instruction::PtrToAddr: {
8185 const SCEV *IntOp = getPtrToAddrExpr(getSCEV(U->getOperand(0)));
8186 if (isa<SCEVCouldNotCompute>(IntOp))
8187 return getUnknown(V);
8188 return IntOp;
8189 }
8190
8191 case Instruction::PtrToInt: {
8192 // Pointer to integer cast is straight-forward, so do model it.
8193 const SCEV *Op = getSCEV(U->getOperand(0));
8194 Type *DstIntTy = U->getType();
8195 // But only if effective SCEV (integer) type is wide enough to represent
8196 // all possible pointer values.
8197 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8198 if (isa<SCEVCouldNotCompute>(IntOp))
8199 return getUnknown(V);
8200 return IntOp;
8201 }
8202 case Instruction::IntToPtr:
8203 // Just don't deal with inttoptr casts.
8204 return getUnknown(V);
8205
8206 case Instruction::SDiv:
8207 // If both operands are non-negative, this is just an udiv.
8208 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8209 isKnownNonNegative(getSCEV(U->getOperand(1))))
8210 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8211 break;
8212
8213 case Instruction::SRem:
8214 // If both operands are non-negative, this is just an urem.
8215 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8216 isKnownNonNegative(getSCEV(U->getOperand(1))))
8217 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8218 break;
8219
8220 case Instruction::GetElementPtr:
8221 return createNodeForGEP(cast<GEPOperator>(U));
8222
8223 case Instruction::PHI:
8224 return createNodeForPHI(cast<PHINode>(U));
8225
8226 case Instruction::Select:
8227 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8228 U->getOperand(2));
8229
8230 case Instruction::Call:
8231 case Instruction::Invoke:
8232 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8233 return getSCEV(RV);
8234
8235 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8236 switch (II->getIntrinsicID()) {
8237 case Intrinsic::abs:
8238 return getAbsExpr(
8239 getSCEV(II->getArgOperand(0)),
8240 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8241 case Intrinsic::umax:
8242 LHS = getSCEV(II->getArgOperand(0));
8243 RHS = getSCEV(II->getArgOperand(1));
8244 return getUMaxExpr(LHS, RHS);
8245 case Intrinsic::umin:
8246 LHS = getSCEV(II->getArgOperand(0));
8247 RHS = getSCEV(II->getArgOperand(1));
8248 return getUMinExpr(LHS, RHS);
8249 case Intrinsic::smax:
8250 LHS = getSCEV(II->getArgOperand(0));
8251 RHS = getSCEV(II->getArgOperand(1));
8252 return getSMaxExpr(LHS, RHS);
8253 case Intrinsic::smin:
8254 LHS = getSCEV(II->getArgOperand(0));
8255 RHS = getSCEV(II->getArgOperand(1));
8256 return getSMinExpr(LHS, RHS);
8257 case Intrinsic::usub_sat: {
8258 const SCEV *X = getSCEV(II->getArgOperand(0));
8259 const SCEV *Y = getSCEV(II->getArgOperand(1));
8260 const SCEV *ClampedY = getUMinExpr(X, Y);
8261 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8262 }
8263 case Intrinsic::uadd_sat: {
8264 const SCEV *X = getSCEV(II->getArgOperand(0));
8265 const SCEV *Y = getSCEV(II->getArgOperand(1));
8266 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8267 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8268 }
8269 case Intrinsic::start_loop_iterations:
8270 case Intrinsic::annotation:
8271 case Intrinsic::ptr_annotation:
8272 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8273 // just eqivalent to the first operand for SCEV purposes.
8274 return getSCEV(II->getArgOperand(0));
8275 case Intrinsic::vscale:
8276 return getVScale(II->getType());
8277 default:
8278 break;
8279 }
8280 }
8281 break;
8282 }
8283
8284 return getUnknown(V);
8285}
8286
8287//===----------------------------------------------------------------------===//
8288// Iteration Count Computation Code
8289//
8290
8292 if (isa<SCEVCouldNotCompute>(ExitCount))
8293 return getCouldNotCompute();
8294
8295 auto *ExitCountType = ExitCount->getType();
8296 assert(ExitCountType->isIntegerTy());
8297 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8298 1 + ExitCountType->getScalarSizeInBits());
8299 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8300}
8301
8303 Type *EvalTy,
8304 const Loop *L) {
8305 if (isa<SCEVCouldNotCompute>(ExitCount))
8306 return getCouldNotCompute();
8307
8308 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8309 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8310
8311 auto CanAddOneWithoutOverflow = [&]() {
8312 ConstantRange ExitCountRange =
8313 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8314 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8315 return true;
8316
8317 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8318 getMinusOne(ExitCount->getType()));
8319 };
8320
8321 // If we need to zero extend the backedge count, check if we can add one to
8322 // it prior to zero extending without overflow. Provided this is safe, it
8323 // allows better simplification of the +1.
8324 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8325 return getZeroExtendExpr(
8326 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8327
8328 // Get the total trip count from the count by adding 1. This may wrap.
8329 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8330}
8331
8332static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8333 if (!ExitCount)
8334 return 0;
8335
8336 ConstantInt *ExitConst = ExitCount->getValue();
8337
8338 // Guard against huge trip counts.
8339 if (ExitConst->getValue().getActiveBits() > 32)
8340 return 0;
8341
8342 // In case of integer overflow, this returns 0, which is correct.
8343 return ((unsigned)ExitConst->getZExtValue()) + 1;
8344}
8345
8347 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8348 return getConstantTripCount(ExitCount);
8349}
8350
8351unsigned
8353 const BasicBlock *ExitingBlock) {
8354 assert(ExitingBlock && "Must pass a non-null exiting block!");
8355 assert(L->isLoopExiting(ExitingBlock) &&
8356 "Exiting block must actually branch out of the loop!");
8357 const SCEVConstant *ExitCount =
8358 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8359 return getConstantTripCount(ExitCount);
8360}
8361
8363 const Loop *L, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8364
8365 const auto *MaxExitCount =
8366 Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
8368 return getConstantTripCount(dyn_cast<SCEVConstant>(MaxExitCount));
8369}
8370
8372 SmallVector<BasicBlock *, 8> ExitingBlocks;
8373 L->getExitingBlocks(ExitingBlocks);
8374
8375 std::optional<unsigned> Res;
8376 for (auto *ExitingBB : ExitingBlocks) {
8377 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8378 if (!Res)
8379 Res = Multiple;
8380 Res = std::gcd(*Res, Multiple);
8381 }
8382 return Res.value_or(1);
8383}
8384
8386 const SCEV *ExitCount) {
8387 if (isa<SCEVCouldNotCompute>(ExitCount))
8388 return 1;
8389
8390 // Get the trip count
8391 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8392
8393 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8394 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8395 // the greatest power of 2 divisor less than 2^32.
8396 return Multiple.getActiveBits() > 32
8397 ? 1U << std::min(31U, Multiple.countTrailingZeros())
8398 : (unsigned)Multiple.getZExtValue();
8399}
8400
8401/// Returns the largest constant divisor of the trip count of this loop as a
8402/// normal unsigned value, if possible. This means that the actual trip count is
8403/// always a multiple of the returned value (don't forget the trip count could
8404/// very well be zero as well!).
8405///
8406/// Returns 1 if the trip count is unknown or not guaranteed to be the
8407/// multiple of a constant (which is also the case if the trip count is simply
8408/// constant, use getSmallConstantTripCount for that case), Will also return 1
8409/// if the trip count is very large (>= 2^32).
8410///
8411/// As explained in the comments for getSmallConstantTripCount, this assumes
8412/// that control exits the loop via ExitingBlock.
8413unsigned
8415 const BasicBlock *ExitingBlock) {
8416 assert(ExitingBlock && "Must pass a non-null exiting block!");
8417 assert(L->isLoopExiting(ExitingBlock) &&
8418 "Exiting block must actually branch out of the loop!");
8419 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8420 return getSmallConstantTripMultiple(L, ExitCount);
8421}
8422
8424 const BasicBlock *ExitingBlock,
8425 ExitCountKind Kind) {
8426 switch (Kind) {
8427 case Exact:
8428 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8429 case SymbolicMaximum:
8430 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8431 case ConstantMaximum:
8432 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8433 };
8434 llvm_unreachable("Invalid ExitCountKind!");
8435}
8436
8438 const Loop *L, const BasicBlock *ExitingBlock,
8440 switch (Kind) {
8441 case Exact:
8442 return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
8443 Predicates);
8444 case SymbolicMaximum:
8445 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
8446 Predicates);
8447 case ConstantMaximum:
8448 return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
8449 Predicates);
8450 };
8451 llvm_unreachable("Invalid ExitCountKind!");
8452}
8453
8456 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8457}
8458
8460 ExitCountKind Kind) {
8461 switch (Kind) {
8462 case Exact:
8463 return getBackedgeTakenInfo(L).getExact(L, this);
8464 case ConstantMaximum:
8465 return getBackedgeTakenInfo(L).getConstantMax(this);
8466 case SymbolicMaximum:
8467 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8468 };
8469 llvm_unreachable("Invalid ExitCountKind!");
8470}
8471
8474 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8475}
8476
8479 return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds);
8480}
8481
8483 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8484}
8485
8486/// Push PHI nodes in the header of the given loop onto the given Worklist.
8487static void PushLoopPHIs(const Loop *L,
8490 BasicBlock *Header = L->getHeader();
8491
8492 // Push all Loop-header PHIs onto the Worklist stack.
8493 for (PHINode &PN : Header->phis())
8494 if (Visited.insert(&PN).second)
8495 Worklist.push_back(&PN);
8496}
8497
8498ScalarEvolution::BackedgeTakenInfo &
8499ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8500 auto &BTI = getBackedgeTakenInfo(L);
8501 if (BTI.hasFullInfo())
8502 return BTI;
8503
8504 auto Pair = PredicatedBackedgeTakenCounts.try_emplace(L);
8505
8506 if (!Pair.second)
8507 return Pair.first->second;
8508
8509 BackedgeTakenInfo Result =
8510 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8511
8512 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8513}
8514
8515ScalarEvolution::BackedgeTakenInfo &
8516ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8517 // Initially insert an invalid entry for this loop. If the insertion
8518 // succeeds, proceed to actually compute a backedge-taken count and
8519 // update the value. The temporary CouldNotCompute value tells SCEV
8520 // code elsewhere that it shouldn't attempt to request a new
8521 // backedge-taken count, which could result in infinite recursion.
8522 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8523 BackedgeTakenCounts.try_emplace(L);
8524 if (!Pair.second)
8525 return Pair.first->second;
8526
8527 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8528 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8529 // must be cleared in this scope.
8530 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8531
8532 // Now that we know more about the trip count for this loop, forget any
8533 // existing SCEV values for PHI nodes in this loop since they are only
8534 // conservative estimates made without the benefit of trip count
8535 // information. This invalidation is not necessary for correctness, and is
8536 // only done to produce more precise results.
8537 if (Result.hasAnyInfo()) {
8538 // Invalidate any expression using an addrec in this loop.
8540 auto LoopUsersIt = LoopUsers.find(L);
8541 if (LoopUsersIt != LoopUsers.end())
8542 append_range(ToForget, LoopUsersIt->second);
8543 forgetMemoizedResults(ToForget);
8544
8545 // Invalidate constant-evolved loop header phis.
8546 for (PHINode &PN : L->getHeader()->phis())
8547 ConstantEvolutionLoopExitValue.erase(&PN);
8548 }
8549
8550 // Re-lookup the insert position, since the call to
8551 // computeBackedgeTakenCount above could result in a
8552 // recusive call to getBackedgeTakenInfo (on a different
8553 // loop), which would invalidate the iterator computed
8554 // earlier.
8555 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8556}
8557
8559 // This method is intended to forget all info about loops. It should
8560 // invalidate caches as if the following happened:
8561 // - The trip counts of all loops have changed arbitrarily
8562 // - Every llvm::Value has been updated in place to produce a different
8563 // result.
8564 BackedgeTakenCounts.clear();
8565 PredicatedBackedgeTakenCounts.clear();
8566 BECountUsers.clear();
8567 LoopPropertiesCache.clear();
8568 ConstantEvolutionLoopExitValue.clear();
8569 ValueExprMap.clear();
8570 ValuesAtScopes.clear();
8571 ValuesAtScopesUsers.clear();
8572 LoopDispositions.clear();
8573 BlockDispositions.clear();
8574 UnsignedRanges.clear();
8575 SignedRanges.clear();
8576 ExprValueMap.clear();
8577 HasRecMap.clear();
8578 ConstantMultipleCache.clear();
8579 PredicatedSCEVRewrites.clear();
8580 FoldCache.clear();
8581 FoldCacheUser.clear();
8582}
8583void ScalarEvolution::visitAndClearUsers(
8587 while (!Worklist.empty()) {
8588 Instruction *I = Worklist.pop_back_val();
8589 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8590 continue;
8591
8593 ValueExprMap.find_as(static_cast<Value *>(I));
8594 if (It != ValueExprMap.end()) {
8595 eraseValueFromMap(It->first);
8596 ToForget.push_back(It->second);
8597 if (PHINode *PN = dyn_cast<PHINode>(I))
8598 ConstantEvolutionLoopExitValue.erase(PN);
8599 }
8600
8601 PushDefUseChildren(I, Worklist, Visited);
8602 }
8603}
8604
8606 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8610
8611 // Iterate over all the loops and sub-loops to drop SCEV information.
8612 while (!LoopWorklist.empty()) {
8613 auto *CurrL = LoopWorklist.pop_back_val();
8614
8615 // Drop any stored trip count value.
8616 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8617 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8618
8619 // Drop information about predicated SCEV rewrites for this loop.
8620 for (auto I = PredicatedSCEVRewrites.begin();
8621 I != PredicatedSCEVRewrites.end();) {
8622 std::pair<const SCEV *, const Loop *> Entry = I->first;
8623 if (Entry.second == CurrL)
8624 PredicatedSCEVRewrites.erase(I++);
8625 else
8626 ++I;
8627 }
8628
8629 auto LoopUsersItr = LoopUsers.find(CurrL);
8630 if (LoopUsersItr != LoopUsers.end())
8631 llvm::append_range(ToForget, LoopUsersItr->second);
8632
8633 // Drop information about expressions based on loop-header PHIs.
8634 PushLoopPHIs(CurrL, Worklist, Visited);
8635 visitAndClearUsers(Worklist, Visited, ToForget);
8636
8637 LoopPropertiesCache.erase(CurrL);
8638 // Forget all contained loops too, to avoid dangling entries in the
8639 // ValuesAtScopes map.
8640 LoopWorklist.append(CurrL->begin(), CurrL->end());
8641 }
8642 forgetMemoizedResults(ToForget);
8643}
8644
8646 forgetLoop(L->getOutermostLoop());
8647}
8648
8651 if (!I) return;
8652
8653 // Drop information about expressions based on loop-header PHIs.
8657 Worklist.push_back(I);
8658 Visited.insert(I);
8659 visitAndClearUsers(Worklist, Visited, ToForget);
8660
8661 forgetMemoizedResults(ToForget);
8662}
8663
8665 if (!isSCEVable(V->getType()))
8666 return;
8667
8668 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8669 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8670 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8671 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8672 if (const SCEV *S = getExistingSCEV(V)) {
8673 struct InvalidationRootCollector {
8674 Loop *L;
8676
8677 InvalidationRootCollector(Loop *L) : L(L) {}
8678
8679 bool follow(const SCEV *S) {
8680 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8681 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8682 if (L->contains(I))
8683 Roots.push_back(S);
8684 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8685 if (L->contains(AddRec->getLoop()))
8686 Roots.push_back(S);
8687 }
8688 return true;
8689 }
8690 bool isDone() const { return false; }
8691 };
8692
8693 InvalidationRootCollector C(L);
8694 visitAll(S, C);
8695 forgetMemoizedResults(C.Roots);
8696 }
8697
8698 // Also perform the normal invalidation.
8699 forgetValue(V);
8700}
8701
8702void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8703
8705 // Unless a specific value is passed to invalidation, completely clear both
8706 // caches.
8707 if (!V) {
8708 BlockDispositions.clear();
8709 LoopDispositions.clear();
8710 return;
8711 }
8712
8713 if (!isSCEVable(V->getType()))
8714 return;
8715
8716 const SCEV *S = getExistingSCEV(V);
8717 if (!S)
8718 return;
8719
8720 // Invalidate the block and loop dispositions cached for S. Dispositions of
8721 // S's users may change if S's disposition changes (i.e. a user may change to
8722 // loop-invariant, if S changes to loop invariant), so also invalidate
8723 // dispositions of S's users recursively.
8724 SmallVector<const SCEV *, 8> Worklist = {S};
8726 while (!Worklist.empty()) {
8727 const SCEV *Curr = Worklist.pop_back_val();
8728 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8729 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8730 if (!LoopDispoRemoved && !BlockDispoRemoved)
8731 continue;
8732 auto Users = SCEVUsers.find(Curr);
8733 if (Users != SCEVUsers.end())
8734 for (const auto *User : Users->second)
8735 if (Seen.insert(User).second)
8736 Worklist.push_back(User);
8737 }
8738}
8739
8740/// Get the exact loop backedge taken count considering all loop exits. A
8741/// computable result can only be returned for loops with all exiting blocks
8742/// dominating the latch. howFarToZero assumes that the limit of each loop test
8743/// is never skipped. This is a valid assumption as long as the loop exits via
8744/// that test. For precise results, it is the caller's responsibility to specify
8745/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8746const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
8747 const Loop *L, ScalarEvolution *SE,
8749 // If any exits were not computable, the loop is not computable.
8750 if (!isComplete() || ExitNotTaken.empty())
8751 return SE->getCouldNotCompute();
8752
8753 const BasicBlock *Latch = L->getLoopLatch();
8754 // All exiting blocks we have collected must dominate the only backedge.
8755 if (!Latch)
8756 return SE->getCouldNotCompute();
8757
8758 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8759 // count is simply a minimum out of all these calculated exit counts.
8761 for (const auto &ENT : ExitNotTaken) {
8762 const SCEV *BECount = ENT.ExactNotTaken;
8763 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8764 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8765 "We should only have known counts for exiting blocks that dominate "
8766 "latch!");
8767
8768 Ops.push_back(BECount);
8769
8770 if (Preds)
8771 append_range(*Preds, ENT.Predicates);
8772
8773 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8774 "Predicate should be always true!");
8775 }
8776
8777 // If an earlier exit exits on the first iteration (exit count zero), then
8778 // a later poison exit count should not propagate into the result. This are
8779 // exactly the semantics provided by umin_seq.
8780 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8781}
8782
8783const ScalarEvolution::ExitNotTakenInfo *
8784ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
8785 const BasicBlock *ExitingBlock,
8786 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8787 for (const auto &ENT : ExitNotTaken)
8788 if (ENT.ExitingBlock == ExitingBlock) {
8789 if (ENT.hasAlwaysTruePredicate())
8790 return &ENT;
8791 else if (Predicates) {
8792 append_range(*Predicates, ENT.Predicates);
8793 return &ENT;
8794 }
8795 }
8796
8797 return nullptr;
8798}
8799
8800/// getConstantMax - Get the constant max backedge taken count for the loop.
8801const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8802 ScalarEvolution *SE,
8803 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8804 if (!getConstantMax())
8805 return SE->getCouldNotCompute();
8806
8807 for (const auto &ENT : ExitNotTaken)
8808 if (!ENT.hasAlwaysTruePredicate()) {
8809 if (!Predicates)
8810 return SE->getCouldNotCompute();
8811 append_range(*Predicates, ENT.Predicates);
8812 }
8813
8814 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8815 isa<SCEVConstant>(getConstantMax())) &&
8816 "No point in having a non-constant max backedge taken count!");
8817 return getConstantMax();
8818}
8819
8820const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8821 const Loop *L, ScalarEvolution *SE,
8822 SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8823 if (!SymbolicMax) {
8824 // Form an expression for the maximum exit count possible for this loop. We
8825 // merge the max and exact information to approximate a version of
8826 // getConstantMaxBackedgeTakenCount which isn't restricted to just
8827 // constants.
8829
8830 for (const auto &ENT : ExitNotTaken) {
8831 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
8832 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
8833 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
8834 "We should only have known counts for exiting blocks that "
8835 "dominate latch!");
8836 ExitCounts.push_back(ExitCount);
8837 if (Predicates)
8838 append_range(*Predicates, ENT.Predicates);
8839
8840 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
8841 "Predicate should be always true!");
8842 }
8843 }
8844 if (ExitCounts.empty())
8845 SymbolicMax = SE->getCouldNotCompute();
8846 else
8847 SymbolicMax =
8848 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
8849 }
8850 return SymbolicMax;
8851}
8852
8853bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
8854 ScalarEvolution *SE) const {
8855 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8856 return !ENT.hasAlwaysTruePredicate();
8857 };
8858 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
8859}
8860
8863
8865 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8866 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8870 // If we prove the max count is zero, so is the symbolic bound. This happens
8871 // in practice due to differences in a) how context sensitive we've chosen
8872 // to be and b) how we reason about bounds implied by UB.
8873 if (ConstantMaxNotTaken->isZero()) {
8874 this->ExactNotTaken = E = ConstantMaxNotTaken;
8875 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
8876 }
8877
8880 "Exact is not allowed to be less precise than Constant Max");
8883 "Exact is not allowed to be less precise than Symbolic Max");
8886 "Symbolic Max is not allowed to be less precise than Constant Max");
8889 "No point in having a non-constant max backedge taken count!");
8891 for (const auto PredList : PredLists)
8892 for (const auto *P : PredList) {
8893 if (SeenPreds.contains(P))
8894 continue;
8895 assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
8896 SeenPreds.insert(P);
8897 Predicates.push_back(P);
8898 }
8899 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
8900 "Backedge count should be int");
8902 !ConstantMaxNotTaken->getType()->isPointerTy()) &&
8903 "Max backedge count should be int");
8904}
8905
8913
8914/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
8915/// computable exit into a persistent ExitNotTakenInfo array.
8916ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
8918 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
8919 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
8920 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8921
8922 ExitNotTaken.reserve(ExitCounts.size());
8923 std::transform(ExitCounts.begin(), ExitCounts.end(),
8924 std::back_inserter(ExitNotTaken),
8925 [&](const EdgeExitInfo &EEI) {
8926 BasicBlock *ExitBB = EEI.first;
8927 const ExitLimit &EL = EEI.second;
8928 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
8929 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
8930 EL.Predicates);
8931 });
8932 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
8933 isa<SCEVConstant>(ConstantMax)) &&
8934 "No point in having a non-constant max backedge taken count!");
8935}
8936
8937/// Compute the number of times the backedge of the specified loop will execute.
8938ScalarEvolution::BackedgeTakenInfo
8939ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
8940 bool AllowPredicates) {
8941 SmallVector<BasicBlock *, 8> ExitingBlocks;
8942 L->getExitingBlocks(ExitingBlocks);
8943
8944 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8945
8947 bool CouldComputeBECount = true;
8948 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
8949 const SCEV *MustExitMaxBECount = nullptr;
8950 const SCEV *MayExitMaxBECount = nullptr;
8951 bool MustExitMaxOrZero = false;
8952 bool IsOnlyExit = ExitingBlocks.size() == 1;
8953
8954 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
8955 // and compute maxBECount.
8956 // Do a union of all the predicates here.
8957 for (BasicBlock *ExitBB : ExitingBlocks) {
8958 // We canonicalize untaken exits to br (constant), ignore them so that
8959 // proving an exit untaken doesn't negatively impact our ability to reason
8960 // about the loop as whole.
8961 if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
8962 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
8963 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8964 if (ExitIfTrue == CI->isZero())
8965 continue;
8966 }
8967
8968 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
8969
8970 assert((AllowPredicates || EL.Predicates.empty()) &&
8971 "Predicated exit limit when predicates are not allowed!");
8972
8973 // 1. For each exit that can be computed, add an entry to ExitCounts.
8974 // CouldComputeBECount is true only if all exits can be computed.
8975 if (EL.ExactNotTaken != getCouldNotCompute())
8976 ++NumExitCountsComputed;
8977 else
8978 // We couldn't compute an exact value for this exit, so
8979 // we won't be able to compute an exact value for the loop.
8980 CouldComputeBECount = false;
8981 // Remember exit count if either exact or symbolic is known. Because
8982 // Exact always implies symbolic, only check symbolic.
8983 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
8984 ExitCounts.emplace_back(ExitBB, EL);
8985 else {
8986 assert(EL.ExactNotTaken == getCouldNotCompute() &&
8987 "Exact is known but symbolic isn't?");
8988 ++NumExitCountsNotComputed;
8989 }
8990
8991 // 2. Derive the loop's MaxBECount from each exit's max number of
8992 // non-exiting iterations. Partition the loop exits into two kinds:
8993 // LoopMustExits and LoopMayExits.
8994 //
8995 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
8996 // is a LoopMayExit. If any computable LoopMustExit is found, then
8997 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
8998 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
8999 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
9000 // any
9001 // computable EL.ConstantMaxNotTaken.
9002 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
9003 DT.dominates(ExitBB, Latch)) {
9004 if (!MustExitMaxBECount) {
9005 MustExitMaxBECount = EL.ConstantMaxNotTaken;
9006 MustExitMaxOrZero = EL.MaxOrZero;
9007 } else {
9008 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
9009 EL.ConstantMaxNotTaken);
9010 }
9011 } else if (MayExitMaxBECount != getCouldNotCompute()) {
9012 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
9013 MayExitMaxBECount = EL.ConstantMaxNotTaken;
9014 else {
9015 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
9016 EL.ConstantMaxNotTaken);
9017 }
9018 }
9019 }
9020 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
9021 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
9022 // The loop backedge will be taken the maximum or zero times if there's
9023 // a single exit that must be taken the maximum or zero times.
9024 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
9025
9026 // Remember which SCEVs are used in exit limits for invalidation purposes.
9027 // We only care about non-constant SCEVs here, so we can ignore
9028 // EL.ConstantMaxNotTaken
9029 // and MaxBECount, which must be SCEVConstant.
9030 for (const auto &Pair : ExitCounts) {
9031 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
9032 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
9033 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
9034 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
9035 {L, AllowPredicates});
9036 }
9037 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
9038 MaxBECount, MaxOrZero);
9039}
9040
9041ScalarEvolution::ExitLimit
9042ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
9043 bool IsOnlyExit, bool AllowPredicates) {
9044 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
9045 // If our exiting block does not dominate the latch, then its connection with
9046 // loop's exit limit may be far from trivial.
9047 const BasicBlock *Latch = L->getLoopLatch();
9048 if (!Latch || !DT.dominates(ExitingBlock, Latch))
9049 return getCouldNotCompute();
9050
9051 Instruction *Term = ExitingBlock->getTerminator();
9052 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
9053 assert(BI->isConditional() && "If unconditional, it can't be in loop!");
9054 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
9055 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
9056 "It should have one successor in loop and one exit block!");
9057 // Proceed to the next level to examine the exit condition expression.
9058 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
9059 /*ControlsOnlyExit=*/IsOnlyExit,
9060 AllowPredicates);
9061 }
9062
9063 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
9064 // For switch, make sure that there is a single exit from the loop.
9065 BasicBlock *Exit = nullptr;
9066 for (auto *SBB : successors(ExitingBlock))
9067 if (!L->contains(SBB)) {
9068 if (Exit) // Multiple exit successors.
9069 return getCouldNotCompute();
9070 Exit = SBB;
9071 }
9072 assert(Exit && "Exiting block must have at least one exit");
9073 return computeExitLimitFromSingleExitSwitch(
9074 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
9075 }
9076
9077 return getCouldNotCompute();
9078}
9079
9081 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9082 bool AllowPredicates) {
9083 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
9084 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
9085 ControlsOnlyExit, AllowPredicates);
9086}
9087
9088std::optional<ScalarEvolution::ExitLimit>
9089ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
9090 bool ExitIfTrue, bool ControlsOnlyExit,
9091 bool AllowPredicates) {
9092 (void)this->L;
9093 (void)this->ExitIfTrue;
9094 (void)this->AllowPredicates;
9095
9096 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9097 this->AllowPredicates == AllowPredicates &&
9098 "Variance in assumed invariant key components!");
9099 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
9100 if (Itr == TripCountMap.end())
9101 return std::nullopt;
9102 return Itr->second;
9103}
9104
9105void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
9106 bool ExitIfTrue,
9107 bool ControlsOnlyExit,
9108 bool AllowPredicates,
9109 const ExitLimit &EL) {
9110 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9111 this->AllowPredicates == AllowPredicates &&
9112 "Variance in assumed invariant key components!");
9113
9114 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
9115 assert(InsertResult.second && "Expected successful insertion!");
9116 (void)InsertResult;
9117 (void)ExitIfTrue;
9118}
9119
9120ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
9121 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9122 bool ControlsOnlyExit, bool AllowPredicates) {
9123
9124 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
9125 AllowPredicates))
9126 return *MaybeEL;
9127
9128 ExitLimit EL = computeExitLimitFromCondImpl(
9129 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
9130 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
9131 return EL;
9132}
9133
9134ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
9135 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9136 bool ControlsOnlyExit, bool AllowPredicates) {
9137 // Handle BinOp conditions (And, Or).
9138 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
9139 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
9140 return *LimitFromBinOp;
9141
9142 // With an icmp, it may be feasible to compute an exact backedge-taken count.
9143 // Proceed to the next level to examine the icmp.
9144 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
9145 ExitLimit EL =
9146 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
9147 if (EL.hasFullInfo() || !AllowPredicates)
9148 return EL;
9149
9150 // Try again, but use SCEV predicates this time.
9151 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
9152 ControlsOnlyExit,
9153 /*AllowPredicates=*/true);
9154 }
9155
9156 // Check for a constant condition. These are normally stripped out by
9157 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
9158 // preserve the CFG and is temporarily leaving constant conditions
9159 // in place.
9160 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
9161 if (ExitIfTrue == !CI->getZExtValue())
9162 // The backedge is always taken.
9163 return getCouldNotCompute();
9164 // The backedge is never taken.
9165 return getZero(CI->getType());
9166 }
9167
9168 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9169 // with a constant step, we can form an equivalent icmp predicate and figure
9170 // out how many iterations will be taken before we exit.
9171 const WithOverflowInst *WO;
9172 const APInt *C;
9173 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9174 match(WO->getRHS(), m_APInt(C))) {
9175 ConstantRange NWR =
9177 WO->getNoWrapKind());
9178 CmpInst::Predicate Pred;
9179 APInt NewRHSC, Offset;
9180 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9181 if (!ExitIfTrue)
9182 Pred = ICmpInst::getInversePredicate(Pred);
9183 auto *LHS = getSCEV(WO->getLHS());
9184 if (Offset != 0)
9186 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9187 ControlsOnlyExit, AllowPredicates);
9188 if (EL.hasAnyInfo())
9189 return EL;
9190 }
9191
9192 // If it's not an integer or pointer comparison then compute it the hard way.
9193 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9194}
9195
9196std::optional<ScalarEvolution::ExitLimit>
9197ScalarEvolution::computeExitLimitFromCondFromBinOp(
9198 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9199 bool ControlsOnlyExit, bool AllowPredicates) {
9200 // Check if the controlling expression for this loop is an And or Or.
9201 Value *Op0, *Op1;
9202 bool IsAnd = false;
9203 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9204 IsAnd = true;
9205 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9206 IsAnd = false;
9207 else
9208 return std::nullopt;
9209
9210 // EitherMayExit is true in these two cases:
9211 // br (and Op0 Op1), loop, exit
9212 // br (or Op0 Op1), exit, loop
9213 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9214 ExitLimit EL0 = computeExitLimitFromCondCached(
9215 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9216 AllowPredicates);
9217 ExitLimit EL1 = computeExitLimitFromCondCached(
9218 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9219 AllowPredicates);
9220
9221 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
9222 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9223 if (isa<ConstantInt>(Op1))
9224 return Op1 == NeutralElement ? EL0 : EL1;
9225 if (isa<ConstantInt>(Op0))
9226 return Op0 == NeutralElement ? EL1 : EL0;
9227
9228 const SCEV *BECount = getCouldNotCompute();
9229 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9230 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9231 if (EitherMayExit) {
9232 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9233 // Both conditions must be same for the loop to continue executing.
9234 // Choose the less conservative count.
9235 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9236 EL1.ExactNotTaken != getCouldNotCompute()) {
9237 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9238 UseSequentialUMin);
9239 }
9240 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9241 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9242 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9243 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9244 else
9245 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9246 EL1.ConstantMaxNotTaken);
9247 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9248 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9249 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9250 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9251 else
9252 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9253 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9254 } else {
9255 // Both conditions must be same at the same time for the loop to exit.
9256 // For now, be conservative.
9257 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9258 BECount = EL0.ExactNotTaken;
9259 }
9260
9261 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9262 // to be more aggressive when computing BECount than when computing
9263 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9264 // and
9265 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9266 // EL1.ConstantMaxNotTaken to not.
9267 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9268 !isa<SCEVCouldNotCompute>(BECount))
9269 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9270 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9271 SymbolicMaxBECount =
9272 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9273 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9274 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
9275}
9276
9277ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9278 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9279 bool AllowPredicates) {
9280 // If the condition was exit on true, convert the condition to exit on false
9281 CmpPredicate Pred;
9282 if (!ExitIfTrue)
9283 Pred = ExitCond->getCmpPredicate();
9284 else
9285 Pred = ExitCond->getInverseCmpPredicate();
9286 const ICmpInst::Predicate OriginalPred = Pred;
9287
9288 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9289 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9290
9291 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9292 AllowPredicates);
9293 if (EL.hasAnyInfo())
9294 return EL;
9295
9296 auto *ExhaustiveCount =
9297 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9298
9299 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9300 return ExhaustiveCount;
9301
9302 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9303 ExitCond->getOperand(1), L, OriginalPred);
9304}
9305ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9306 const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS,
9307 bool ControlsOnlyExit, bool AllowPredicates) {
9308
9309 // Try to evaluate any dependencies out of the loop.
9310 LHS = getSCEVAtScope(LHS, L);
9311 RHS = getSCEVAtScope(RHS, L);
9312
9313 // At this point, we would like to compute how many iterations of the
9314 // loop the predicate will return true for these inputs.
9315 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9316 // If there is a loop-invariant, force it into the RHS.
9317 std::swap(LHS, RHS);
9319 }
9320
9321 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9323 // Simplify the operands before analyzing them.
9324 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9325
9326 // If we have a comparison of a chrec against a constant, try to use value
9327 // ranges to answer this query.
9328 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9329 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9330 if (AddRec->getLoop() == L) {
9331 // Form the constant range.
9332 ConstantRange CompRange =
9333 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9334
9335 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9336 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9337 }
9338
9339 // If this loop must exit based on this condition (or execute undefined
9340 // behaviour), see if we can improve wrap flags. This is essentially
9341 // a must execute style proof.
9342 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9343 // If we can prove the test sequence produced must repeat the same values
9344 // on self-wrap of the IV, then we can infer that IV doesn't self wrap
9345 // because if it did, we'd have an infinite (undefined) loop.
9346 // TODO: We can peel off any functions which are invertible *in L*. Loop
9347 // invariant terms are effectively constants for our purposes here.
9348 auto *InnerLHS = LHS;
9349 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9350 InnerLHS = ZExt->getOperand();
9351 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9352 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9353 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true,
9354 /*OrNegative=*/true)) {
9355 auto Flags = AR->getNoWrapFlags();
9356 Flags = setFlags(Flags, SCEV::FlagNW);
9357 SmallVector<const SCEV *> Operands{AR->operands()};
9358 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9359 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9360 }
9361
9362 // For a slt/ult condition with a positive step, can we prove nsw/nuw?
9363 // From no-self-wrap, this follows trivially from the fact that every
9364 // (un)signed-wrapped, but not self-wrapped value must be LT than the
9365 // last value before (un)signed wrap. Since we know that last value
9366 // didn't exit, nor will any smaller one.
9367 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) {
9368 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW;
9369 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9370 AR && AR->getLoop() == L && AR->isAffine() &&
9371 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() &&
9372 isKnownPositive(AR->getStepRecurrence(*this))) {
9373 auto Flags = AR->getNoWrapFlags();
9374 Flags = setFlags(Flags, WrapType);
9375 SmallVector<const SCEV*> Operands{AR->operands()};
9376 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9377 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9378 }
9379 }
9380 }
9381
9382 switch (Pred) {
9383 case ICmpInst::ICMP_NE: { // while (X != Y)
9384 // Convert to: while (X-Y != 0)
9385 if (LHS->getType()->isPointerTy()) {
9388 return LHS;
9389 }
9390 if (RHS->getType()->isPointerTy()) {
9393 return RHS;
9394 }
9395 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9396 AllowPredicates);
9397 if (EL.hasAnyInfo())
9398 return EL;
9399 break;
9400 }
9401 case ICmpInst::ICMP_EQ: { // while (X == Y)
9402 // Convert to: while (X-Y == 0)
9403 if (LHS->getType()->isPointerTy()) {
9406 return LHS;
9407 }
9408 if (RHS->getType()->isPointerTy()) {
9411 return RHS;
9412 }
9413 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9414 if (EL.hasAnyInfo()) return EL;
9415 break;
9416 }
9417 case ICmpInst::ICMP_SLE:
9418 case ICmpInst::ICMP_ULE:
9419 // Since the loop is finite, an invariant RHS cannot include the boundary
9420 // value, otherwise it would loop forever.
9421 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9422 !isLoopInvariant(RHS, L)) {
9423 // Otherwise, perform the addition in a wider type, to avoid overflow.
9424 // If the LHS is an addrec with the appropriate nowrap flag, the
9425 // extension will be sunk into it and the exit count can be analyzed.
9426 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9427 if (!OldType)
9428 break;
9429 // Prefer doubling the bitwidth over adding a single bit to make it more
9430 // likely that we use a legal type.
9431 auto *NewType =
9432 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9433 if (ICmpInst::isSigned(Pred)) {
9434 LHS = getSignExtendExpr(LHS, NewType);
9435 RHS = getSignExtendExpr(RHS, NewType);
9436 } else {
9437 LHS = getZeroExtendExpr(LHS, NewType);
9438 RHS = getZeroExtendExpr(RHS, NewType);
9439 }
9440 }
9442 [[fallthrough]];
9443 case ICmpInst::ICMP_SLT:
9444 case ICmpInst::ICMP_ULT: { // while (X < Y)
9445 bool IsSigned = ICmpInst::isSigned(Pred);
9446 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9447 AllowPredicates);
9448 if (EL.hasAnyInfo())
9449 return EL;
9450 break;
9451 }
9452 case ICmpInst::ICMP_SGE:
9453 case ICmpInst::ICMP_UGE:
9454 // Since the loop is finite, an invariant RHS cannot include the boundary
9455 // value, otherwise it would loop forever.
9456 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9457 !isLoopInvariant(RHS, L))
9458 break;
9460 [[fallthrough]];
9461 case ICmpInst::ICMP_SGT:
9462 case ICmpInst::ICMP_UGT: { // while (X > Y)
9463 bool IsSigned = ICmpInst::isSigned(Pred);
9464 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9465 AllowPredicates);
9466 if (EL.hasAnyInfo())
9467 return EL;
9468 break;
9469 }
9470 default:
9471 break;
9472 }
9473
9474 return getCouldNotCompute();
9475}
9476
9477ScalarEvolution::ExitLimit
9478ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9479 SwitchInst *Switch,
9480 BasicBlock *ExitingBlock,
9481 bool ControlsOnlyExit) {
9482 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9483
9484 // Give up if the exit is the default dest of a switch.
9485 if (Switch->getDefaultDest() == ExitingBlock)
9486 return getCouldNotCompute();
9487
9488 assert(L->contains(Switch->getDefaultDest()) &&
9489 "Default case must not exit the loop!");
9490 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9491 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9492
9493 // while (X != Y) --> while (X-Y != 0)
9494 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9495 if (EL.hasAnyInfo())
9496 return EL;
9497
9498 return getCouldNotCompute();
9499}
9500
9501static ConstantInt *
9503 ScalarEvolution &SE) {
9504 const SCEV *InVal = SE.getConstant(C);
9505 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9507 "Evaluation of SCEV at constant didn't fold correctly?");
9508 return cast<SCEVConstant>(Val)->getValue();
9509}
9510
9511ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9512 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9513 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9514 if (!RHS)
9515 return getCouldNotCompute();
9516
9517 const BasicBlock *Latch = L->getLoopLatch();
9518 if (!Latch)
9519 return getCouldNotCompute();
9520
9521 const BasicBlock *Predecessor = L->getLoopPredecessor();
9522 if (!Predecessor)
9523 return getCouldNotCompute();
9524
9525 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9526 // Return LHS in OutLHS and shift_opt in OutOpCode.
9527 auto MatchPositiveShift =
9528 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9529
9530 using namespace PatternMatch;
9531
9532 ConstantInt *ShiftAmt;
9533 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9534 OutOpCode = Instruction::LShr;
9535 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9536 OutOpCode = Instruction::AShr;
9537 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9538 OutOpCode = Instruction::Shl;
9539 else
9540 return false;
9541
9542 return ShiftAmt->getValue().isStrictlyPositive();
9543 };
9544
9545 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9546 //
9547 // loop:
9548 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9549 // %iv.shifted = lshr i32 %iv, <positive constant>
9550 //
9551 // Return true on a successful match. Return the corresponding PHI node (%iv
9552 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9553 auto MatchShiftRecurrence =
9554 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9555 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9556
9557 {
9559 Value *V;
9560
9561 // If we encounter a shift instruction, "peel off" the shift operation,
9562 // and remember that we did so. Later when we inspect %iv's backedge
9563 // value, we will make sure that the backedge value uses the same
9564 // operation.
9565 //
9566 // Note: the peeled shift operation does not have to be the same
9567 // instruction as the one feeding into the PHI's backedge value. We only
9568 // really care about it being the same *kind* of shift instruction --
9569 // that's all that is required for our later inferences to hold.
9570 if (MatchPositiveShift(LHS, V, OpC)) {
9571 PostShiftOpCode = OpC;
9572 LHS = V;
9573 }
9574 }
9575
9576 PNOut = dyn_cast<PHINode>(LHS);
9577 if (!PNOut || PNOut->getParent() != L->getHeader())
9578 return false;
9579
9580 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9581 Value *OpLHS;
9582
9583 return
9584 // The backedge value for the PHI node must be a shift by a positive
9585 // amount
9586 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9587
9588 // of the PHI node itself
9589 OpLHS == PNOut &&
9590
9591 // and the kind of shift should be match the kind of shift we peeled
9592 // off, if any.
9593 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9594 };
9595
9596 PHINode *PN;
9598 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9599 return getCouldNotCompute();
9600
9601 const DataLayout &DL = getDataLayout();
9602
9603 // The key rationale for this optimization is that for some kinds of shift
9604 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9605 // within a finite number of iterations. If the condition guarding the
9606 // backedge (in the sense that the backedge is taken if the condition is true)
9607 // is false for the value the shift recurrence stabilizes to, then we know
9608 // that the backedge is taken only a finite number of times.
9609
9610 ConstantInt *StableValue = nullptr;
9611 switch (OpCode) {
9612 default:
9613 llvm_unreachable("Impossible case!");
9614
9615 case Instruction::AShr: {
9616 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9617 // bitwidth(K) iterations.
9618 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9619 KnownBits Known = computeKnownBits(FirstValue, DL, &AC,
9620 Predecessor->getTerminator(), &DT);
9621 auto *Ty = cast<IntegerType>(RHS->getType());
9622 if (Known.isNonNegative())
9623 StableValue = ConstantInt::get(Ty, 0);
9624 else if (Known.isNegative())
9625 StableValue = ConstantInt::get(Ty, -1, true);
9626 else
9627 return getCouldNotCompute();
9628
9629 break;
9630 }
9631 case Instruction::LShr:
9632 case Instruction::Shl:
9633 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9634 // stabilize to 0 in at most bitwidth(K) iterations.
9635 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9636 break;
9637 }
9638
9639 auto *Result =
9640 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9641 assert(Result->getType()->isIntegerTy(1) &&
9642 "Otherwise cannot be an operand to a branch instruction");
9643
9644 if (Result->isNullValue()) {
9645 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9646 const SCEV *UpperBound =
9648 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9649 }
9650
9651 return getCouldNotCompute();
9652}
9653
9654/// Return true if we can constant fold an instruction of the specified type,
9655/// assuming that all operands were constants.
9656static bool CanConstantFold(const Instruction *I) {
9660 return true;
9661
9662 if (const CallInst *CI = dyn_cast<CallInst>(I))
9663 if (const Function *F = CI->getCalledFunction())
9664 return canConstantFoldCallTo(CI, F);
9665 return false;
9666}
9667
9668/// Determine whether this instruction can constant evolve within this loop
9669/// assuming its operands can all constant evolve.
9670static bool canConstantEvolve(Instruction *I, const Loop *L) {
9671 // An instruction outside of the loop can't be derived from a loop PHI.
9672 if (!L->contains(I)) return false;
9673
9674 if (isa<PHINode>(I)) {
9675 // We don't currently keep track of the control flow needed to evaluate
9676 // PHIs, so we cannot handle PHIs inside of loops.
9677 return L->getHeader() == I->getParent();
9678 }
9679
9680 // If we won't be able to constant fold this expression even if the operands
9681 // are constants, bail early.
9682 return CanConstantFold(I);
9683}
9684
9685/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9686/// recursing through each instruction operand until reaching a loop header phi.
9687static PHINode *
9690 unsigned Depth) {
9692 return nullptr;
9693
9694 // Otherwise, we can evaluate this instruction if all of its operands are
9695 // constant or derived from a PHI node themselves.
9696 PHINode *PHI = nullptr;
9697 for (Value *Op : UseInst->operands()) {
9698 if (isa<Constant>(Op)) continue;
9699
9701 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9702
9703 PHINode *P = dyn_cast<PHINode>(OpInst);
9704 if (!P)
9705 // If this operand is already visited, reuse the prior result.
9706 // We may have P != PHI if this is the deepest point at which the
9707 // inconsistent paths meet.
9708 P = PHIMap.lookup(OpInst);
9709 if (!P) {
9710 // Recurse and memoize the results, whether a phi is found or not.
9711 // This recursive call invalidates pointers into PHIMap.
9712 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9713 PHIMap[OpInst] = P;
9714 }
9715 if (!P)
9716 return nullptr; // Not evolving from PHI
9717 if (PHI && PHI != P)
9718 return nullptr; // Evolving from multiple different PHIs.
9719 PHI = P;
9720 }
9721 // This is a expression evolving from a constant PHI!
9722 return PHI;
9723}
9724
9725/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9726/// in the loop that V is derived from. We allow arbitrary operations along the
9727/// way, but the operands of an operation must either be constants or a value
9728/// derived from a constant PHI. If this expression does not fit with these
9729/// constraints, return null.
9732 if (!I || !canConstantEvolve(I, L)) return nullptr;
9733
9734 if (PHINode *PN = dyn_cast<PHINode>(I))
9735 return PN;
9736
9737 // Record non-constant instructions contained by the loop.
9739 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9740}
9741
9742/// EvaluateExpression - Given an expression that passes the
9743/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9744/// in the loop has the value PHIVal. If we can't fold this expression for some
9745/// reason, return null.
9748 const DataLayout &DL,
9749 const TargetLibraryInfo *TLI) {
9750 // Convenient constant check, but redundant for recursive calls.
9751 if (Constant *C = dyn_cast<Constant>(V)) return C;
9753 if (!I) return nullptr;
9754
9755 if (Constant *C = Vals.lookup(I)) return C;
9756
9757 // An instruction inside the loop depends on a value outside the loop that we
9758 // weren't given a mapping for, or a value such as a call inside the loop.
9759 if (!canConstantEvolve(I, L)) return nullptr;
9760
9761 // An unmapped PHI can be due to a branch or another loop inside this loop,
9762 // or due to this not being the initial iteration through a loop where we
9763 // couldn't compute the evolution of this particular PHI last time.
9764 if (isa<PHINode>(I)) return nullptr;
9765
9766 std::vector<Constant*> Operands(I->getNumOperands());
9767
9768 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9769 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9770 if (!Operand) {
9771 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9772 if (!Operands[i]) return nullptr;
9773 continue;
9774 }
9775 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9776 Vals[Operand] = C;
9777 if (!C) return nullptr;
9778 Operands[i] = C;
9779 }
9780
9781 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9782 /*AllowNonDeterministic=*/false);
9783}
9784
9785
9786// If every incoming value to PN except the one for BB is a specific Constant,
9787// return that, else return nullptr.
9789 Constant *IncomingVal = nullptr;
9790
9791 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9792 if (PN->getIncomingBlock(i) == BB)
9793 continue;
9794
9795 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9796 if (!CurrentVal)
9797 return nullptr;
9798
9799 if (IncomingVal != CurrentVal) {
9800 if (IncomingVal)
9801 return nullptr;
9802 IncomingVal = CurrentVal;
9803 }
9804 }
9805
9806 return IncomingVal;
9807}
9808
9809/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9810/// in the header of its containing loop, we know the loop executes a
9811/// constant number of times, and the PHI node is just a recurrence
9812/// involving constants, fold it.
9813Constant *
9814ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9815 const APInt &BEs,
9816 const Loop *L) {
9817 auto [I, Inserted] = ConstantEvolutionLoopExitValue.try_emplace(PN);
9818 if (!Inserted)
9819 return I->second;
9820
9822 return nullptr; // Not going to evaluate it.
9823
9824 Constant *&RetVal = I->second;
9825
9826 DenseMap<Instruction *, Constant *> CurrentIterVals;
9827 BasicBlock *Header = L->getHeader();
9828 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9829
9830 BasicBlock *Latch = L->getLoopLatch();
9831 if (!Latch)
9832 return nullptr;
9833
9834 for (PHINode &PHI : Header->phis()) {
9835 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9836 CurrentIterVals[&PHI] = StartCST;
9837 }
9838 if (!CurrentIterVals.count(PN))
9839 return RetVal = nullptr;
9840
9841 Value *BEValue = PN->getIncomingValueForBlock(Latch);
9842
9843 // Execute the loop symbolically to determine the exit value.
9844 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
9845 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
9846
9847 unsigned NumIterations = BEs.getZExtValue(); // must be in range
9848 unsigned IterationNum = 0;
9849 const DataLayout &DL = getDataLayout();
9850 for (; ; ++IterationNum) {
9851 if (IterationNum == NumIterations)
9852 return RetVal = CurrentIterVals[PN]; // Got exit value!
9853
9854 // Compute the value of the PHIs for the next iteration.
9855 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
9856 DenseMap<Instruction *, Constant *> NextIterVals;
9857 Constant *NextPHI =
9858 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9859 if (!NextPHI)
9860 return nullptr; // Couldn't evaluate!
9861 NextIterVals[PN] = NextPHI;
9862
9863 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
9864
9865 // Also evaluate the other PHI nodes. However, we don't get to stop if we
9866 // cease to be able to evaluate one of them or if they stop evolving,
9867 // because that doesn't necessarily prevent us from computing PN.
9869 for (const auto &I : CurrentIterVals) {
9870 PHINode *PHI = dyn_cast<PHINode>(I.first);
9871 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
9872 PHIsToCompute.emplace_back(PHI, I.second);
9873 }
9874 // We use two distinct loops because EvaluateExpression may invalidate any
9875 // iterators into CurrentIterVals.
9876 for (const auto &I : PHIsToCompute) {
9877 PHINode *PHI = I.first;
9878 Constant *&NextPHI = NextIterVals[PHI];
9879 if (!NextPHI) { // Not already computed.
9880 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9881 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9882 }
9883 if (NextPHI != I.second)
9884 StoppedEvolving = false;
9885 }
9886
9887 // If all entries in CurrentIterVals == NextIterVals then we can stop
9888 // iterating, the loop can't continue to change.
9889 if (StoppedEvolving)
9890 return RetVal = CurrentIterVals[PN];
9891
9892 CurrentIterVals.swap(NextIterVals);
9893 }
9894}
9895
9896const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
9897 Value *Cond,
9898 bool ExitWhen) {
9899 PHINode *PN = getConstantEvolvingPHI(Cond, L);
9900 if (!PN) return getCouldNotCompute();
9901
9902 // If the loop is canonicalized, the PHI will have exactly two entries.
9903 // That's the only form we support here.
9904 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
9905
9906 DenseMap<Instruction *, Constant *> CurrentIterVals;
9907 BasicBlock *Header = L->getHeader();
9908 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9909
9910 BasicBlock *Latch = L->getLoopLatch();
9911 assert(Latch && "Should follow from NumIncomingValues == 2!");
9912
9913 for (PHINode &PHI : Header->phis()) {
9914 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9915 CurrentIterVals[&PHI] = StartCST;
9916 }
9917 if (!CurrentIterVals.count(PN))
9918 return getCouldNotCompute();
9919
9920 // Okay, we find a PHI node that defines the trip count of this loop. Execute
9921 // the loop symbolically to determine when the condition gets a value of
9922 // "ExitWhen".
9923 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
9924 const DataLayout &DL = getDataLayout();
9925 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
9926 auto *CondVal = dyn_cast_or_null<ConstantInt>(
9927 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
9928
9929 // Couldn't symbolically evaluate.
9930 if (!CondVal) return getCouldNotCompute();
9931
9932 if (CondVal->getValue() == uint64_t(ExitWhen)) {
9933 ++NumBruteForceTripCountsComputed;
9934 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
9935 }
9936
9937 // Update all the PHI nodes for the next iteration.
9938 DenseMap<Instruction *, Constant *> NextIterVals;
9939
9940 // Create a list of which PHIs we need to compute. We want to do this before
9941 // calling EvaluateExpression on them because that may invalidate iterators
9942 // into CurrentIterVals.
9943 SmallVector<PHINode *, 8> PHIsToCompute;
9944 for (const auto &I : CurrentIterVals) {
9945 PHINode *PHI = dyn_cast<PHINode>(I.first);
9946 if (!PHI || PHI->getParent() != Header) continue;
9947 PHIsToCompute.push_back(PHI);
9948 }
9949 for (PHINode *PHI : PHIsToCompute) {
9950 Constant *&NextPHI = NextIterVals[PHI];
9951 if (NextPHI) continue; // Already computed!
9952
9953 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9954 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9955 }
9956 CurrentIterVals.swap(NextIterVals);
9957 }
9958
9959 // Too many iterations were needed to evaluate.
9960 return getCouldNotCompute();
9961}
9962
9963const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
9965 ValuesAtScopes[V];
9966 // Check to see if we've folded this expression at this loop before.
9967 for (auto &LS : Values)
9968 if (LS.first == L)
9969 return LS.second ? LS.second : V;
9970
9971 Values.emplace_back(L, nullptr);
9972
9973 // Otherwise compute it.
9974 const SCEV *C = computeSCEVAtScope(V, L);
9975 for (auto &LS : reverse(ValuesAtScopes[V]))
9976 if (LS.first == L) {
9977 LS.second = C;
9978 if (!isa<SCEVConstant>(C))
9979 ValuesAtScopesUsers[C].push_back({L, V});
9980 break;
9981 }
9982 return C;
9983}
9984
9985/// This builds up a Constant using the ConstantExpr interface. That way, we
9986/// will return Constants for objects which aren't represented by a
9987/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
9988/// Returns NULL if the SCEV isn't representable as a Constant.
9990 switch (V->getSCEVType()) {
9991 case scCouldNotCompute:
9992 case scAddRecExpr:
9993 case scVScale:
9994 return nullptr;
9995 case scConstant:
9996 return cast<SCEVConstant>(V)->getValue();
9997 case scUnknown:
9998 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
9999 case scPtrToAddr: {
10001 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
10002 return ConstantExpr::getPtrToAddr(CastOp, P2I->getType());
10003
10004 return nullptr;
10005 }
10006 case scPtrToInt: {
10008 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
10009 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
10010
10011 return nullptr;
10012 }
10013 case scTruncate: {
10015 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
10016 return ConstantExpr::getTrunc(CastOp, ST->getType());
10017 return nullptr;
10018 }
10019 case scAddExpr: {
10020 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
10021 Constant *C = nullptr;
10022 for (const SCEV *Op : SA->operands()) {
10024 if (!OpC)
10025 return nullptr;
10026 if (!C) {
10027 C = OpC;
10028 continue;
10029 }
10030 assert(!C->getType()->isPointerTy() &&
10031 "Can only have one pointer, and it must be last");
10032 if (OpC->getType()->isPointerTy()) {
10033 // The offsets have been converted to bytes. We can add bytes using
10034 // an i8 GEP.
10035 C = ConstantExpr::getPtrAdd(OpC, C);
10036 } else {
10037 C = ConstantExpr::getAdd(C, OpC);
10038 }
10039 }
10040 return C;
10041 }
10042 case scMulExpr:
10043 case scSignExtend:
10044 case scZeroExtend:
10045 case scUDivExpr:
10046 case scSMaxExpr:
10047 case scUMaxExpr:
10048 case scSMinExpr:
10049 case scUMinExpr:
10051 return nullptr;
10052 }
10053 llvm_unreachable("Unknown SCEV kind!");
10054}
10055
10056const SCEV *
10057ScalarEvolution::getWithOperands(const SCEV *S,
10058 SmallVectorImpl<const SCEV *> &NewOps) {
10059 switch (S->getSCEVType()) {
10060 case scTruncate:
10061 case scZeroExtend:
10062 case scSignExtend:
10063 case scPtrToAddr:
10064 case scPtrToInt:
10065 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
10066 case scAddRecExpr: {
10067 auto *AddRec = cast<SCEVAddRecExpr>(S);
10068 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
10069 }
10070 case scAddExpr:
10071 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
10072 case scMulExpr:
10073 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
10074 case scUDivExpr:
10075 return getUDivExpr(NewOps[0], NewOps[1]);
10076 case scUMaxExpr:
10077 case scSMaxExpr:
10078 case scUMinExpr:
10079 case scSMinExpr:
10080 return getMinMaxExpr(S->getSCEVType(), NewOps);
10082 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
10083 case scConstant:
10084 case scVScale:
10085 case scUnknown:
10086 return S;
10087 case scCouldNotCompute:
10088 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10089 }
10090 llvm_unreachable("Unknown SCEV kind!");
10091}
10092
10093const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
10094 switch (V->getSCEVType()) {
10095 case scConstant:
10096 case scVScale:
10097 return V;
10098 case scAddRecExpr: {
10099 // If this is a loop recurrence for a loop that does not contain L, then we
10100 // are dealing with the final value computed by the loop.
10101 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
10102 // First, attempt to evaluate each operand.
10103 // Avoid performing the look-up in the common case where the specified
10104 // expression has no loop-variant portions.
10105 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
10106 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
10107 if (OpAtScope == AddRec->getOperand(i))
10108 continue;
10109
10110 // Okay, at least one of these operands is loop variant but might be
10111 // foldable. Build a new instance of the folded commutative expression.
10113 NewOps.reserve(AddRec->getNumOperands());
10114 append_range(NewOps, AddRec->operands().take_front(i));
10115 NewOps.push_back(OpAtScope);
10116 for (++i; i != e; ++i)
10117 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
10118
10119 const SCEV *FoldedRec = getAddRecExpr(
10120 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
10121 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
10122 // The addrec may be folded to a nonrecurrence, for example, if the
10123 // induction variable is multiplied by zero after constant folding. Go
10124 // ahead and return the folded value.
10125 if (!AddRec)
10126 return FoldedRec;
10127 break;
10128 }
10129
10130 // If the scope is outside the addrec's loop, evaluate it by using the
10131 // loop exit value of the addrec.
10132 if (!AddRec->getLoop()->contains(L)) {
10133 // To evaluate this recurrence, we need to know how many times the AddRec
10134 // loop iterates. Compute this now.
10135 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
10136 if (BackedgeTakenCount == getCouldNotCompute())
10137 return AddRec;
10138
10139 // Then, evaluate the AddRec.
10140 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
10141 }
10142
10143 return AddRec;
10144 }
10145 case scTruncate:
10146 case scZeroExtend:
10147 case scSignExtend:
10148 case scPtrToAddr:
10149 case scPtrToInt:
10150 case scAddExpr:
10151 case scMulExpr:
10152 case scUDivExpr:
10153 case scUMaxExpr:
10154 case scSMaxExpr:
10155 case scUMinExpr:
10156 case scSMinExpr:
10157 case scSequentialUMinExpr: {
10158 ArrayRef<const SCEV *> Ops = V->operands();
10159 // Avoid performing the look-up in the common case where the specified
10160 // expression has no loop-variant portions.
10161 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
10162 const SCEV *OpAtScope = getSCEVAtScope(Ops[i], L);
10163 if (OpAtScope != Ops[i]) {
10164 // Okay, at least one of these operands is loop variant but might be
10165 // foldable. Build a new instance of the folded commutative expression.
10167 NewOps.reserve(Ops.size());
10168 append_range(NewOps, Ops.take_front(i));
10169 NewOps.push_back(OpAtScope);
10170
10171 for (++i; i != e; ++i) {
10172 OpAtScope = getSCEVAtScope(Ops[i], L);
10173 NewOps.push_back(OpAtScope);
10174 }
10175
10176 return getWithOperands(V, NewOps);
10177 }
10178 }
10179 // If we got here, all operands are loop invariant.
10180 return V;
10181 }
10182 case scUnknown: {
10183 // If this instruction is evolved from a constant-evolving PHI, compute the
10184 // exit value from the loop without using SCEVs.
10185 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
10187 if (!I)
10188 return V; // This is some other type of SCEVUnknown, just return it.
10189
10190 if (PHINode *PN = dyn_cast<PHINode>(I)) {
10191 const Loop *CurrLoop = this->LI[I->getParent()];
10192 // Looking for loop exit value.
10193 if (CurrLoop && CurrLoop->getParentLoop() == L &&
10194 PN->getParent() == CurrLoop->getHeader()) {
10195 // Okay, there is no closed form solution for the PHI node. Check
10196 // to see if the loop that contains it has a known backedge-taken
10197 // count. If so, we may be able to force computation of the exit
10198 // value.
10199 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10200 // This trivial case can show up in some degenerate cases where
10201 // the incoming IR has not yet been fully simplified.
10202 if (BackedgeTakenCount->isZero()) {
10203 Value *InitValue = nullptr;
10204 bool MultipleInitValues = false;
10205 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10206 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10207 if (!InitValue)
10208 InitValue = PN->getIncomingValue(i);
10209 else if (InitValue != PN->getIncomingValue(i)) {
10210 MultipleInitValues = true;
10211 break;
10212 }
10213 }
10214 }
10215 if (!MultipleInitValues && InitValue)
10216 return getSCEV(InitValue);
10217 }
10218 // Do we have a loop invariant value flowing around the backedge
10219 // for a loop which must execute the backedge?
10220 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10221 isKnownNonZero(BackedgeTakenCount) &&
10222 PN->getNumIncomingValues() == 2) {
10223
10224 unsigned InLoopPred =
10225 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10226 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10227 if (CurrLoop->isLoopInvariant(BackedgeVal))
10228 return getSCEV(BackedgeVal);
10229 }
10230 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10231 // Okay, we know how many times the containing loop executes. If
10232 // this is a constant evolving PHI node, get the final value at
10233 // the specified iteration number.
10234 Constant *RV =
10235 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10236 if (RV)
10237 return getSCEV(RV);
10238 }
10239 }
10240 }
10241
10242 // Okay, this is an expression that we cannot symbolically evaluate
10243 // into a SCEV. Check to see if it's possible to symbolically evaluate
10244 // the arguments into constants, and if so, try to constant propagate the
10245 // result. This is particularly useful for computing loop exit values.
10246 if (!CanConstantFold(I))
10247 return V; // This is some other type of SCEVUnknown, just return it.
10248
10249 SmallVector<Constant *, 4> Operands;
10250 Operands.reserve(I->getNumOperands());
10251 bool MadeImprovement = false;
10252 for (Value *Op : I->operands()) {
10253 if (Constant *C = dyn_cast<Constant>(Op)) {
10254 Operands.push_back(C);
10255 continue;
10256 }
10257
10258 // If any of the operands is non-constant and if they are
10259 // non-integer and non-pointer, don't even try to analyze them
10260 // with scev techniques.
10261 if (!isSCEVable(Op->getType()))
10262 return V;
10263
10264 const SCEV *OrigV = getSCEV(Op);
10265 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10266 MadeImprovement |= OrigV != OpV;
10267
10269 if (!C)
10270 return V;
10271 assert(C->getType() == Op->getType() && "Type mismatch");
10272 Operands.push_back(C);
10273 }
10274
10275 // Check to see if getSCEVAtScope actually made an improvement.
10276 if (!MadeImprovement)
10277 return V; // This is some other type of SCEVUnknown, just return it.
10278
10279 Constant *C = nullptr;
10280 const DataLayout &DL = getDataLayout();
10281 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10282 /*AllowNonDeterministic=*/false);
10283 if (!C)
10284 return V;
10285 return getSCEV(C);
10286 }
10287 case scCouldNotCompute:
10288 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10289 }
10290 llvm_unreachable("Unknown SCEV type!");
10291}
10292
10294 return getSCEVAtScope(getSCEV(V), L);
10295}
10296
10297const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10299 return stripInjectiveFunctions(ZExt->getOperand());
10301 return stripInjectiveFunctions(SExt->getOperand());
10302 return S;
10303}
10304
10305/// Finds the minimum unsigned root of the following equation:
10306///
10307/// A * X = B (mod N)
10308///
10309/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10310/// A and B isn't important.
10311///
10312/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10313static const SCEV *
10316 ScalarEvolution &SE, const Loop *L) {
10317 uint32_t BW = A.getBitWidth();
10318 assert(BW == SE.getTypeSizeInBits(B->getType()));
10319 assert(A != 0 && "A must be non-zero.");
10320
10321 // 1. D = gcd(A, N)
10322 //
10323 // The gcd of A and N may have only one prime factor: 2. The number of
10324 // trailing zeros in A is its multiplicity
10325 uint32_t Mult2 = A.countr_zero();
10326 // D = 2^Mult2
10327
10328 // 2. Check if B is divisible by D.
10329 //
10330 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10331 // is not less than multiplicity of this prime factor for D.
10332 unsigned MinTZ = SE.getMinTrailingZeros(B);
10333 // Try again with the terminator of the loop predecessor for context-specific
10334 // result, if MinTZ s too small.
10335 if (MinTZ < Mult2 && L->getLoopPredecessor())
10336 MinTZ = SE.getMinTrailingZeros(B, L->getLoopPredecessor()->getTerminator());
10337 if (MinTZ < Mult2) {
10338 // Check if we can prove there's no remainder using URem.
10339 const SCEV *URem =
10340 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
10341 const SCEV *Zero = SE.getZero(B->getType());
10342 if (!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero)) {
10343 // Try to add a predicate ensuring B is a multiple of 1 << Mult2.
10344 if (!Predicates)
10345 return SE.getCouldNotCompute();
10346
10347 // Avoid adding a predicate that is known to be false.
10348 if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
10349 return SE.getCouldNotCompute();
10350 Predicates->push_back(SE.getEqualPredicate(URem, Zero));
10351 }
10352 }
10353
10354 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10355 // modulo (N / D).
10356 //
10357 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10358 // (N / D) in general. The inverse itself always fits into BW bits, though,
10359 // so we immediately truncate it.
10360 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10361 APInt I = AD.multiplicativeInverse().zext(BW);
10362
10363 // 4. Compute the minimum unsigned root of the equation:
10364 // I * (B / D) mod (N / D)
10365 // To simplify the computation, we factor out the divide by D:
10366 // (I * B mod N) / D
10367 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10368 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10369}
10370
10371/// For a given quadratic addrec, generate coefficients of the corresponding
10372/// quadratic equation, multiplied by a common value to ensure that they are
10373/// integers.
10374/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10375/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10376/// were multiplied by, and BitWidth is the bit width of the original addrec
10377/// coefficients.
10378/// This function returns std::nullopt if the addrec coefficients are not
10379/// compile- time constants.
10380static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10382 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10383 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10384 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10385 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10386 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10387 << *AddRec << '\n');
10388
10389 // We currently can only solve this if the coefficients are constants.
10390 if (!LC || !MC || !NC) {
10391 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10392 return std::nullopt;
10393 }
10394
10395 APInt L = LC->getAPInt();
10396 APInt M = MC->getAPInt();
10397 APInt N = NC->getAPInt();
10398 assert(!N.isZero() && "This is not a quadratic addrec");
10399
10400 unsigned BitWidth = LC->getAPInt().getBitWidth();
10401 unsigned NewWidth = BitWidth + 1;
10402 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10403 << BitWidth << '\n');
10404 // The sign-extension (as opposed to a zero-extension) here matches the
10405 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10406 N = N.sext(NewWidth);
10407 M = M.sext(NewWidth);
10408 L = L.sext(NewWidth);
10409
10410 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10411 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10412 // L+M, L+2M+N, L+3M+3N, ...
10413 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10414 //
10415 // The equation Acc = 0 is then
10416 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10417 // In a quadratic form it becomes:
10418 // N n^2 + (2M-N) n + 2L = 0.
10419
10420 APInt A = N;
10421 APInt B = 2 * M - A;
10422 APInt C = 2 * L;
10423 APInt T = APInt(NewWidth, 2);
10424 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10425 << "x + " << C << ", coeff bw: " << NewWidth
10426 << ", multiplied by " << T << '\n');
10427 return std::make_tuple(A, B, C, T, BitWidth);
10428}
10429
10430/// Helper function to compare optional APInts:
10431/// (a) if X and Y both exist, return min(X, Y),
10432/// (b) if neither X nor Y exist, return std::nullopt,
10433/// (c) if exactly one of X and Y exists, return that value.
10434static std::optional<APInt> MinOptional(std::optional<APInt> X,
10435 std::optional<APInt> Y) {
10436 if (X && Y) {
10437 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10438 APInt XW = X->sext(W);
10439 APInt YW = Y->sext(W);
10440 return XW.slt(YW) ? *X : *Y;
10441 }
10442 if (!X && !Y)
10443 return std::nullopt;
10444 return X ? *X : *Y;
10445}
10446
10447/// Helper function to truncate an optional APInt to a given BitWidth.
10448/// When solving addrec-related equations, it is preferable to return a value
10449/// that has the same bit width as the original addrec's coefficients. If the
10450/// solution fits in the original bit width, truncate it (except for i1).
10451/// Returning a value of a different bit width may inhibit some optimizations.
10452///
10453/// In general, a solution to a quadratic equation generated from an addrec
10454/// may require BW+1 bits, where BW is the bit width of the addrec's
10455/// coefficients. The reason is that the coefficients of the quadratic
10456/// equation are BW+1 bits wide (to avoid truncation when converting from
10457/// the addrec to the equation).
10458static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10459 unsigned BitWidth) {
10460 if (!X)
10461 return std::nullopt;
10462 unsigned W = X->getBitWidth();
10464 return X->trunc(BitWidth);
10465 return X;
10466}
10467
10468/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10469/// iterations. The values L, M, N are assumed to be signed, and they
10470/// should all have the same bit widths.
10471/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10472/// where BW is the bit width of the addrec's coefficients.
10473/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10474/// returned as such, otherwise the bit width of the returned value may
10475/// be greater than BW.
10476///
10477/// This function returns std::nullopt if
10478/// (a) the addrec coefficients are not constant, or
10479/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10480/// like x^2 = 5, no integer solutions exist, in other cases an integer
10481/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10482static std::optional<APInt>
10484 APInt A, B, C, M;
10485 unsigned BitWidth;
10486 auto T = GetQuadraticEquation(AddRec);
10487 if (!T)
10488 return std::nullopt;
10489
10490 std::tie(A, B, C, M, BitWidth) = *T;
10491 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10492 std::optional<APInt> X =
10494 if (!X)
10495 return std::nullopt;
10496
10497 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10498 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10499 if (!V->isZero())
10500 return std::nullopt;
10501
10502 return TruncIfPossible(X, BitWidth);
10503}
10504
10505/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10506/// iterations. The values M, N are assumed to be signed, and they
10507/// should all have the same bit widths.
10508/// Find the least n such that c(n) does not belong to the given range,
10509/// while c(n-1) does.
10510///
10511/// This function returns std::nullopt if
10512/// (a) the addrec coefficients are not constant, or
10513/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10514/// bounds of the range.
10515static std::optional<APInt>
10517 const ConstantRange &Range, ScalarEvolution &SE) {
10518 assert(AddRec->getOperand(0)->isZero() &&
10519 "Starting value of addrec should be 0");
10520 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10521 << Range << ", addrec " << *AddRec << '\n');
10522 // This case is handled in getNumIterationsInRange. Here we can assume that
10523 // we start in the range.
10524 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10525 "Addrec's initial value should be in range");
10526
10527 APInt A, B, C, M;
10528 unsigned BitWidth;
10529 auto T = GetQuadraticEquation(AddRec);
10530 if (!T)
10531 return std::nullopt;
10532
10533 // Be careful about the return value: there can be two reasons for not
10534 // returning an actual number. First, if no solutions to the equations
10535 // were found, and second, if the solutions don't leave the given range.
10536 // The first case means that the actual solution is "unknown", the second
10537 // means that it's known, but not valid. If the solution is unknown, we
10538 // cannot make any conclusions.
10539 // Return a pair: the optional solution and a flag indicating if the
10540 // solution was found.
10541 auto SolveForBoundary =
10542 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10543 // Solve for signed overflow and unsigned overflow, pick the lower
10544 // solution.
10545 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10546 << Bound << " (before multiplying by " << M << ")\n");
10547 Bound *= M; // The quadratic equation multiplier.
10548
10549 std::optional<APInt> SO;
10550 if (BitWidth > 1) {
10551 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10552 "signed overflow\n");
10554 }
10555 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10556 "unsigned overflow\n");
10557 std::optional<APInt> UO =
10559
10560 auto LeavesRange = [&] (const APInt &X) {
10561 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10562 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10563 if (Range.contains(V0->getValue()))
10564 return false;
10565 // X should be at least 1, so X-1 is non-negative.
10566 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10567 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10568 if (Range.contains(V1->getValue()))
10569 return true;
10570 return false;
10571 };
10572
10573 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10574 // can be a solution, but the function failed to find it. We cannot treat it
10575 // as "no solution".
10576 if (!SO || !UO)
10577 return {std::nullopt, false};
10578
10579 // Check the smaller value first to see if it leaves the range.
10580 // At this point, both SO and UO must have values.
10581 std::optional<APInt> Min = MinOptional(SO, UO);
10582 if (LeavesRange(*Min))
10583 return { Min, true };
10584 std::optional<APInt> Max = Min == SO ? UO : SO;
10585 if (LeavesRange(*Max))
10586 return { Max, true };
10587
10588 // Solutions were found, but were eliminated, hence the "true".
10589 return {std::nullopt, true};
10590 };
10591
10592 std::tie(A, B, C, M, BitWidth) = *T;
10593 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10594 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10595 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10596 auto SL = SolveForBoundary(Lower);
10597 auto SU = SolveForBoundary(Upper);
10598 // If any of the solutions was unknown, no meaninigful conclusions can
10599 // be made.
10600 if (!SL.second || !SU.second)
10601 return std::nullopt;
10602
10603 // Claim: The correct solution is not some value between Min and Max.
10604 //
10605 // Justification: Assuming that Min and Max are different values, one of
10606 // them is when the first signed overflow happens, the other is when the
10607 // first unsigned overflow happens. Crossing the range boundary is only
10608 // possible via an overflow (treating 0 as a special case of it, modeling
10609 // an overflow as crossing k*2^W for some k).
10610 //
10611 // The interesting case here is when Min was eliminated as an invalid
10612 // solution, but Max was not. The argument is that if there was another
10613 // overflow between Min and Max, it would also have been eliminated if
10614 // it was considered.
10615 //
10616 // For a given boundary, it is possible to have two overflows of the same
10617 // type (signed/unsigned) without having the other type in between: this
10618 // can happen when the vertex of the parabola is between the iterations
10619 // corresponding to the overflows. This is only possible when the two
10620 // overflows cross k*2^W for the same k. In such case, if the second one
10621 // left the range (and was the first one to do so), the first overflow
10622 // would have to enter the range, which would mean that either we had left
10623 // the range before or that we started outside of it. Both of these cases
10624 // are contradictions.
10625 //
10626 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10627 // solution is not some value between the Max for this boundary and the
10628 // Min of the other boundary.
10629 //
10630 // Justification: Assume that we had such Max_A and Min_B corresponding
10631 // to range boundaries A and B and such that Max_A < Min_B. If there was
10632 // a solution between Max_A and Min_B, it would have to be caused by an
10633 // overflow corresponding to either A or B. It cannot correspond to B,
10634 // since Min_B is the first occurrence of such an overflow. If it
10635 // corresponded to A, it would have to be either a signed or an unsigned
10636 // overflow that is larger than both eliminated overflows for A. But
10637 // between the eliminated overflows and this overflow, the values would
10638 // cover the entire value space, thus crossing the other boundary, which
10639 // is a contradiction.
10640
10641 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10642}
10643
10644ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10645 const Loop *L,
10646 bool ControlsOnlyExit,
10647 bool AllowPredicates) {
10648
10649 // This is only used for loops with a "x != y" exit test. The exit condition
10650 // is now expressed as a single expression, V = x-y. So the exit test is
10651 // effectively V != 0. We know and take advantage of the fact that this
10652 // expression only being used in a comparison by zero context.
10653
10655 // If the value is a constant
10656 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10657 // If the value is already zero, the branch will execute zero times.
10658 if (C->getValue()->isZero()) return C;
10659 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10660 }
10661
10662 const SCEVAddRecExpr *AddRec =
10663 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10664
10665 if (!AddRec && AllowPredicates)
10666 // Try to make this an AddRec using runtime tests, in the first X
10667 // iterations of this loop, where X is the SCEV expression found by the
10668 // algorithm below.
10669 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10670
10671 if (!AddRec || AddRec->getLoop() != L)
10672 return getCouldNotCompute();
10673
10674 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10675 // the quadratic equation to solve it.
10676 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10677 // We can only use this value if the chrec ends up with an exact zero
10678 // value at this index. When solving for "X*X != 5", for example, we
10679 // should not accept a root of 2.
10680 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10681 const auto *R = cast<SCEVConstant>(getConstant(*S));
10682 return ExitLimit(R, R, R, false, Predicates);
10683 }
10684 return getCouldNotCompute();
10685 }
10686
10687 // Otherwise we can only handle this if it is affine.
10688 if (!AddRec->isAffine())
10689 return getCouldNotCompute();
10690
10691 // If this is an affine expression, the execution count of this branch is
10692 // the minimum unsigned root of the following equation:
10693 //
10694 // Start + Step*N = 0 (mod 2^BW)
10695 //
10696 // equivalent to:
10697 //
10698 // Step*N = -Start (mod 2^BW)
10699 //
10700 // where BW is the common bit width of Start and Step.
10701
10702 // Get the initial value for the loop.
10703 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10704 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10705
10706 if (!isLoopInvariant(Step, L))
10707 return getCouldNotCompute();
10708
10709 LoopGuards Guards = LoopGuards::collect(L, *this);
10710 // Specialize step for this loop so we get context sensitive facts below.
10711 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10712
10713 // For positive steps (counting up until unsigned overflow):
10714 // N = -Start/Step (as unsigned)
10715 // For negative steps (counting down to zero):
10716 // N = Start/-Step
10717 // First compute the unsigned distance from zero in the direction of Step.
10718 bool CountDown = isKnownNegative(StepWLG);
10719 if (!CountDown && !isKnownNonNegative(StepWLG))
10720 return getCouldNotCompute();
10721
10722 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10723 // Handle unitary steps, which cannot wraparound.
10724 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10725 // N = Distance (as unsigned)
10726
10727 if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
10728 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10729 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10730
10731 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10732 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10733 // case, and see if we can improve the bound.
10734 //
10735 // Explicitly handling this here is necessary because getUnsignedRange
10736 // isn't context-sensitive; it doesn't know that we only care about the
10737 // range inside the loop.
10738 const SCEV *Zero = getZero(Distance->getType());
10739 const SCEV *One = getOne(Distance->getType());
10740 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10741 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10742 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10743 // as "unsigned_max(Distance + 1) - 1".
10744 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10745 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10746 }
10747 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10748 Predicates);
10749 }
10750
10751 // If the condition controls loop exit (the loop exits only if the expression
10752 // is true) and the addition is no-wrap we can use unsigned divide to
10753 // compute the backedge count. In this case, the step may not divide the
10754 // distance, but we don't care because if the condition is "missed" the loop
10755 // will have undefined behavior due to wrapping.
10756 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10757 loopHasNoAbnormalExits(AddRec->getLoop())) {
10758
10759 // If the stride is zero and the start is non-zero, the loop must be
10760 // infinite. In C++, most loops are finite by assumption, in which case the
10761 // step being zero implies UB must execute if the loop is entered.
10762 if (!(loopIsFiniteByAssumption(L) && isKnownNonZero(Start)) &&
10763 !isKnownNonZero(StepWLG))
10764 return getCouldNotCompute();
10765
10766 const SCEV *Exact =
10767 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10768 const SCEV *ConstantMax = getCouldNotCompute();
10769 if (Exact != getCouldNotCompute()) {
10770 APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
10771 ConstantMax =
10773 }
10774 const SCEV *SymbolicMax =
10775 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10776 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10777 }
10778
10779 // Solve the general equation.
10780 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10781 if (!StepC || StepC->getValue()->isZero())
10782 return getCouldNotCompute();
10783 const SCEV *E = SolveLinEquationWithOverflow(
10784 StepC->getAPInt(), getNegativeSCEV(Start),
10785 AllowPredicates ? &Predicates : nullptr, *this, L);
10786
10787 const SCEV *M = E;
10788 if (E != getCouldNotCompute()) {
10789 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10790 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10791 }
10792 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10793 return ExitLimit(E, M, S, false, Predicates);
10794}
10795
10796ScalarEvolution::ExitLimit
10797ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10798 // Loops that look like: while (X == 0) are very strange indeed. We don't
10799 // handle them yet except for the trivial case. This could be expanded in the
10800 // future as needed.
10801
10802 // If the value is a constant, check to see if it is known to be non-zero
10803 // already. If so, the backedge will execute zero times.
10804 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10805 if (!C->getValue()->isZero())
10806 return getZero(C->getType());
10807 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10808 }
10809
10810 // We could implement others, but I really doubt anyone writes loops like
10811 // this, and if they did, they would already be constant folded.
10812 return getCouldNotCompute();
10813}
10814
10815std::pair<const BasicBlock *, const BasicBlock *>
10816ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10817 const {
10818 // If the block has a unique predecessor, then there is no path from the
10819 // predecessor to the block that does not go through the direct edge
10820 // from the predecessor to the block.
10821 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10822 return {Pred, BB};
10823
10824 // A loop's header is defined to be a block that dominates the loop.
10825 // If the header has a unique predecessor outside the loop, it must be
10826 // a block that has exactly one successor that can reach the loop.
10827 if (const Loop *L = LI.getLoopFor(BB))
10828 return {L->getLoopPredecessor(), L->getHeader()};
10829
10830 return {nullptr, BB};
10831}
10832
10833/// SCEV structural equivalence is usually sufficient for testing whether two
10834/// expressions are equal, however for the purposes of looking for a condition
10835/// guarding a loop, it can be useful to be a little more general, since a
10836/// front-end may have replicated the controlling expression.
10837static bool HasSameValue(const SCEV *A, const SCEV *B) {
10838 // Quick check to see if they are the same SCEV.
10839 if (A == B) return true;
10840
10841 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
10842 // Not all instructions that are "identical" compute the same value. For
10843 // instance, two distinct alloca instructions allocating the same type are
10844 // identical and do not read memory; but compute distinct values.
10845 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
10846 };
10847
10848 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
10849 // two different instructions with the same value. Check for this case.
10850 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
10851 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
10852 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
10853 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
10854 if (ComputesEqualValues(AI, BI))
10855 return true;
10856
10857 // Otherwise assume they may have a different value.
10858 return false;
10859}
10860
10861static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
10862 const SCEV *Op0, *Op1;
10863 if (!match(S, m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))))
10864 return false;
10865 if (match(Op0, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
10866 LHS = Op1;
10867 return true;
10868 }
10869 if (match(Op1, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
10870 LHS = Op0;
10871 return true;
10872 }
10873 return false;
10874}
10875
10877 const SCEV *&RHS, unsigned Depth) {
10878 bool Changed = false;
10879 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
10880 // '0 != 0'.
10881 auto TrivialCase = [&](bool TriviallyTrue) {
10883 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
10884 return true;
10885 };
10886 // If we hit the max recursion limit bail out.
10887 if (Depth >= 3)
10888 return false;
10889
10890 const SCEV *NewLHS, *NewRHS;
10891 if (match(LHS, m_scev_c_Mul(m_SCEV(NewLHS), m_SCEVVScale())) &&
10892 match(RHS, m_scev_c_Mul(m_SCEV(NewRHS), m_SCEVVScale()))) {
10893 const SCEVMulExpr *LMul = cast<SCEVMulExpr>(LHS);
10894 const SCEVMulExpr *RMul = cast<SCEVMulExpr>(RHS);
10895
10896 // (X * vscale) pred (Y * vscale) ==> X pred Y
10897 // when both multiples are NSW.
10898 // (X * vscale) uicmp/eq/ne (Y * vscale) ==> X uicmp/eq/ne Y
10899 // when both multiples are NUW.
10900 if ((LMul->hasNoSignedWrap() && RMul->hasNoSignedWrap()) ||
10901 (LMul->hasNoUnsignedWrap() && RMul->hasNoUnsignedWrap() &&
10902 !ICmpInst::isSigned(Pred))) {
10903 LHS = NewLHS;
10904 RHS = NewRHS;
10905 Changed = true;
10906 }
10907 }
10908
10909 // Canonicalize a constant to the right side.
10910 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
10911 // Check for both operands constant.
10912 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
10913 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
10914 return TrivialCase(false);
10915 return TrivialCase(true);
10916 }
10917 // Otherwise swap the operands to put the constant on the right.
10918 std::swap(LHS, RHS);
10920 Changed = true;
10921 }
10922
10923 // If we're comparing an addrec with a value which is loop-invariant in the
10924 // addrec's loop, put the addrec on the left. Also make a dominance check,
10925 // as both operands could be addrecs loop-invariant in each other's loop.
10926 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
10927 const Loop *L = AR->getLoop();
10928 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
10929 std::swap(LHS, RHS);
10931 Changed = true;
10932 }
10933 }
10934
10935 // If there's a constant operand, canonicalize comparisons with boundary
10936 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
10937 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
10938 const APInt &RA = RC->getAPInt();
10939
10940 bool SimplifiedByConstantRange = false;
10941
10942 if (!ICmpInst::isEquality(Pred)) {
10944 if (ExactCR.isFullSet())
10945 return TrivialCase(true);
10946 if (ExactCR.isEmptySet())
10947 return TrivialCase(false);
10948
10949 APInt NewRHS;
10950 CmpInst::Predicate NewPred;
10951 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
10952 ICmpInst::isEquality(NewPred)) {
10953 // We were able to convert an inequality to an equality.
10954 Pred = NewPred;
10955 RHS = getConstant(NewRHS);
10956 Changed = SimplifiedByConstantRange = true;
10957 }
10958 }
10959
10960 if (!SimplifiedByConstantRange) {
10961 switch (Pred) {
10962 default:
10963 break;
10964 case ICmpInst::ICMP_EQ:
10965 case ICmpInst::ICMP_NE:
10966 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
10967 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
10968 Changed = true;
10969 break;
10970
10971 // The "Should have been caught earlier!" messages refer to the fact
10972 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
10973 // should have fired on the corresponding cases, and canonicalized the
10974 // check to trivial case.
10975
10976 case ICmpInst::ICMP_UGE:
10977 assert(!RA.isMinValue() && "Should have been caught earlier!");
10978 Pred = ICmpInst::ICMP_UGT;
10979 RHS = getConstant(RA - 1);
10980 Changed = true;
10981 break;
10982 case ICmpInst::ICMP_ULE:
10983 assert(!RA.isMaxValue() && "Should have been caught earlier!");
10984 Pred = ICmpInst::ICMP_ULT;
10985 RHS = getConstant(RA + 1);
10986 Changed = true;
10987 break;
10988 case ICmpInst::ICMP_SGE:
10989 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
10990 Pred = ICmpInst::ICMP_SGT;
10991 RHS = getConstant(RA - 1);
10992 Changed = true;
10993 break;
10994 case ICmpInst::ICMP_SLE:
10995 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
10996 Pred = ICmpInst::ICMP_SLT;
10997 RHS = getConstant(RA + 1);
10998 Changed = true;
10999 break;
11000 }
11001 }
11002 }
11003
11004 // Check for obvious equality.
11005 if (HasSameValue(LHS, RHS)) {
11006 if (ICmpInst::isTrueWhenEqual(Pred))
11007 return TrivialCase(true);
11009 return TrivialCase(false);
11010 }
11011
11012 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
11013 // adding or subtracting 1 from one of the operands.
11014 switch (Pred) {
11015 case ICmpInst::ICMP_SLE:
11016 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
11017 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
11019 Pred = ICmpInst::ICMP_SLT;
11020 Changed = true;
11021 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
11022 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
11024 Pred = ICmpInst::ICMP_SLT;
11025 Changed = true;
11026 }
11027 break;
11028 case ICmpInst::ICMP_SGE:
11029 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
11030 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
11032 Pred = ICmpInst::ICMP_SGT;
11033 Changed = true;
11034 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
11035 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11037 Pred = ICmpInst::ICMP_SGT;
11038 Changed = true;
11039 }
11040 break;
11041 case ICmpInst::ICMP_ULE:
11042 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
11043 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
11045 Pred = ICmpInst::ICMP_ULT;
11046 Changed = true;
11047 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
11048 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
11049 Pred = ICmpInst::ICMP_ULT;
11050 Changed = true;
11051 }
11052 break;
11053 case ICmpInst::ICMP_UGE:
11054 // If RHS is an op we can fold the -1, try that first.
11055 // Otherwise prefer LHS to preserve the nuw flag.
11056 if ((isa<SCEVConstant>(RHS) ||
11058 isa<SCEVConstant>(cast<SCEVNAryExpr>(RHS)->getOperand(0)))) &&
11059 !getUnsignedRangeMin(RHS).isMinValue()) {
11060 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11061 Pred = ICmpInst::ICMP_UGT;
11062 Changed = true;
11063 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
11064 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11066 Pred = ICmpInst::ICMP_UGT;
11067 Changed = true;
11068 } else if (!getUnsignedRangeMin(RHS).isMinValue()) {
11069 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11070 Pred = ICmpInst::ICMP_UGT;
11071 Changed = true;
11072 }
11073 break;
11074 default:
11075 break;
11076 }
11077
11078 // TODO: More simplifications are possible here.
11079
11080 // Recursively simplify until we either hit a recursion limit or nothing
11081 // changes.
11082 if (Changed)
11083 (void)SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
11084
11085 return Changed;
11086}
11087
11089 return getSignedRangeMax(S).isNegative();
11090}
11091
11095
11097 return !getSignedRangeMin(S).isNegative();
11098}
11099
11103
11105 // Query push down for cases where the unsigned range is
11106 // less than sufficient.
11107 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
11108 return isKnownNonZero(SExt->getOperand(0));
11109 return getUnsignedRangeMin(S) != 0;
11110}
11111
11113 bool OrNegative) {
11114 auto NonRecursive = [OrNegative](const SCEV *S) {
11115 if (auto *C = dyn_cast<SCEVConstant>(S))
11116 return C->getAPInt().isPowerOf2() ||
11117 (OrNegative && C->getAPInt().isNegatedPowerOf2());
11118
11119 // vscale is a power-of-two.
11120 return isa<SCEVVScale>(S);
11121 };
11122
11123 if (NonRecursive(S))
11124 return true;
11125
11126 auto *Mul = dyn_cast<SCEVMulExpr>(S);
11127 if (!Mul)
11128 return false;
11129 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
11130}
11131
11133 const SCEV *S, uint64_t M,
11135 if (M == 0)
11136 return false;
11137 if (M == 1)
11138 return true;
11139
11140 // Recursively check AddRec operands. An AddRecExpr S is a multiple of M if S
11141 // starts with a multiple of M and at every iteration step S only adds
11142 // multiples of M.
11143 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
11144 return isKnownMultipleOf(AddRec->getStart(), M, Assumptions) &&
11145 isKnownMultipleOf(AddRec->getStepRecurrence(*this), M, Assumptions);
11146
11147 // For a constant, check that "S % M == 0".
11148 if (auto *Cst = dyn_cast<SCEVConstant>(S)) {
11149 APInt C = Cst->getAPInt();
11150 return C.urem(M) == 0;
11151 }
11152
11153 // TODO: Also check other SCEV expressions, i.e., SCEVAddRecExpr, etc.
11154
11155 // Basic tests have failed.
11156 // Check "S % M == 0" at compile time and record runtime Assumptions.
11157 auto *STy = dyn_cast<IntegerType>(S->getType());
11158 const SCEV *SmodM =
11159 getURemExpr(S, getConstant(ConstantInt::get(STy, M, false)));
11160 const SCEV *Zero = getZero(STy);
11161
11162 // Check whether "S % M == 0" is known at compile time.
11163 if (isKnownPredicate(ICmpInst::ICMP_EQ, SmodM, Zero))
11164 return true;
11165
11166 // Check whether "S % M != 0" is known at compile time.
11167 if (isKnownPredicate(ICmpInst::ICMP_NE, SmodM, Zero))
11168 return false;
11169
11171
11172 // Detect redundant predicates.
11173 for (auto *A : Assumptions)
11174 if (A->implies(P, *this))
11175 return true;
11176
11177 // Only record non-redundant predicates.
11178 Assumptions.push_back(P);
11179 return true;
11180}
11181
11183 return ((isKnownNonNegative(S1) && isKnownNonNegative(S2)) ||
11185}
11186
11187std::pair<const SCEV *, const SCEV *>
11189 // Compute SCEV on entry of loop L.
11190 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
11191 if (Start == getCouldNotCompute())
11192 return { Start, Start };
11193 // Compute post increment SCEV for loop L.
11194 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
11195 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
11196 return { Start, PostInc };
11197}
11198
11200 const SCEV *RHS) {
11201 // First collect all loops.
11203 getUsedLoops(LHS, LoopsUsed);
11204 getUsedLoops(RHS, LoopsUsed);
11205
11206 if (LoopsUsed.empty())
11207 return false;
11208
11209 // Domination relationship must be a linear order on collected loops.
11210#ifndef NDEBUG
11211 for (const auto *L1 : LoopsUsed)
11212 for (const auto *L2 : LoopsUsed)
11213 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
11214 DT.dominates(L2->getHeader(), L1->getHeader())) &&
11215 "Domination relationship is not a linear order");
11216#endif
11217
11218 const Loop *MDL =
11219 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
11220 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
11221 });
11222
11223 // Get init and post increment value for LHS.
11224 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
11225 // if LHS contains unknown non-invariant SCEV then bail out.
11226 if (SplitLHS.first == getCouldNotCompute())
11227 return false;
11228 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
11229 // Get init and post increment value for RHS.
11230 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
11231 // if RHS contains unknown non-invariant SCEV then bail out.
11232 if (SplitRHS.first == getCouldNotCompute())
11233 return false;
11234 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
11235 // It is possible that init SCEV contains an invariant load but it does
11236 // not dominate MDL and is not available at MDL loop entry, so we should
11237 // check it here.
11238 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
11239 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
11240 return false;
11241
11242 // It seems backedge guard check is faster than entry one so in some cases
11243 // it can speed up whole estimation by short circuit
11244 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
11245 SplitRHS.second) &&
11246 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
11247}
11248
11250 const SCEV *RHS) {
11251 // Canonicalize the inputs first.
11252 (void)SimplifyICmpOperands(Pred, LHS, RHS);
11253
11254 if (isKnownViaInduction(Pred, LHS, RHS))
11255 return true;
11256
11257 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
11258 return true;
11259
11260 // Otherwise see what can be done with some simple reasoning.
11261 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
11262}
11263
11265 const SCEV *LHS,
11266 const SCEV *RHS) {
11267 if (isKnownPredicate(Pred, LHS, RHS))
11268 return true;
11270 return false;
11271 return std::nullopt;
11272}
11273
11275 const SCEV *RHS,
11276 const Instruction *CtxI) {
11277 // TODO: Analyze guards and assumes from Context's block.
11278 return isKnownPredicate(Pred, LHS, RHS) ||
11279 isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS);
11280}
11281
11282std::optional<bool>
11284 const SCEV *RHS, const Instruction *CtxI) {
11285 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
11286 if (KnownWithoutContext)
11287 return KnownWithoutContext;
11288
11289 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
11290 return true;
11292 CtxI->getParent(), ICmpInst::getInverseCmpPredicate(Pred), LHS, RHS))
11293 return false;
11294 return std::nullopt;
11295}
11296
11298 const SCEVAddRecExpr *LHS,
11299 const SCEV *RHS) {
11300 const Loop *L = LHS->getLoop();
11301 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
11302 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
11303}
11304
11305std::optional<ScalarEvolution::MonotonicPredicateType>
11307 ICmpInst::Predicate Pred) {
11308 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
11309
11310#ifndef NDEBUG
11311 // Verify an invariant: inverting the predicate should turn a monotonically
11312 // increasing change to a monotonically decreasing one, and vice versa.
11313 if (Result) {
11314 auto ResultSwapped =
11315 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11316
11317 assert(*ResultSwapped != *Result &&
11318 "monotonicity should flip as we flip the predicate");
11319 }
11320#endif
11321
11322 return Result;
11323}
11324
11325std::optional<ScalarEvolution::MonotonicPredicateType>
11326ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11327 ICmpInst::Predicate Pred) {
11328 // A zero step value for LHS means the induction variable is essentially a
11329 // loop invariant value. We don't really depend on the predicate actually
11330 // flipping from false to true (for increasing predicates, and the other way
11331 // around for decreasing predicates), all we care about is that *if* the
11332 // predicate changes then it only changes from false to true.
11333 //
11334 // A zero step value in itself is not very useful, but there may be places
11335 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11336 // as general as possible.
11337
11338 // Only handle LE/LT/GE/GT predicates.
11339 if (!ICmpInst::isRelational(Pred))
11340 return std::nullopt;
11341
11342 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11343 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11344 "Should be greater or less!");
11345
11346 // Check that AR does not wrap.
11347 if (ICmpInst::isUnsigned(Pred)) {
11348 if (!LHS->hasNoUnsignedWrap())
11349 return std::nullopt;
11351 }
11352 assert(ICmpInst::isSigned(Pred) &&
11353 "Relational predicate is either signed or unsigned!");
11354 if (!LHS->hasNoSignedWrap())
11355 return std::nullopt;
11356
11357 const SCEV *Step = LHS->getStepRecurrence(*this);
11358
11359 if (isKnownNonNegative(Step))
11361
11362 if (isKnownNonPositive(Step))
11364
11365 return std::nullopt;
11366}
11367
11368std::optional<ScalarEvolution::LoopInvariantPredicate>
11370 const SCEV *RHS, const Loop *L,
11371 const Instruction *CtxI) {
11372 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11373 if (!isLoopInvariant(RHS, L)) {
11374 if (!isLoopInvariant(LHS, L))
11375 return std::nullopt;
11376
11377 std::swap(LHS, RHS);
11379 }
11380
11381 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11382 if (!ArLHS || ArLHS->getLoop() != L)
11383 return std::nullopt;
11384
11385 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11386 if (!MonotonicType)
11387 return std::nullopt;
11388 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11389 // true as the loop iterates, and the backedge is control dependent on
11390 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11391 //
11392 // * if the predicate was false in the first iteration then the predicate
11393 // is never evaluated again, since the loop exits without taking the
11394 // backedge.
11395 // * if the predicate was true in the first iteration then it will
11396 // continue to be true for all future iterations since it is
11397 // monotonically increasing.
11398 //
11399 // For both the above possibilities, we can replace the loop varying
11400 // predicate with its value on the first iteration of the loop (which is
11401 // loop invariant).
11402 //
11403 // A similar reasoning applies for a monotonically decreasing predicate, by
11404 // replacing true with false and false with true in the above two bullets.
11406 auto P = Increasing ? Pred : ICmpInst::getInverseCmpPredicate(Pred);
11407
11408 if (isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
11410 RHS);
11411
11412 if (!CtxI)
11413 return std::nullopt;
11414 // Try to prove via context.
11415 // TODO: Support other cases.
11416 switch (Pred) {
11417 default:
11418 break;
11419 case ICmpInst::ICMP_ULE:
11420 case ICmpInst::ICMP_ULT: {
11421 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11422 // Given preconditions
11423 // (1) ArLHS does not cross the border of positive and negative parts of
11424 // range because of:
11425 // - Positive step; (TODO: lift this limitation)
11426 // - nuw - does not cross zero boundary;
11427 // - nsw - does not cross SINT_MAX boundary;
11428 // (2) ArLHS <s RHS
11429 // (3) RHS >=s 0
11430 // we can replace the loop variant ArLHS <u RHS condition with loop
11431 // invariant Start(ArLHS) <u RHS.
11432 //
11433 // Because of (1) there are two options:
11434 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11435 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11436 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11437 // Because of (2) ArLHS <u RHS is trivially true.
11438 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11439 // We can strengthen this to Start(ArLHS) <u RHS.
11440 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11441 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11442 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11443 isKnownNonNegative(RHS) &&
11444 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11446 RHS);
11447 }
11448 }
11449
11450 return std::nullopt;
11451}
11452
11453std::optional<ScalarEvolution::LoopInvariantPredicate>
11455 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11456 const Instruction *CtxI, const SCEV *MaxIter) {
11458 Pred, LHS, RHS, L, CtxI, MaxIter))
11459 return LIP;
11460 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11461 // Number of iterations expressed as UMIN isn't always great for expressing
11462 // the value on the last iteration. If the straightforward approach didn't
11463 // work, try the following trick: if the a predicate is invariant for X, it
11464 // is also invariant for umin(X, ...). So try to find something that works
11465 // among subexpressions of MaxIter expressed as umin.
11466 for (auto *Op : UMin->operands())
11468 Pred, LHS, RHS, L, CtxI, Op))
11469 return LIP;
11470 return std::nullopt;
11471}
11472
11473std::optional<ScalarEvolution::LoopInvariantPredicate>
11475 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11476 const Instruction *CtxI, const SCEV *MaxIter) {
11477 // Try to prove the following set of facts:
11478 // - The predicate is monotonic in the iteration space.
11479 // - If the check does not fail on the 1st iteration:
11480 // - No overflow will happen during first MaxIter iterations;
11481 // - It will not fail on the MaxIter'th iteration.
11482 // If the check does fail on the 1st iteration, we leave the loop and no
11483 // other checks matter.
11484
11485 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11486 if (!isLoopInvariant(RHS, L)) {
11487 if (!isLoopInvariant(LHS, L))
11488 return std::nullopt;
11489
11490 std::swap(LHS, RHS);
11492 }
11493
11494 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11495 if (!AR || AR->getLoop() != L)
11496 return std::nullopt;
11497
11498 // Even if both are valid, we need to consistently chose the unsigned or the
11499 // signed predicate below, not mixtures of both. For now, prefer the unsigned
11500 // predicate.
11501 Pred = Pred.dropSameSign();
11502
11503 // The predicate must be relational (i.e. <, <=, >=, >).
11504 if (!ICmpInst::isRelational(Pred))
11505 return std::nullopt;
11506
11507 // TODO: Support steps other than +/- 1.
11508 const SCEV *Step = AR->getStepRecurrence(*this);
11509 auto *One = getOne(Step->getType());
11510 auto *MinusOne = getNegativeSCEV(One);
11511 if (Step != One && Step != MinusOne)
11512 return std::nullopt;
11513
11514 // Type mismatch here means that MaxIter is potentially larger than max
11515 // unsigned value in start type, which mean we cannot prove no wrap for the
11516 // indvar.
11517 if (AR->getType() != MaxIter->getType())
11518 return std::nullopt;
11519
11520 // Value of IV on suggested last iteration.
11521 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11522 // Does it still meet the requirement?
11523 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11524 return std::nullopt;
11525 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11526 // not exceed max unsigned value of this type), this effectively proves
11527 // that there is no wrap during the iteration. To prove that there is no
11528 // signed/unsigned wrap, we need to check that
11529 // Start <= Last for step = 1 or Start >= Last for step = -1.
11530 ICmpInst::Predicate NoOverflowPred =
11532 if (Step == MinusOne)
11533 NoOverflowPred = ICmpInst::getSwappedPredicate(NoOverflowPred);
11534 const SCEV *Start = AR->getStart();
11535 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11536 return std::nullopt;
11537
11538 // Everything is fine.
11539 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11540}
11541
11542bool ScalarEvolution::isKnownPredicateViaConstantRanges(CmpPredicate Pred,
11543 const SCEV *LHS,
11544 const SCEV *RHS) {
11545 if (HasSameValue(LHS, RHS))
11546 return ICmpInst::isTrueWhenEqual(Pred);
11547
11548 auto CheckRange = [&](bool IsSigned) {
11549 auto RangeLHS = IsSigned ? getSignedRange(LHS) : getUnsignedRange(LHS);
11550 auto RangeRHS = IsSigned ? getSignedRange(RHS) : getUnsignedRange(RHS);
11551 return RangeLHS.icmp(Pred, RangeRHS);
11552 };
11553
11554 // The check at the top of the function catches the case where the values are
11555 // known to be equal.
11556 if (Pred == CmpInst::ICMP_EQ)
11557 return false;
11558
11559 if (Pred == CmpInst::ICMP_NE) {
11560 if (CheckRange(true) || CheckRange(false))
11561 return true;
11562 auto *Diff = getMinusSCEV(LHS, RHS);
11563 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11564 }
11565
11566 return CheckRange(CmpInst::isSigned(Pred));
11567}
11568
11569bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred,
11570 const SCEV *LHS,
11571 const SCEV *RHS) {
11572 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11573 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11574 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11575 // OutC1 and OutC2.
11576 auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
11577 APInt &OutC1, APInt &OutC2,
11578 SCEV::NoWrapFlags ExpectedFlags) {
11579 const SCEV *XNonConstOp, *XConstOp;
11580 const SCEV *YNonConstOp, *YConstOp;
11581 SCEV::NoWrapFlags XFlagsPresent;
11582 SCEV::NoWrapFlags YFlagsPresent;
11583
11584 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11585 XConstOp = getZero(X->getType());
11586 XNonConstOp = X;
11587 XFlagsPresent = ExpectedFlags;
11588 }
11589 if (!isa<SCEVConstant>(XConstOp))
11590 return false;
11591
11592 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11593 YConstOp = getZero(Y->getType());
11594 YNonConstOp = Y;
11595 YFlagsPresent = ExpectedFlags;
11596 }
11597
11598 if (YNonConstOp != XNonConstOp)
11599 return false;
11600
11601 if (!isa<SCEVConstant>(YConstOp))
11602 return false;
11603
11604 // When matching ADDs with NUW flags (and unsigned predicates), only the
11605 // second ADD (with the larger constant) requires NUW.
11606 if ((YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11607 return false;
11608 if (ExpectedFlags != SCEV::FlagNUW &&
11609 (XFlagsPresent & ExpectedFlags) != ExpectedFlags) {
11610 return false;
11611 }
11612
11613 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11614 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11615
11616 return true;
11617 };
11618
11619 APInt C1;
11620 APInt C2;
11621
11622 switch (Pred) {
11623 default:
11624 break;
11625
11626 case ICmpInst::ICMP_SGE:
11627 std::swap(LHS, RHS);
11628 [[fallthrough]];
11629 case ICmpInst::ICMP_SLE:
11630 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11631 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11632 return true;
11633
11634 break;
11635
11636 case ICmpInst::ICMP_SGT:
11637 std::swap(LHS, RHS);
11638 [[fallthrough]];
11639 case ICmpInst::ICMP_SLT:
11640 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11641 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11642 return true;
11643
11644 break;
11645
11646 case ICmpInst::ICMP_UGE:
11647 std::swap(LHS, RHS);
11648 [[fallthrough]];
11649 case ICmpInst::ICMP_ULE:
11650 // (X + C1) u<= (X + C2)<nuw> for C1 u<= C2.
11651 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11652 return true;
11653
11654 break;
11655
11656 case ICmpInst::ICMP_UGT:
11657 std::swap(LHS, RHS);
11658 [[fallthrough]];
11659 case ICmpInst::ICMP_ULT:
11660 // (X + C1) u< (X + C2)<nuw> if C1 u< C2.
11661 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11662 return true;
11663 break;
11664 }
11665
11666 return false;
11667}
11668
11669bool ScalarEvolution::isKnownPredicateViaSplitting(CmpPredicate Pred,
11670 const SCEV *LHS,
11671 const SCEV *RHS) {
11672 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11673 return false;
11674
11675 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11676 // the stack can result in exponential time complexity.
11677 SaveAndRestore Restore(ProvingSplitPredicate, true);
11678
11679 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11680 //
11681 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11682 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11683 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11684 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11685 // use isKnownPredicate later if needed.
11686 return isKnownNonNegative(RHS) &&
11689}
11690
11691bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, CmpPredicate Pred,
11692 const SCEV *LHS, const SCEV *RHS) {
11693 // No need to even try if we know the module has no guards.
11694 if (!HasGuards)
11695 return false;
11696
11697 return any_of(*BB, [&](const Instruction &I) {
11698 using namespace llvm::PatternMatch;
11699
11700 Value *Condition;
11702 m_Value(Condition))) &&
11703 isImpliedCond(Pred, LHS, RHS, Condition, false);
11704 });
11705}
11706
11707/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11708/// protected by a conditional between LHS and RHS. This is used to
11709/// to eliminate casts.
11711 CmpPredicate Pred,
11712 const SCEV *LHS,
11713 const SCEV *RHS) {
11714 // Interpret a null as meaning no loop, where there is obviously no guard
11715 // (interprocedural conditions notwithstanding). Do not bother about
11716 // unreachable loops.
11717 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11718 return true;
11719
11720 if (VerifyIR)
11721 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11722 "This cannot be done on broken IR!");
11723
11724
11725 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11726 return true;
11727
11728 BasicBlock *Latch = L->getLoopLatch();
11729 if (!Latch)
11730 return false;
11731
11732 BranchInst *LoopContinuePredicate =
11734 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
11735 isImpliedCond(Pred, LHS, RHS,
11736 LoopContinuePredicate->getCondition(),
11737 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11738 return true;
11739
11740 // We don't want more than one activation of the following loops on the stack
11741 // -- that can lead to O(n!) time complexity.
11742 if (WalkingBEDominatingConds)
11743 return false;
11744
11745 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11746
11747 // See if we can exploit a trip count to prove the predicate.
11748 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11749 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11750 if (LatchBECount != getCouldNotCompute()) {
11751 // We know that Latch branches back to the loop header exactly
11752 // LatchBECount times. This means the backdege condition at Latch is
11753 // equivalent to "{0,+,1} u< LatchBECount".
11754 Type *Ty = LatchBECount->getType();
11755 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11756 const SCEV *LoopCounter =
11757 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11758 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11759 LatchBECount))
11760 return true;
11761 }
11762
11763 // Check conditions due to any @llvm.assume intrinsics.
11764 for (auto &AssumeVH : AC.assumptions()) {
11765 if (!AssumeVH)
11766 continue;
11767 auto *CI = cast<CallInst>(AssumeVH);
11768 if (!DT.dominates(CI, Latch->getTerminator()))
11769 continue;
11770
11771 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11772 return true;
11773 }
11774
11775 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11776 return true;
11777
11778 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11779 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11780 assert(DTN && "should reach the loop header before reaching the root!");
11781
11782 BasicBlock *BB = DTN->getBlock();
11783 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11784 return true;
11785
11786 BasicBlock *PBB = BB->getSinglePredecessor();
11787 if (!PBB)
11788 continue;
11789
11790 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
11791 if (!ContinuePredicate || !ContinuePredicate->isConditional())
11792 continue;
11793
11794 Value *Condition = ContinuePredicate->getCondition();
11795
11796 // If we have an edge `E` within the loop body that dominates the only
11797 // latch, the condition guarding `E` also guards the backedge. This
11798 // reasoning works only for loops with a single latch.
11799
11800 BasicBlockEdge DominatingEdge(PBB, BB);
11801 if (DominatingEdge.isSingleEdge()) {
11802 // We're constructively (and conservatively) enumerating edges within the
11803 // loop body that dominate the latch. The dominator tree better agree
11804 // with us on this:
11805 assert(DT.dominates(DominatingEdge, Latch) && "should be!");
11806
11807 if (isImpliedCond(Pred, LHS, RHS, Condition,
11808 BB != ContinuePredicate->getSuccessor(0)))
11809 return true;
11810 }
11811 }
11812
11813 return false;
11814}
11815
11817 CmpPredicate Pred,
11818 const SCEV *LHS,
11819 const SCEV *RHS) {
11820 // Do not bother proving facts for unreachable code.
11821 if (!DT.isReachableFromEntry(BB))
11822 return true;
11823 if (VerifyIR)
11824 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11825 "This cannot be done on broken IR!");
11826
11827 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11828 // the facts (a >= b && a != b) separately. A typical situation is when the
11829 // non-strict comparison is known from ranges and non-equality is known from
11830 // dominating predicates. If we are proving strict comparison, we always try
11831 // to prove non-equality and non-strict comparison separately.
11832 CmpPredicate NonStrictPredicate = ICmpInst::getNonStrictCmpPredicate(Pred);
11833 const bool ProvingStrictComparison =
11834 Pred != NonStrictPredicate.dropSameSign();
11835 bool ProvedNonStrictComparison = false;
11836 bool ProvedNonEquality = false;
11837
11838 auto SplitAndProve = [&](std::function<bool(CmpPredicate)> Fn) -> bool {
11839 if (!ProvedNonStrictComparison)
11840 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11841 if (!ProvedNonEquality)
11842 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11843 if (ProvedNonStrictComparison && ProvedNonEquality)
11844 return true;
11845 return false;
11846 };
11847
11848 if (ProvingStrictComparison) {
11849 auto ProofFn = [&](CmpPredicate P) {
11850 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
11851 };
11852 if (SplitAndProve(ProofFn))
11853 return true;
11854 }
11855
11856 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
11857 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
11858 const Instruction *CtxI = &BB->front();
11859 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
11860 return true;
11861 if (ProvingStrictComparison) {
11862 auto ProofFn = [&](CmpPredicate P) {
11863 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
11864 };
11865 if (SplitAndProve(ProofFn))
11866 return true;
11867 }
11868 return false;
11869 };
11870
11871 // Starting at the block's predecessor, climb up the predecessor chain, as long
11872 // as there are predecessors that can be found that have unique successors
11873 // leading to the original block.
11874 const Loop *ContainingLoop = LI.getLoopFor(BB);
11875 const BasicBlock *PredBB;
11876 if (ContainingLoop && ContainingLoop->getHeader() == BB)
11877 PredBB = ContainingLoop->getLoopPredecessor();
11878 else
11879 PredBB = BB->getSinglePredecessor();
11880 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
11881 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
11882 const BranchInst *BlockEntryPredicate =
11883 dyn_cast<BranchInst>(Pair.first->getTerminator());
11884 if (!BlockEntryPredicate || BlockEntryPredicate->isUnconditional())
11885 continue;
11886
11887 if (ProveViaCond(BlockEntryPredicate->getCondition(),
11888 BlockEntryPredicate->getSuccessor(0) != Pair.second))
11889 return true;
11890 }
11891
11892 // Check conditions due to any @llvm.assume intrinsics.
11893 for (auto &AssumeVH : AC.assumptions()) {
11894 if (!AssumeVH)
11895 continue;
11896 auto *CI = cast<CallInst>(AssumeVH);
11897 if (!DT.dominates(CI, BB))
11898 continue;
11899
11900 if (ProveViaCond(CI->getArgOperand(0), false))
11901 return true;
11902 }
11903
11904 // Check conditions due to any @llvm.experimental.guard intrinsics.
11905 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
11906 F.getParent(), Intrinsic::experimental_guard);
11907 if (GuardDecl)
11908 for (const auto *GU : GuardDecl->users())
11909 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
11910 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
11911 if (ProveViaCond(Guard->getArgOperand(0), false))
11912 return true;
11913 return false;
11914}
11915
11917 const SCEV *LHS,
11918 const SCEV *RHS) {
11919 // Interpret a null as meaning no loop, where there is obviously no guard
11920 // (interprocedural conditions notwithstanding).
11921 if (!L)
11922 return false;
11923
11924 // Both LHS and RHS must be available at loop entry.
11926 "LHS is not available at Loop Entry");
11928 "RHS is not available at Loop Entry");
11929
11930 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11931 return true;
11932
11933 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
11934}
11935
11936bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11937 const SCEV *RHS,
11938 const Value *FoundCondValue, bool Inverse,
11939 const Instruction *CtxI) {
11940 // False conditions implies anything. Do not bother analyzing it further.
11941 if (FoundCondValue ==
11942 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
11943 return true;
11944
11945 if (!PendingLoopPredicates.insert(FoundCondValue).second)
11946 return false;
11947
11948 llvm::scope_exit ClearOnExit(
11949 [&]() { PendingLoopPredicates.erase(FoundCondValue); });
11950
11951 // Recursively handle And and Or conditions.
11952 const Value *Op0, *Op1;
11953 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
11954 if (!Inverse)
11955 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11956 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11957 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
11958 if (Inverse)
11959 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11960 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11961 }
11962
11963 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
11964 if (!ICI) return false;
11965
11966 // Now that we found a conditional branch that dominates the loop or controls
11967 // the loop latch. Check to see if it is the comparison we are looking for.
11968 CmpPredicate FoundPred;
11969 if (Inverse)
11970 FoundPred = ICI->getInverseCmpPredicate();
11971 else
11972 FoundPred = ICI->getCmpPredicate();
11973
11974 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
11975 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
11976
11977 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
11978}
11979
11980bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11981 const SCEV *RHS, CmpPredicate FoundPred,
11982 const SCEV *FoundLHS, const SCEV *FoundRHS,
11983 const Instruction *CtxI) {
11984 // Balance the types.
11985 if (getTypeSizeInBits(LHS->getType()) <
11986 getTypeSizeInBits(FoundLHS->getType())) {
11987 // For unsigned and equality predicates, try to prove that both found
11988 // operands fit into narrow unsigned range. If so, try to prove facts in
11989 // narrow types.
11990 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
11991 !FoundRHS->getType()->isPointerTy()) {
11992 auto *NarrowType = LHS->getType();
11993 auto *WideType = FoundLHS->getType();
11994 auto BitWidth = getTypeSizeInBits(NarrowType);
11995 const SCEV *MaxValue = getZeroExtendExpr(
11997 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
11998 MaxValue) &&
11999 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
12000 MaxValue)) {
12001 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
12002 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
12003 // We cannot preserve samesign after truncation.
12004 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred.dropSameSign(),
12005 TruncFoundLHS, TruncFoundRHS, CtxI))
12006 return true;
12007 }
12008 }
12009
12010 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
12011 return false;
12012 if (CmpInst::isSigned(Pred)) {
12013 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
12014 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
12015 } else {
12016 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
12017 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
12018 }
12019 } else if (getTypeSizeInBits(LHS->getType()) >
12020 getTypeSizeInBits(FoundLHS->getType())) {
12021 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
12022 return false;
12023 if (CmpInst::isSigned(FoundPred)) {
12024 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
12025 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
12026 } else {
12027 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
12028 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
12029 }
12030 }
12031 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
12032 FoundRHS, CtxI);
12033}
12034
12035bool ScalarEvolution::isImpliedCondBalancedTypes(
12036 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
12037 const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) {
12039 getTypeSizeInBits(FoundLHS->getType()) &&
12040 "Types should be balanced!");
12041 // Canonicalize the query to match the way instcombine will have
12042 // canonicalized the comparison.
12043 if (SimplifyICmpOperands(Pred, LHS, RHS))
12044 if (LHS == RHS)
12045 return CmpInst::isTrueWhenEqual(Pred);
12046 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
12047 if (FoundLHS == FoundRHS)
12048 return CmpInst::isFalseWhenEqual(FoundPred);
12049
12050 // Check to see if we can make the LHS or RHS match.
12051 if (LHS == FoundRHS || RHS == FoundLHS) {
12052 if (isa<SCEVConstant>(RHS)) {
12053 std::swap(FoundLHS, FoundRHS);
12054 FoundPred = ICmpInst::getSwappedCmpPredicate(FoundPred);
12055 } else {
12056 std::swap(LHS, RHS);
12058 }
12059 }
12060
12061 // Check whether the found predicate is the same as the desired predicate.
12062 if (auto P = CmpPredicate::getMatching(FoundPred, Pred))
12063 return isImpliedCondOperands(*P, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12064
12065 // Check whether swapping the found predicate makes it the same as the
12066 // desired predicate.
12067 if (auto P = CmpPredicate::getMatching(
12068 ICmpInst::getSwappedCmpPredicate(FoundPred), Pred)) {
12069 // We can write the implication
12070 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
12071 // using one of the following ways:
12072 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
12073 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
12074 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
12075 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
12076 // Forms 1. and 2. require swapping the operands of one condition. Don't
12077 // do this if it would break canonical constant/addrec ordering.
12079 return isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P), RHS,
12080 LHS, FoundLHS, FoundRHS, CtxI);
12081 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
12082 return isImpliedCondOperands(*P, LHS, RHS, FoundRHS, FoundLHS, CtxI);
12083
12084 // There's no clear preference between forms 3. and 4., try both. Avoid
12085 // forming getNotSCEV of pointer values as the resulting subtract is
12086 // not legal.
12087 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
12088 isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P),
12089 getNotSCEV(LHS), getNotSCEV(RHS), FoundLHS,
12090 FoundRHS, CtxI))
12091 return true;
12092
12093 if (!FoundLHS->getType()->isPointerTy() &&
12094 !FoundRHS->getType()->isPointerTy() &&
12095 isImpliedCondOperands(*P, LHS, RHS, getNotSCEV(FoundLHS),
12096 getNotSCEV(FoundRHS), CtxI))
12097 return true;
12098
12099 return false;
12100 }
12101
12102 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
12104 assert(P1 != P2 && "Handled earlier!");
12105 return CmpInst::isRelational(P2) &&
12107 };
12108 if (IsSignFlippedPredicate(Pred, FoundPred)) {
12109 // Unsigned comparison is the same as signed comparison when both the
12110 // operands are non-negative or negative.
12111 if (haveSameSign(FoundLHS, FoundRHS))
12112 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12113 // Create local copies that we can freely swap and canonicalize our
12114 // conditions to "le/lt".
12115 CmpPredicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
12116 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
12117 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
12118 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
12119 CanonicalPred = ICmpInst::getSwappedCmpPredicate(CanonicalPred);
12120 CanonicalFoundPred = ICmpInst::getSwappedCmpPredicate(CanonicalFoundPred);
12121 std::swap(CanonicalLHS, CanonicalRHS);
12122 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
12123 }
12124 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
12125 "Must be!");
12126 assert((ICmpInst::isLT(CanonicalFoundPred) ||
12127 ICmpInst::isLE(CanonicalFoundPred)) &&
12128 "Must be!");
12129 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
12130 // Use implication:
12131 // x <u y && y >=s 0 --> x <s y.
12132 // If we can prove the left part, the right part is also proven.
12133 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12134 CanonicalRHS, CanonicalFoundLHS,
12135 CanonicalFoundRHS);
12136 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
12137 // Use implication:
12138 // x <s y && y <s 0 --> x <u y.
12139 // If we can prove the left part, the right part is also proven.
12140 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12141 CanonicalRHS, CanonicalFoundLHS,
12142 CanonicalFoundRHS);
12143 }
12144
12145 // Check if we can make progress by sharpening ranges.
12146 if (FoundPred == ICmpInst::ICMP_NE &&
12147 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
12148
12149 const SCEVConstant *C = nullptr;
12150 const SCEV *V = nullptr;
12151
12152 if (isa<SCEVConstant>(FoundLHS)) {
12153 C = cast<SCEVConstant>(FoundLHS);
12154 V = FoundRHS;
12155 } else {
12156 C = cast<SCEVConstant>(FoundRHS);
12157 V = FoundLHS;
12158 }
12159
12160 // The guarding predicate tells us that C != V. If the known range
12161 // of V is [C, t), we can sharpen the range to [C + 1, t). The
12162 // range we consider has to correspond to same signedness as the
12163 // predicate we're interested in folding.
12164
12165 APInt Min = ICmpInst::isSigned(Pred) ?
12167
12168 if (Min == C->getAPInt()) {
12169 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
12170 // This is true even if (Min + 1) wraps around -- in case of
12171 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
12172
12173 APInt SharperMin = Min + 1;
12174
12175 switch (Pred) {
12176 case ICmpInst::ICMP_SGE:
12177 case ICmpInst::ICMP_UGE:
12178 // We know V `Pred` SharperMin. If this implies LHS `Pred`
12179 // RHS, we're done.
12180 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
12181 CtxI))
12182 return true;
12183 [[fallthrough]];
12184
12185 case ICmpInst::ICMP_SGT:
12186 case ICmpInst::ICMP_UGT:
12187 // We know from the range information that (V `Pred` Min ||
12188 // V == Min). We know from the guarding condition that !(V
12189 // == Min). This gives us
12190 //
12191 // V `Pred` Min || V == Min && !(V == Min)
12192 // => V `Pred` Min
12193 //
12194 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
12195
12196 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
12197 return true;
12198 break;
12199
12200 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
12201 case ICmpInst::ICMP_SLE:
12202 case ICmpInst::ICMP_ULE:
12203 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12204 LHS, V, getConstant(SharperMin), CtxI))
12205 return true;
12206 [[fallthrough]];
12207
12208 case ICmpInst::ICMP_SLT:
12209 case ICmpInst::ICMP_ULT:
12210 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12211 LHS, V, getConstant(Min), CtxI))
12212 return true;
12213 break;
12214
12215 default:
12216 // No change
12217 break;
12218 }
12219 }
12220 }
12221
12222 // Check whether the actual condition is beyond sufficient.
12223 if (FoundPred == ICmpInst::ICMP_EQ)
12224 if (ICmpInst::isTrueWhenEqual(Pred))
12225 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12226 return true;
12227 if (Pred == ICmpInst::ICMP_NE)
12228 if (!ICmpInst::isTrueWhenEqual(FoundPred))
12229 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12230 return true;
12231
12232 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
12233 return true;
12234
12235 // Otherwise assume the worst.
12236 return false;
12237}
12238
12239bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
12240 const SCEV *&L, const SCEV *&R,
12241 SCEV::NoWrapFlags &Flags) {
12242 if (!match(Expr, m_scev_Add(m_SCEV(L), m_SCEV(R))))
12243 return false;
12244
12245 Flags = cast<SCEVAddExpr>(Expr)->getNoWrapFlags();
12246 return true;
12247}
12248
12249std::optional<APInt>
12251 // We avoid subtracting expressions here because this function is usually
12252 // fairly deep in the call stack (i.e. is called many times).
12253
12254 unsigned BW = getTypeSizeInBits(More->getType());
12255 APInt Diff(BW, 0);
12256 APInt DiffMul(BW, 1);
12257 // Try various simplifications to reduce the difference to a constant. Limit
12258 // the number of allowed simplifications to keep compile-time low.
12259 for (unsigned I = 0; I < 8; ++I) {
12260 if (More == Less)
12261 return Diff;
12262
12263 // Reduce addrecs with identical steps to their start value.
12265 const auto *LAR = cast<SCEVAddRecExpr>(Less);
12266 const auto *MAR = cast<SCEVAddRecExpr>(More);
12267
12268 if (LAR->getLoop() != MAR->getLoop())
12269 return std::nullopt;
12270
12271 // We look at affine expressions only; not for correctness but to keep
12272 // getStepRecurrence cheap.
12273 if (!LAR->isAffine() || !MAR->isAffine())
12274 return std::nullopt;
12275
12276 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
12277 return std::nullopt;
12278
12279 Less = LAR->getStart();
12280 More = MAR->getStart();
12281 continue;
12282 }
12283
12284 // Try to match a common constant multiply.
12285 auto MatchConstMul =
12286 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
12287 const APInt *C;
12288 const SCEV *Op;
12289 if (match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op))))
12290 return {{Op, *C}};
12291 return std::nullopt;
12292 };
12293 if (auto MatchedMore = MatchConstMul(More)) {
12294 if (auto MatchedLess = MatchConstMul(Less)) {
12295 if (MatchedMore->second == MatchedLess->second) {
12296 More = MatchedMore->first;
12297 Less = MatchedLess->first;
12298 DiffMul *= MatchedMore->second;
12299 continue;
12300 }
12301 }
12302 }
12303
12304 // Try to cancel out common factors in two add expressions.
12306 auto Add = [&](const SCEV *S, int Mul) {
12307 if (auto *C = dyn_cast<SCEVConstant>(S)) {
12308 if (Mul == 1) {
12309 Diff += C->getAPInt() * DiffMul;
12310 } else {
12311 assert(Mul == -1);
12312 Diff -= C->getAPInt() * DiffMul;
12313 }
12314 } else
12315 Multiplicity[S] += Mul;
12316 };
12317 auto Decompose = [&](const SCEV *S, int Mul) {
12318 if (isa<SCEVAddExpr>(S)) {
12319 for (const SCEV *Op : S->operands())
12320 Add(Op, Mul);
12321 } else
12322 Add(S, Mul);
12323 };
12324 Decompose(More, 1);
12325 Decompose(Less, -1);
12326
12327 // Check whether all the non-constants cancel out, or reduce to new
12328 // More/Less values.
12329 const SCEV *NewMore = nullptr, *NewLess = nullptr;
12330 for (const auto &[S, Mul] : Multiplicity) {
12331 if (Mul == 0)
12332 continue;
12333 if (Mul == 1) {
12334 if (NewMore)
12335 return std::nullopt;
12336 NewMore = S;
12337 } else if (Mul == -1) {
12338 if (NewLess)
12339 return std::nullopt;
12340 NewLess = S;
12341 } else
12342 return std::nullopt;
12343 }
12344
12345 // Values stayed the same, no point in trying further.
12346 if (NewMore == More || NewLess == Less)
12347 return std::nullopt;
12348
12349 More = NewMore;
12350 Less = NewLess;
12351
12352 // Reduced to constant.
12353 if (!More && !Less)
12354 return Diff;
12355
12356 // Left with variable on only one side, bail out.
12357 if (!More || !Less)
12358 return std::nullopt;
12359 }
12360
12361 // Did not reduce to constant.
12362 return std::nullopt;
12363}
12364
12365bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12366 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS,
12367 const SCEV *FoundRHS, const Instruction *CtxI) {
12368 // Try to recognize the following pattern:
12369 //
12370 // FoundRHS = ...
12371 // ...
12372 // loop:
12373 // FoundLHS = {Start,+,W}
12374 // context_bb: // Basic block from the same loop
12375 // known(Pred, FoundLHS, FoundRHS)
12376 //
12377 // If some predicate is known in the context of a loop, it is also known on
12378 // each iteration of this loop, including the first iteration. Therefore, in
12379 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12380 // prove the original pred using this fact.
12381 if (!CtxI)
12382 return false;
12383 const BasicBlock *ContextBB = CtxI->getParent();
12384 // Make sure AR varies in the context block.
12385 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12386 const Loop *L = AR->getLoop();
12387 // Make sure that context belongs to the loop and executes on 1st iteration
12388 // (if it ever executes at all).
12389 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12390 return false;
12391 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12392 return false;
12393 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12394 }
12395
12396 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12397 const Loop *L = AR->getLoop();
12398 // Make sure that context belongs to the loop and executes on 1st iteration
12399 // (if it ever executes at all).
12400 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12401 return false;
12402 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12403 return false;
12404 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12405 }
12406
12407 return false;
12408}
12409
12410bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(CmpPredicate Pred,
12411 const SCEV *LHS,
12412 const SCEV *RHS,
12413 const SCEV *FoundLHS,
12414 const SCEV *FoundRHS) {
12415 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12416 return false;
12417
12418 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12419 if (!AddRecLHS)
12420 return false;
12421
12422 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12423 if (!AddRecFoundLHS)
12424 return false;
12425
12426 // We'd like to let SCEV reason about control dependencies, so we constrain
12427 // both the inequalities to be about add recurrences on the same loop. This
12428 // way we can use isLoopEntryGuardedByCond later.
12429
12430 const Loop *L = AddRecFoundLHS->getLoop();
12431 if (L != AddRecLHS->getLoop())
12432 return false;
12433
12434 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12435 //
12436 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12437 // ... (2)
12438 //
12439 // Informal proof for (2), assuming (1) [*]:
12440 //
12441 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12442 //
12443 // Then
12444 //
12445 // FoundLHS s< FoundRHS s< INT_MIN - C
12446 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12447 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12448 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12449 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12450 // <=> FoundLHS + C s< FoundRHS + C
12451 //
12452 // [*]: (1) can be proved by ruling out overflow.
12453 //
12454 // [**]: This can be proved by analyzing all the four possibilities:
12455 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12456 // (A s>= 0, B s>= 0).
12457 //
12458 // Note:
12459 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12460 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12461 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12462 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12463 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12464 // C)".
12465
12466 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12467 if (!LDiff)
12468 return false;
12469 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12470 if (!RDiff || *LDiff != *RDiff)
12471 return false;
12472
12473 if (LDiff->isMinValue())
12474 return true;
12475
12476 APInt FoundRHSLimit;
12477
12478 if (Pred == CmpInst::ICMP_ULT) {
12479 FoundRHSLimit = -(*RDiff);
12480 } else {
12481 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12482 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12483 }
12484
12485 // Try to prove (1) or (2), as needed.
12486 return isAvailableAtLoopEntry(FoundRHS, L) &&
12487 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12488 getConstant(FoundRHSLimit));
12489}
12490
12491bool ScalarEvolution::isImpliedViaMerge(CmpPredicate Pred, const SCEV *LHS,
12492 const SCEV *RHS, const SCEV *FoundLHS,
12493 const SCEV *FoundRHS, unsigned Depth) {
12494 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12495
12496 llvm::scope_exit ClearOnExit([&]() {
12497 if (LPhi) {
12498 bool Erased = PendingMerges.erase(LPhi);
12499 assert(Erased && "Failed to erase LPhi!");
12500 (void)Erased;
12501 }
12502 if (RPhi) {
12503 bool Erased = PendingMerges.erase(RPhi);
12504 assert(Erased && "Failed to erase RPhi!");
12505 (void)Erased;
12506 }
12507 });
12508
12509 // Find respective Phis and check that they are not being pending.
12510 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12511 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12512 if (!PendingMerges.insert(Phi).second)
12513 return false;
12514 LPhi = Phi;
12515 }
12516 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12517 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12518 // If we detect a loop of Phi nodes being processed by this method, for
12519 // example:
12520 //
12521 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12522 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12523 //
12524 // we don't want to deal with a case that complex, so return conservative
12525 // answer false.
12526 if (!PendingMerges.insert(Phi).second)
12527 return false;
12528 RPhi = Phi;
12529 }
12530
12531 // If none of LHS, RHS is a Phi, nothing to do here.
12532 if (!LPhi && !RPhi)
12533 return false;
12534
12535 // If there is a SCEVUnknown Phi we are interested in, make it left.
12536 if (!LPhi) {
12537 std::swap(LHS, RHS);
12538 std::swap(FoundLHS, FoundRHS);
12539 std::swap(LPhi, RPhi);
12541 }
12542
12543 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12544 const BasicBlock *LBB = LPhi->getParent();
12545 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12546
12547 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12548 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12549 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12550 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12551 };
12552
12553 if (RPhi && RPhi->getParent() == LBB) {
12554 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12555 // If we compare two Phis from the same block, and for each entry block
12556 // the predicate is true for incoming values from this block, then the
12557 // predicate is also true for the Phis.
12558 for (const BasicBlock *IncBB : predecessors(LBB)) {
12559 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12560 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12561 if (!ProvedEasily(L, R))
12562 return false;
12563 }
12564 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12565 // Case two: RHS is also a Phi from the same basic block, and it is an
12566 // AddRec. It means that there is a loop which has both AddRec and Unknown
12567 // PHIs, for it we can compare incoming values of AddRec from above the loop
12568 // and latch with their respective incoming values of LPhi.
12569 // TODO: Generalize to handle loops with many inputs in a header.
12570 if (LPhi->getNumIncomingValues() != 2) return false;
12571
12572 auto *RLoop = RAR->getLoop();
12573 auto *Predecessor = RLoop->getLoopPredecessor();
12574 assert(Predecessor && "Loop with AddRec with no predecessor?");
12575 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12576 if (!ProvedEasily(L1, RAR->getStart()))
12577 return false;
12578 auto *Latch = RLoop->getLoopLatch();
12579 assert(Latch && "Loop with AddRec with no latch?");
12580 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12581 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12582 return false;
12583 } else {
12584 // In all other cases go over inputs of LHS and compare each of them to RHS,
12585 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12586 // At this point RHS is either a non-Phi, or it is a Phi from some block
12587 // different from LBB.
12588 for (const BasicBlock *IncBB : predecessors(LBB)) {
12589 // Check that RHS is available in this block.
12590 if (!dominates(RHS, IncBB))
12591 return false;
12592 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12593 // Make sure L does not refer to a value from a potentially previous
12594 // iteration of a loop.
12595 if (!properlyDominates(L, LBB))
12596 return false;
12597 // Addrecs are considered to properly dominate their loop, so are missed
12598 // by the previous check. Discard any values that have computable
12599 // evolution in this loop.
12600 if (auto *Loop = LI.getLoopFor(LBB))
12601 if (hasComputableLoopEvolution(L, Loop))
12602 return false;
12603 if (!ProvedEasily(L, RHS))
12604 return false;
12605 }
12606 }
12607 return true;
12608}
12609
12610bool ScalarEvolution::isImpliedCondOperandsViaShift(CmpPredicate Pred,
12611 const SCEV *LHS,
12612 const SCEV *RHS,
12613 const SCEV *FoundLHS,
12614 const SCEV *FoundRHS) {
12615 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12616 // sure that we are dealing with same LHS.
12617 if (RHS == FoundRHS) {
12618 std::swap(LHS, RHS);
12619 std::swap(FoundLHS, FoundRHS);
12621 }
12622 if (LHS != FoundLHS)
12623 return false;
12624
12625 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12626 if (!SUFoundRHS)
12627 return false;
12628
12629 Value *Shiftee, *ShiftValue;
12630
12631 using namespace PatternMatch;
12632 if (match(SUFoundRHS->getValue(),
12633 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12634 auto *ShifteeS = getSCEV(Shiftee);
12635 // Prove one of the following:
12636 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12637 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12638 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12639 // ---> LHS <s RHS
12640 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12641 // ---> LHS <=s RHS
12642 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12643 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12644 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12645 if (isKnownNonNegative(ShifteeS))
12646 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12647 }
12648
12649 return false;
12650}
12651
12652bool ScalarEvolution::isImpliedCondOperands(CmpPredicate Pred, const SCEV *LHS,
12653 const SCEV *RHS,
12654 const SCEV *FoundLHS,
12655 const SCEV *FoundRHS,
12656 const Instruction *CtxI) {
12657 return isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS,
12658 FoundRHS) ||
12659 isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS,
12660 FoundRHS) ||
12661 isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS) ||
12662 isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12663 CtxI) ||
12664 isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS);
12665}
12666
12667/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12668template <typename MinMaxExprType>
12669static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12670 const SCEV *Candidate) {
12671 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12672 if (!MinMaxExpr)
12673 return false;
12674
12675 return is_contained(MinMaxExpr->operands(), Candidate);
12676}
12677
12679 CmpPredicate Pred, const SCEV *LHS,
12680 const SCEV *RHS) {
12681 // If both sides are affine addrecs for the same loop, with equal
12682 // steps, and we know the recurrences don't wrap, then we only
12683 // need to check the predicate on the starting values.
12684
12685 if (!ICmpInst::isRelational(Pred))
12686 return false;
12687
12688 const SCEV *LStart, *RStart, *Step;
12689 const Loop *L;
12690 if (!match(LHS,
12691 m_scev_AffineAddRec(m_SCEV(LStart), m_SCEV(Step), m_Loop(L))) ||
12693 m_SpecificLoop(L))))
12694 return false;
12699 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12700 return false;
12701
12702 return SE.isKnownPredicate(Pred, LStart, RStart);
12703}
12704
12705/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12706/// expression?
12708 const SCEV *LHS, const SCEV *RHS) {
12709 switch (Pred) {
12710 default:
12711 return false;
12712
12713 case ICmpInst::ICMP_SGE:
12714 std::swap(LHS, RHS);
12715 [[fallthrough]];
12716 case ICmpInst::ICMP_SLE:
12717 return
12718 // min(A, ...) <= A
12720 // A <= max(A, ...)
12722
12723 case ICmpInst::ICMP_UGE:
12724 std::swap(LHS, RHS);
12725 [[fallthrough]];
12726 case ICmpInst::ICMP_ULE:
12727 return
12728 // min(A, ...) <= A
12729 // FIXME: what about umin_seq?
12731 // A <= max(A, ...)
12733 }
12734
12735 llvm_unreachable("covered switch fell through?!");
12736}
12737
12738bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS,
12739 const SCEV *RHS,
12740 const SCEV *FoundLHS,
12741 const SCEV *FoundRHS,
12742 unsigned Depth) {
12745 "LHS and RHS have different sizes?");
12746 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12747 getTypeSizeInBits(FoundRHS->getType()) &&
12748 "FoundLHS and FoundRHS have different sizes?");
12749 // We want to avoid hurting the compile time with analysis of too big trees.
12751 return false;
12752
12753 // We only want to work with GT comparison so far.
12754 if (ICmpInst::isLT(Pred)) {
12756 std::swap(LHS, RHS);
12757 std::swap(FoundLHS, FoundRHS);
12758 }
12759
12761
12762 // For unsigned, try to reduce it to corresponding signed comparison.
12763 if (P == ICmpInst::ICMP_UGT)
12764 // We can replace unsigned predicate with its signed counterpart if all
12765 // involved values are non-negative.
12766 // TODO: We could have better support for unsigned.
12767 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12768 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12769 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12770 // use this fact to prove that LHS and RHS are non-negative.
12771 const SCEV *MinusOne = getMinusOne(LHS->getType());
12772 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12773 FoundRHS) &&
12774 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12775 FoundRHS))
12777 }
12778
12779 if (P != ICmpInst::ICMP_SGT)
12780 return false;
12781
12782 auto GetOpFromSExt = [&](const SCEV *S) {
12783 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12784 return Ext->getOperand();
12785 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12786 // the constant in some cases.
12787 return S;
12788 };
12789
12790 // Acquire values from extensions.
12791 auto *OrigLHS = LHS;
12792 auto *OrigFoundLHS = FoundLHS;
12793 LHS = GetOpFromSExt(LHS);
12794 FoundLHS = GetOpFromSExt(FoundLHS);
12795
12796 // Is the SGT predicate can be proved trivially or using the found context.
12797 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12798 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12799 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12800 FoundRHS, Depth + 1);
12801 };
12802
12803 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12804 // We want to avoid creation of any new non-constant SCEV. Since we are
12805 // going to compare the operands to RHS, we should be certain that we don't
12806 // need any size extensions for this. So let's decline all cases when the
12807 // sizes of types of LHS and RHS do not match.
12808 // TODO: Maybe try to get RHS from sext to catch more cases?
12810 return false;
12811
12812 // Should not overflow.
12813 if (!LHSAddExpr->hasNoSignedWrap())
12814 return false;
12815
12816 auto *LL = LHSAddExpr->getOperand(0);
12817 auto *LR = LHSAddExpr->getOperand(1);
12818 auto *MinusOne = getMinusOne(RHS->getType());
12819
12820 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12821 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12822 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12823 };
12824 // Try to prove the following rule:
12825 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12826 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12827 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12828 return true;
12829 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12830 Value *LL, *LR;
12831 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12832
12833 using namespace llvm::PatternMatch;
12834
12835 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12836 // Rules for division.
12837 // We are going to perform some comparisons with Denominator and its
12838 // derivative expressions. In general case, creating a SCEV for it may
12839 // lead to a complex analysis of the entire graph, and in particular it
12840 // can request trip count recalculation for the same loop. This would
12841 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12842 // this, we only want to create SCEVs that are constants in this section.
12843 // So we bail if Denominator is not a constant.
12844 if (!isa<ConstantInt>(LR))
12845 return false;
12846
12847 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12848
12849 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
12850 // then a SCEV for the numerator already exists and matches with FoundLHS.
12851 auto *Numerator = getExistingSCEV(LL);
12852 if (!Numerator || Numerator->getType() != FoundLHS->getType())
12853 return false;
12854
12855 // Make sure that the numerator matches with FoundLHS and the denominator
12856 // is positive.
12857 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
12858 return false;
12859
12860 auto *DTy = Denominator->getType();
12861 auto *FRHSTy = FoundRHS->getType();
12862 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
12863 // One of types is a pointer and another one is not. We cannot extend
12864 // them properly to a wider type, so let us just reject this case.
12865 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
12866 // to avoid this check.
12867 return false;
12868
12869 // Given that:
12870 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
12871 auto *WTy = getWiderType(DTy, FRHSTy);
12872 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
12873 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
12874
12875 // Try to prove the following rule:
12876 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
12877 // For example, given that FoundLHS > 2. It means that FoundLHS is at
12878 // least 3. If we divide it by Denominator < 4, we will have at least 1.
12879 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
12880 if (isKnownNonPositive(RHS) &&
12881 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
12882 return true;
12883
12884 // Try to prove the following rule:
12885 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
12886 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
12887 // If we divide it by Denominator > 2, then:
12888 // 1. If FoundLHS is negative, then the result is 0.
12889 // 2. If FoundLHS is non-negative, then the result is non-negative.
12890 // Anyways, the result is non-negative.
12891 auto *MinusOne = getMinusOne(WTy);
12892 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
12893 if (isKnownNegative(RHS) &&
12894 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
12895 return true;
12896 }
12897 }
12898
12899 // If our expression contained SCEVUnknown Phis, and we split it down and now
12900 // need to prove something for them, try to prove the predicate for every
12901 // possible incoming values of those Phis.
12902 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
12903 return true;
12904
12905 return false;
12906}
12907
12909 const SCEV *RHS) {
12910 // zext x u<= sext x, sext x s<= zext x
12911 const SCEV *Op;
12912 switch (Pred) {
12913 case ICmpInst::ICMP_SGE:
12914 std::swap(LHS, RHS);
12915 [[fallthrough]];
12916 case ICmpInst::ICMP_SLE: {
12917 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12918 return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
12920 }
12921 case ICmpInst::ICMP_UGE:
12922 std::swap(LHS, RHS);
12923 [[fallthrough]];
12924 case ICmpInst::ICMP_ULE: {
12925 // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
12926 return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
12928 }
12929 default:
12930 return false;
12931 };
12932 llvm_unreachable("unhandled case");
12933}
12934
12935bool ScalarEvolution::isKnownViaNonRecursiveReasoning(CmpPredicate Pred,
12936 const SCEV *LHS,
12937 const SCEV *RHS) {
12938 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
12939 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
12940 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
12941 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
12942 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
12943}
12944
12945bool ScalarEvolution::isImpliedCondOperandsHelper(CmpPredicate Pred,
12946 const SCEV *LHS,
12947 const SCEV *RHS,
12948 const SCEV *FoundLHS,
12949 const SCEV *FoundRHS) {
12950 switch (Pred) {
12951 default:
12952 llvm_unreachable("Unexpected CmpPredicate value!");
12953 case ICmpInst::ICMP_EQ:
12954 case ICmpInst::ICMP_NE:
12955 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
12956 return true;
12957 break;
12958 case ICmpInst::ICMP_SLT:
12959 case ICmpInst::ICMP_SLE:
12960 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
12961 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
12962 return true;
12963 break;
12964 case ICmpInst::ICMP_SGT:
12965 case ICmpInst::ICMP_SGE:
12966 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
12967 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
12968 return true;
12969 break;
12970 case ICmpInst::ICMP_ULT:
12971 case ICmpInst::ICMP_ULE:
12972 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
12973 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
12974 return true;
12975 break;
12976 case ICmpInst::ICMP_UGT:
12977 case ICmpInst::ICMP_UGE:
12978 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
12979 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
12980 return true;
12981 break;
12982 }
12983
12984 // Maybe it can be proved via operations?
12985 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
12986 return true;
12987
12988 return false;
12989}
12990
12991bool ScalarEvolution::isImpliedCondOperandsViaRanges(
12992 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
12993 const SCEV *FoundLHS, const SCEV *FoundRHS) {
12994 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
12995 // The restriction on `FoundRHS` be lifted easily -- it exists only to
12996 // reduce the compile time impact of this optimization.
12997 return false;
12998
12999 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
13000 if (!Addend)
13001 return false;
13002
13003 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
13004
13005 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
13006 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
13007 ConstantRange FoundLHSRange =
13008 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
13009
13010 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
13011 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
13012
13013 // We can also compute the range of values for `LHS` that satisfy the
13014 // consequent, "`LHS` `Pred` `RHS`":
13015 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
13016 // The antecedent implies the consequent if every value of `LHS` that
13017 // satisfies the antecedent also satisfies the consequent.
13018 return LHSRange.icmp(Pred, ConstRHS);
13019}
13020
13021bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
13022 bool IsSigned) {
13023 assert(isKnownPositive(Stride) && "Positive stride expected!");
13024
13025 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
13026 const SCEV *One = getOne(Stride->getType());
13027
13028 if (IsSigned) {
13029 APInt MaxRHS = getSignedRangeMax(RHS);
13030 APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
13031 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
13032
13033 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
13034 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
13035 }
13036
13037 APInt MaxRHS = getUnsignedRangeMax(RHS);
13038 APInt MaxValue = APInt::getMaxValue(BitWidth);
13039 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13040
13041 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
13042 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
13043}
13044
13045bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
13046 bool IsSigned) {
13047
13048 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
13049 const SCEV *One = getOne(Stride->getType());
13050
13051 if (IsSigned) {
13052 APInt MinRHS = getSignedRangeMin(RHS);
13053 APInt MinValue = APInt::getSignedMinValue(BitWidth);
13054 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
13055
13056 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
13057 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
13058 }
13059
13060 APInt MinRHS = getUnsignedRangeMin(RHS);
13061 APInt MinValue = APInt::getMinValue(BitWidth);
13062 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13063
13064 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
13065 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
13066}
13067
13069 // umin(N, 1) + floor((N - umin(N, 1)) / D)
13070 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
13071 // expression fixes the case of N=0.
13072 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
13073 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
13074 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
13075}
13076
13077const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
13078 const SCEV *Stride,
13079 const SCEV *End,
13080 unsigned BitWidth,
13081 bool IsSigned) {
13082 // The logic in this function assumes we can represent a positive stride.
13083 // If we can't, the backedge-taken count must be zero.
13084 if (IsSigned && BitWidth == 1)
13085 return getZero(Stride->getType());
13086
13087 // This code below only been closely audited for negative strides in the
13088 // unsigned comparison case, it may be correct for signed comparison, but
13089 // that needs to be established.
13090 if (IsSigned && isKnownNegative(Stride))
13091 return getCouldNotCompute();
13092
13093 // Calculate the maximum backedge count based on the range of values
13094 // permitted by Start, End, and Stride.
13095 APInt MinStart =
13096 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
13097
13098 APInt MinStride =
13099 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
13100
13101 // We assume either the stride is positive, or the backedge-taken count
13102 // is zero. So force StrideForMaxBECount to be at least one.
13103 APInt One(BitWidth, 1);
13104 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
13105 : APIntOps::umax(One, MinStride);
13106
13107 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
13108 : APInt::getMaxValue(BitWidth);
13109 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
13110
13111 // Although End can be a MAX expression we estimate MaxEnd considering only
13112 // the case End = RHS of the loop termination condition. This is safe because
13113 // in the other case (End - Start) is zero, leading to a zero maximum backedge
13114 // taken count.
13115 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
13116 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
13117
13118 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
13119 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
13120 : APIntOps::umax(MaxEnd, MinStart);
13121
13122 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
13123 getConstant(StrideForMaxBECount) /* Step */);
13124}
13125
13127ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
13128 const Loop *L, bool IsSigned,
13129 bool ControlsOnlyExit, bool AllowPredicates) {
13131
13133 bool PredicatedIV = false;
13134 if (!IV) {
13135 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
13136 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
13137 if (AR && AR->getLoop() == L && AR->isAffine()) {
13138 auto canProveNUW = [&]() {
13139 // We can use the comparison to infer no-wrap flags only if it fully
13140 // controls the loop exit.
13141 if (!ControlsOnlyExit)
13142 return false;
13143
13144 if (!isLoopInvariant(RHS, L))
13145 return false;
13146
13147 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
13148 // We need the sequence defined by AR to strictly increase in the
13149 // unsigned integer domain for the logic below to hold.
13150 return false;
13151
13152 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
13153 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
13154 // If RHS <=u Limit, then there must exist a value V in the sequence
13155 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
13156 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
13157 // overflow occurs. This limit also implies that a signed comparison
13158 // (in the wide bitwidth) is equivalent to an unsigned comparison as
13159 // the high bits on both sides must be zero.
13160 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
13161 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
13162 Limit = Limit.zext(OuterBitWidth);
13163 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
13164 };
13165 auto Flags = AR->getNoWrapFlags();
13166 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
13167 Flags = setFlags(Flags, SCEV::FlagNUW);
13168
13169 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
13170 if (AR->hasNoUnsignedWrap()) {
13171 // Emulate what getZeroExtendExpr would have done during construction
13172 // if we'd been able to infer the fact just above at that time.
13173 const SCEV *Step = AR->getStepRecurrence(*this);
13174 Type *Ty = ZExt->getType();
13175 auto *S = getAddRecExpr(
13177 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
13179 }
13180 }
13181 }
13182 }
13183
13184
13185 if (!IV && AllowPredicates) {
13186 // Try to make this an AddRec using runtime tests, in the first X
13187 // iterations of this loop, where X is the SCEV expression found by the
13188 // algorithm below.
13189 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13190 PredicatedIV = true;
13191 }
13192
13193 // Avoid weird loops
13194 if (!IV || IV->getLoop() != L || !IV->isAffine())
13195 return getCouldNotCompute();
13196
13197 // A precondition of this method is that the condition being analyzed
13198 // reaches an exiting branch which dominates the latch. Given that, we can
13199 // assume that an increment which violates the nowrap specification and
13200 // produces poison must cause undefined behavior when the resulting poison
13201 // value is branched upon and thus we can conclude that the backedge is
13202 // taken no more often than would be required to produce that poison value.
13203 // Note that a well defined loop can exit on the iteration which violates
13204 // the nowrap specification if there is another exit (either explicit or
13205 // implicit/exceptional) which causes the loop to execute before the
13206 // exiting instruction we're analyzing would trigger UB.
13207 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13208 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13210
13211 const SCEV *Stride = IV->getStepRecurrence(*this);
13212
13213 bool PositiveStride = isKnownPositive(Stride);
13214
13215 // Avoid negative or zero stride values.
13216 if (!PositiveStride) {
13217 // We can compute the correct backedge taken count for loops with unknown
13218 // strides if we can prove that the loop is not an infinite loop with side
13219 // effects. Here's the loop structure we are trying to handle -
13220 //
13221 // i = start
13222 // do {
13223 // A[i] = i;
13224 // i += s;
13225 // } while (i < end);
13226 //
13227 // The backedge taken count for such loops is evaluated as -
13228 // (max(end, start + stride) - start - 1) /u stride
13229 //
13230 // The additional preconditions that we need to check to prove correctness
13231 // of the above formula is as follows -
13232 //
13233 // a) IV is either nuw or nsw depending upon signedness (indicated by the
13234 // NoWrap flag).
13235 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
13236 // no side effects within the loop)
13237 // c) loop has a single static exit (with no abnormal exits)
13238 //
13239 // Precondition a) implies that if the stride is negative, this is a single
13240 // trip loop. The backedge taken count formula reduces to zero in this case.
13241 //
13242 // Precondition b) and c) combine to imply that if rhs is invariant in L,
13243 // then a zero stride means the backedge can't be taken without executing
13244 // undefined behavior.
13245 //
13246 // The positive stride case is the same as isKnownPositive(Stride) returning
13247 // true (original behavior of the function).
13248 //
13249 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
13251 return getCouldNotCompute();
13252
13253 if (!isKnownNonZero(Stride)) {
13254 // If we have a step of zero, and RHS isn't invariant in L, we don't know
13255 // if it might eventually be greater than start and if so, on which
13256 // iteration. We can't even produce a useful upper bound.
13257 if (!isLoopInvariant(RHS, L))
13258 return getCouldNotCompute();
13259
13260 // We allow a potentially zero stride, but we need to divide by stride
13261 // below. Since the loop can't be infinite and this check must control
13262 // the sole exit, we can infer the exit must be taken on the first
13263 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
13264 // we know the numerator in the divides below must be zero, so we can
13265 // pick an arbitrary non-zero value for the denominator (e.g. stride)
13266 // and produce the right result.
13267 // FIXME: Handle the case where Stride is poison?
13268 auto wouldZeroStrideBeUB = [&]() {
13269 // Proof by contradiction. Suppose the stride were zero. If we can
13270 // prove that the backedge *is* taken on the first iteration, then since
13271 // we know this condition controls the sole exit, we must have an
13272 // infinite loop. We can't have a (well defined) infinite loop per
13273 // check just above.
13274 // Note: The (Start - Stride) term is used to get the start' term from
13275 // (start' + stride,+,stride). Remember that we only care about the
13276 // result of this expression when stride == 0 at runtime.
13277 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
13278 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
13279 };
13280 if (!wouldZeroStrideBeUB()) {
13281 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
13282 }
13283 }
13284 } else if (!NoWrap) {
13285 // Avoid proven overflow cases: this will ensure that the backedge taken
13286 // count will not generate any unsigned overflow.
13287 if (canIVOverflowOnLT(RHS, Stride, IsSigned))
13288 return getCouldNotCompute();
13289 }
13290
13291 // On all paths just preceeding, we established the following invariant:
13292 // IV can be assumed not to overflow up to and including the exiting
13293 // iteration. We proved this in one of two ways:
13294 // 1) We can show overflow doesn't occur before the exiting iteration
13295 // 1a) canIVOverflowOnLT, and b) step of one
13296 // 2) We can show that if overflow occurs, the loop must execute UB
13297 // before any possible exit.
13298 // Note that we have not yet proved RHS invariant (in general).
13299
13300 const SCEV *Start = IV->getStart();
13301
13302 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
13303 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
13304 // Use integer-typed versions for actual computation; we can't subtract
13305 // pointers in general.
13306 const SCEV *OrigStart = Start;
13307 const SCEV *OrigRHS = RHS;
13308 if (Start->getType()->isPointerTy()) {
13310 if (isa<SCEVCouldNotCompute>(Start))
13311 return Start;
13312 }
13313 if (RHS->getType()->isPointerTy()) {
13316 return RHS;
13317 }
13318
13319 const SCEV *End = nullptr, *BECount = nullptr,
13320 *BECountIfBackedgeTaken = nullptr;
13321 if (!isLoopInvariant(RHS, L)) {
13322 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13323 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13324 RHSAddRec->getNoWrapFlags()) {
13325 // The structure of loop we are trying to calculate backedge count of:
13326 //
13327 // left = left_start
13328 // right = right_start
13329 //
13330 // while(left < right){
13331 // ... do something here ...
13332 // left += s1; // stride of left is s1 (s1 > 0)
13333 // right += s2; // stride of right is s2 (s2 < 0)
13334 // }
13335 //
13336
13337 const SCEV *RHSStart = RHSAddRec->getStart();
13338 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13339
13340 // If Stride - RHSStride is positive and does not overflow, we can write
13341 // backedge count as ->
13342 // ceil((End - Start) /u (Stride - RHSStride))
13343 // Where, End = max(RHSStart, Start)
13344
13345 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13346 if (isKnownNegative(RHSStride) &&
13347 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13348 RHSStride)) {
13349
13350 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13351 if (isKnownPositive(Denominator)) {
13352 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13353 : getUMaxExpr(RHSStart, Start);
13354
13355 // We can do this because End >= Start, as End = max(RHSStart, Start)
13356 const SCEV *Delta = getMinusSCEV(End, Start);
13357
13358 BECount = getUDivCeilSCEV(Delta, Denominator);
13359 BECountIfBackedgeTaken =
13360 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13361 }
13362 }
13363 }
13364 if (BECount == nullptr) {
13365 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13366 // given the start, stride and max value for the end bound of the
13367 // loop (RHS), and the fact that IV does not overflow (which is
13368 // checked above).
13369 const SCEV *MaxBECount = computeMaxBECountForLT(
13370 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13371 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13372 MaxBECount, false /*MaxOrZero*/, Predicates);
13373 }
13374 } else {
13375 // We use the expression (max(End,Start)-Start)/Stride to describe the
13376 // backedge count, as if the backedge is taken at least once
13377 // max(End,Start) is End and so the result is as above, and if not
13378 // max(End,Start) is Start so we get a backedge count of zero.
13379 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13380 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13381 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13382 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13383 // Can we prove (max(RHS,Start) > Start - Stride?
13384 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13385 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13386 // In this case, we can use a refined formula for computing backedge
13387 // taken count. The general formula remains:
13388 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13389 // We want to use the alternate formula:
13390 // "((End - 1) - (Start - Stride)) /u Stride"
13391 // Let's do a quick case analysis to show these are equivalent under
13392 // our precondition that max(RHS,Start) > Start - Stride.
13393 // * For RHS <= Start, the backedge-taken count must be zero.
13394 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13395 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13396 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13397 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13398 // reducing this to the stride of 1 case.
13399 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13400 // Stride".
13401 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13402 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13403 // "((RHS - (Start - Stride) - 1) /u Stride".
13404 // Our preconditions trivially imply no overflow in that form.
13405 const SCEV *MinusOne = getMinusOne(Stride->getType());
13406 const SCEV *Numerator =
13407 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13408 BECount = getUDivExpr(Numerator, Stride);
13409 }
13410
13411 if (!BECount) {
13412 auto canProveRHSGreaterThanEqualStart = [&]() {
13413 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13414 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13415 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13416
13417 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13418 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13419 return true;
13420
13421 // (RHS > Start - 1) implies RHS >= Start.
13422 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13423 // "Start - 1" doesn't overflow.
13424 // * For signed comparison, if Start - 1 does overflow, it's equal
13425 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13426 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13427 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13428 //
13429 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13430 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13431 auto *StartMinusOne =
13432 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13433 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13434 };
13435
13436 // If we know that RHS >= Start in the context of loop, then we know
13437 // that max(RHS, Start) = RHS at this point.
13438 if (canProveRHSGreaterThanEqualStart()) {
13439 End = RHS;
13440 } else {
13441 // If RHS < Start, the backedge will be taken zero times. So in
13442 // general, we can write the backedge-taken count as:
13443 //
13444 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13445 //
13446 // We convert it to the following to make it more convenient for SCEV:
13447 //
13448 // ceil(max(RHS, Start) - Start) / Stride
13449 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13450
13451 // See what would happen if we assume the backedge is taken. This is
13452 // used to compute MaxBECount.
13453 BECountIfBackedgeTaken =
13454 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13455 }
13456
13457 // At this point, we know:
13458 //
13459 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13460 // 2. The index variable doesn't overflow.
13461 //
13462 // Therefore, we know N exists such that
13463 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13464 // doesn't overflow.
13465 //
13466 // Using this information, try to prove whether the addition in
13467 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13468 const SCEV *One = getOne(Stride->getType());
13469 bool MayAddOverflow = [&] {
13470 if (isKnownToBeAPowerOfTwo(Stride)) {
13471 // Suppose Stride is a power of two, and Start/End are unsigned
13472 // integers. Let UMAX be the largest representable unsigned
13473 // integer.
13474 //
13475 // By the preconditions of this function, we know
13476 // "(Start + Stride * N) >= End", and this doesn't overflow.
13477 // As a formula:
13478 //
13479 // End <= (Start + Stride * N) <= UMAX
13480 //
13481 // Subtracting Start from all the terms:
13482 //
13483 // End - Start <= Stride * N <= UMAX - Start
13484 //
13485 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13486 //
13487 // End - Start <= Stride * N <= UMAX
13488 //
13489 // Stride * N is a multiple of Stride. Therefore,
13490 //
13491 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13492 //
13493 // Since Stride is a power of two, UMAX + 1 is divisible by
13494 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13495 // write:
13496 //
13497 // End - Start <= Stride * N <= UMAX - Stride - 1
13498 //
13499 // Dropping the middle term:
13500 //
13501 // End - Start <= UMAX - Stride - 1
13502 //
13503 // Adding Stride - 1 to both sides:
13504 //
13505 // (End - Start) + (Stride - 1) <= UMAX
13506 //
13507 // In other words, the addition doesn't have unsigned overflow.
13508 //
13509 // A similar proof works if we treat Start/End as signed values.
13510 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13511 // to use signed max instead of unsigned max. Note that we're
13512 // trying to prove a lack of unsigned overflow in either case.
13513 return false;
13514 }
13515 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13516 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13517 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13518 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13519 // 1 <s End.
13520 //
13521 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13522 // End.
13523 return false;
13524 }
13525 return true;
13526 }();
13527
13528 const SCEV *Delta = getMinusSCEV(End, Start);
13529 if (!MayAddOverflow) {
13530 // floor((D + (S - 1)) / S)
13531 // We prefer this formulation if it's legal because it's fewer
13532 // operations.
13533 BECount =
13534 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13535 } else {
13536 BECount = getUDivCeilSCEV(Delta, Stride);
13537 }
13538 }
13539 }
13540
13541 const SCEV *ConstantMaxBECount;
13542 bool MaxOrZero = false;
13543 if (isa<SCEVConstant>(BECount)) {
13544 ConstantMaxBECount = BECount;
13545 } else if (BECountIfBackedgeTaken &&
13546 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13547 // If we know exactly how many times the backedge will be taken if it's
13548 // taken at least once, then the backedge count will either be that or
13549 // zero.
13550 ConstantMaxBECount = BECountIfBackedgeTaken;
13551 MaxOrZero = true;
13552 } else {
13553 ConstantMaxBECount = computeMaxBECountForLT(
13554 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13555 }
13556
13557 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13558 !isa<SCEVCouldNotCompute>(BECount))
13559 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13560
13561 const SCEV *SymbolicMaxBECount =
13562 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13563 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13564 Predicates);
13565}
13566
13567ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13568 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13569 bool ControlsOnlyExit, bool AllowPredicates) {
13571 // We handle only IV > Invariant
13572 if (!isLoopInvariant(RHS, L))
13573 return getCouldNotCompute();
13574
13575 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13576 if (!IV && AllowPredicates)
13577 // Try to make this an AddRec using runtime tests, in the first X
13578 // iterations of this loop, where X is the SCEV expression found by the
13579 // algorithm below.
13580 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13581
13582 // Avoid weird loops
13583 if (!IV || IV->getLoop() != L || !IV->isAffine())
13584 return getCouldNotCompute();
13585
13586 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13587 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13589
13590 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13591
13592 // Avoid negative or zero stride values
13593 if (!isKnownPositive(Stride))
13594 return getCouldNotCompute();
13595
13596 // Avoid proven overflow cases: this will ensure that the backedge taken count
13597 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13598 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13599 // behaviors like the case of C language.
13600 if (!Stride->isOne() && !NoWrap)
13601 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13602 return getCouldNotCompute();
13603
13604 const SCEV *Start = IV->getStart();
13605 const SCEV *End = RHS;
13606 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13607 // If we know that Start >= RHS in the context of loop, then we know that
13608 // min(RHS, Start) = RHS at this point.
13610 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13611 End = RHS;
13612 else
13613 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13614 }
13615
13616 if (Start->getType()->isPointerTy()) {
13618 if (isa<SCEVCouldNotCompute>(Start))
13619 return Start;
13620 }
13621 if (End->getType()->isPointerTy()) {
13622 End = getLosslessPtrToIntExpr(End);
13623 if (isa<SCEVCouldNotCompute>(End))
13624 return End;
13625 }
13626
13627 // Compute ((Start - End) + (Stride - 1)) / Stride.
13628 // FIXME: This can overflow. Holding off on fixing this for now;
13629 // howManyGreaterThans will hopefully be gone soon.
13630 const SCEV *One = getOne(Stride->getType());
13631 const SCEV *BECount = getUDivExpr(
13632 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13633
13634 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13636
13637 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13638 : getUnsignedRangeMin(Stride);
13639
13640 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13641 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13642 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13643
13644 // Although End can be a MIN expression we estimate MinEnd considering only
13645 // the case End = RHS. This is safe because in the other case (Start - End)
13646 // is zero, leading to a zero maximum backedge taken count.
13647 APInt MinEnd =
13648 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13649 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13650
13651 const SCEV *ConstantMaxBECount =
13652 isa<SCEVConstant>(BECount)
13653 ? BECount
13654 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13655 getConstant(MinStride));
13656
13657 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13658 ConstantMaxBECount = BECount;
13659 const SCEV *SymbolicMaxBECount =
13660 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13661
13662 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13663 Predicates);
13664}
13665
13667 ScalarEvolution &SE) const {
13668 if (Range.isFullSet()) // Infinite loop.
13669 return SE.getCouldNotCompute();
13670
13671 // If the start is a non-zero constant, shift the range to simplify things.
13672 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13673 if (!SC->getValue()->isZero()) {
13675 Operands[0] = SE.getZero(SC->getType());
13676 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13678 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13679 return ShiftedAddRec->getNumIterationsInRange(
13680 Range.subtract(SC->getAPInt()), SE);
13681 // This is strange and shouldn't happen.
13682 return SE.getCouldNotCompute();
13683 }
13684
13685 // The only time we can solve this is when we have all constant indices.
13686 // Otherwise, we cannot determine the overflow conditions.
13687 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13688 return SE.getCouldNotCompute();
13689
13690 // Okay at this point we know that all elements of the chrec are constants and
13691 // that the start element is zero.
13692
13693 // First check to see if the range contains zero. If not, the first
13694 // iteration exits.
13695 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13696 if (!Range.contains(APInt(BitWidth, 0)))
13697 return SE.getZero(getType());
13698
13699 if (isAffine()) {
13700 // If this is an affine expression then we have this situation:
13701 // Solve {0,+,A} in Range === Ax in Range
13702
13703 // We know that zero is in the range. If A is positive then we know that
13704 // the upper value of the range must be the first possible exit value.
13705 // If A is negative then the lower of the range is the last possible loop
13706 // value. Also note that we already checked for a full range.
13707 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13708 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13709
13710 // The exit value should be (End+A)/A.
13711 APInt ExitVal = (End + A).udiv(A);
13712 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13713
13714 // Evaluate at the exit value. If we really did fall out of the valid
13715 // range, then we computed our trip count, otherwise wrap around or other
13716 // things must have happened.
13717 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13718 if (Range.contains(Val->getValue()))
13719 return SE.getCouldNotCompute(); // Something strange happened
13720
13721 // Ensure that the previous value is in the range.
13722 assert(Range.contains(
13724 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13725 "Linear scev computation is off in a bad way!");
13726 return SE.getConstant(ExitValue);
13727 }
13728
13729 if (isQuadratic()) {
13730 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13731 return SE.getConstant(*S);
13732 }
13733
13734 return SE.getCouldNotCompute();
13735}
13736
13737const SCEVAddRecExpr *
13739 assert(getNumOperands() > 1 && "AddRec with zero step?");
13740 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13741 // but in this case we cannot guarantee that the value returned will be an
13742 // AddRec because SCEV does not have a fixed point where it stops
13743 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13744 // may happen if we reach arithmetic depth limit while simplifying. So we
13745 // construct the returned value explicitly.
13747 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13748 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13749 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13750 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13751 // We know that the last operand is not a constant zero (otherwise it would
13752 // have been popped out earlier). This guarantees us that if the result has
13753 // the same last operand, then it will also not be popped out, meaning that
13754 // the returned value will be an AddRec.
13755 const SCEV *Last = getOperand(getNumOperands() - 1);
13756 assert(!Last->isZero() && "Recurrency with zero step?");
13757 Ops.push_back(Last);
13760}
13761
13762// Return true when S contains at least an undef value.
13764 return SCEVExprContains(
13765 S, [](const SCEV *S) { return match(S, m_scev_UndefOrPoison()); });
13766}
13767
13768// Return true when S contains a value that is a nullptr.
13770 return SCEVExprContains(S, [](const SCEV *S) {
13771 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13772 return SU->getValue() == nullptr;
13773 return false;
13774 });
13775}
13776
13777/// Return the size of an element read or written by Inst.
13779 Type *Ty;
13780 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13781 Ty = Store->getValueOperand()->getType();
13782 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13783 Ty = Load->getType();
13784 else
13785 return nullptr;
13786
13788 return getSizeOfExpr(ETy, Ty);
13789}
13790
13791//===----------------------------------------------------------------------===//
13792// SCEVCallbackVH Class Implementation
13793//===----------------------------------------------------------------------===//
13794
13796 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13797 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13798 SE->ConstantEvolutionLoopExitValue.erase(PN);
13799 SE->eraseValueFromMap(getValPtr());
13800 // this now dangles!
13801}
13802
13803void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13804 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13805
13806 // Forget all the expressions associated with users of the old value,
13807 // so that future queries will recompute the expressions using the new
13808 // value.
13809 SE->forgetValue(getValPtr());
13810 // this now dangles!
13811}
13812
13813ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13814 : CallbackVH(V), SE(se) {}
13815
13816//===----------------------------------------------------------------------===//
13817// ScalarEvolution Class Implementation
13818//===----------------------------------------------------------------------===//
13819
13822 LoopInfo &LI)
13823 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
13824 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13825 LoopDispositions(64), BlockDispositions(64) {
13826 // To use guards for proving predicates, we need to scan every instruction in
13827 // relevant basic blocks, and not just terminators. Doing this is a waste of
13828 // time if the IR does not actually contain any calls to
13829 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13830 //
13831 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13832 // to _add_ guards to the module when there weren't any before, and wants
13833 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13834 // efficient in lieu of being smart in that rather obscure case.
13835
13836 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
13837 F.getParent(), Intrinsic::experimental_guard);
13838 HasGuards = GuardDecl && !GuardDecl->use_empty();
13839}
13840
13842 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
13843 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13844 ValueExprMap(std::move(Arg.ValueExprMap)),
13845 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13846 PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
13847 PendingMerges(std::move(Arg.PendingMerges)),
13848 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
13849 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13850 PredicatedBackedgeTakenCounts(
13851 std::move(Arg.PredicatedBackedgeTakenCounts)),
13852 BECountUsers(std::move(Arg.BECountUsers)),
13853 ConstantEvolutionLoopExitValue(
13854 std::move(Arg.ConstantEvolutionLoopExitValue)),
13855 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
13856 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
13857 LoopDispositions(std::move(Arg.LoopDispositions)),
13858 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
13859 BlockDispositions(std::move(Arg.BlockDispositions)),
13860 SCEVUsers(std::move(Arg.SCEVUsers)),
13861 UnsignedRanges(std::move(Arg.UnsignedRanges)),
13862 SignedRanges(std::move(Arg.SignedRanges)),
13863 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
13864 UniquePreds(std::move(Arg.UniquePreds)),
13865 SCEVAllocator(std::move(Arg.SCEVAllocator)),
13866 LoopUsers(std::move(Arg.LoopUsers)),
13867 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
13868 FirstUnknown(Arg.FirstUnknown) {
13869 Arg.FirstUnknown = nullptr;
13870}
13871
13873 // Iterate through all the SCEVUnknown instances and call their
13874 // destructors, so that they release their references to their values.
13875 for (SCEVUnknown *U = FirstUnknown; U;) {
13876 SCEVUnknown *Tmp = U;
13877 U = U->Next;
13878 Tmp->~SCEVUnknown();
13879 }
13880 FirstUnknown = nullptr;
13881
13882 ExprValueMap.clear();
13883 ValueExprMap.clear();
13884 HasRecMap.clear();
13885 BackedgeTakenCounts.clear();
13886 PredicatedBackedgeTakenCounts.clear();
13887
13888 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
13889 assert(PendingPhiRanges.empty() && "getRangeRef garbage");
13890 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
13891 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
13892 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
13893}
13894
13898
13899/// When printing a top-level SCEV for trip counts, it's helpful to include
13900/// a type for constants which are otherwise hard to disambiguate.
13901static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
13902 if (isa<SCEVConstant>(S))
13903 OS << *S->getType() << " ";
13904 OS << *S;
13905}
13906
13908 const Loop *L) {
13909 // Print all inner loops first
13910 for (Loop *I : *L)
13911 PrintLoopInfo(OS, SE, I);
13912
13913 OS << "Loop ";
13914 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13915 OS << ": ";
13916
13917 SmallVector<BasicBlock *, 8> ExitingBlocks;
13918 L->getExitingBlocks(ExitingBlocks);
13919 if (ExitingBlocks.size() != 1)
13920 OS << "<multiple exits> ";
13921
13922 auto *BTC = SE->getBackedgeTakenCount(L);
13923 if (!isa<SCEVCouldNotCompute>(BTC)) {
13924 OS << "backedge-taken count is ";
13925 PrintSCEVWithTypeHint(OS, BTC);
13926 } else
13927 OS << "Unpredictable backedge-taken count.";
13928 OS << "\n";
13929
13930 if (ExitingBlocks.size() > 1)
13931 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13932 OS << " exit count for " << ExitingBlock->getName() << ": ";
13933 const SCEV *EC = SE->getExitCount(L, ExitingBlock);
13934 PrintSCEVWithTypeHint(OS, EC);
13935 if (isa<SCEVCouldNotCompute>(EC)) {
13936 // Retry with predicates.
13938 EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
13939 if (!isa<SCEVCouldNotCompute>(EC)) {
13940 OS << "\n predicated exit count for " << ExitingBlock->getName()
13941 << ": ";
13942 PrintSCEVWithTypeHint(OS, EC);
13943 OS << "\n Predicates:\n";
13944 for (const auto *P : Predicates)
13945 P->print(OS, 4);
13946 }
13947 }
13948 OS << "\n";
13949 }
13950
13951 OS << "Loop ";
13952 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13953 OS << ": ";
13954
13955 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
13956 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
13957 OS << "constant max backedge-taken count is ";
13958 PrintSCEVWithTypeHint(OS, ConstantBTC);
13960 OS << ", actual taken count either this or zero.";
13961 } else {
13962 OS << "Unpredictable constant max backedge-taken count. ";
13963 }
13964
13965 OS << "\n"
13966 "Loop ";
13967 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13968 OS << ": ";
13969
13970 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
13971 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
13972 OS << "symbolic max backedge-taken count is ";
13973 PrintSCEVWithTypeHint(OS, SymbolicBTC);
13975 OS << ", actual taken count either this or zero.";
13976 } else {
13977 OS << "Unpredictable symbolic max backedge-taken count. ";
13978 }
13979 OS << "\n";
13980
13981 if (ExitingBlocks.size() > 1)
13982 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13983 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
13984 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
13986 PrintSCEVWithTypeHint(OS, ExitBTC);
13987 if (isa<SCEVCouldNotCompute>(ExitBTC)) {
13988 // Retry with predicates.
13990 ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
13992 if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
13993 OS << "\n predicated symbolic max exit count for "
13994 << ExitingBlock->getName() << ": ";
13995 PrintSCEVWithTypeHint(OS, ExitBTC);
13996 OS << "\n Predicates:\n";
13997 for (const auto *P : Predicates)
13998 P->print(OS, 4);
13999 }
14000 }
14001 OS << "\n";
14002 }
14003
14005 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
14006 if (PBT != BTC) {
14007 assert(!Preds.empty() && "Different predicated BTC, but no predicates");
14008 OS << "Loop ";
14009 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14010 OS << ": ";
14011 if (!isa<SCEVCouldNotCompute>(PBT)) {
14012 OS << "Predicated backedge-taken count is ";
14013 PrintSCEVWithTypeHint(OS, PBT);
14014 } else
14015 OS << "Unpredictable predicated backedge-taken count.";
14016 OS << "\n";
14017 OS << " Predicates:\n";
14018 for (const auto *P : Preds)
14019 P->print(OS, 4);
14020 }
14021 Preds.clear();
14022
14023 auto *PredConstantMax =
14025 if (PredConstantMax != ConstantBTC) {
14026 assert(!Preds.empty() &&
14027 "different predicated constant max BTC but no predicates");
14028 OS << "Loop ";
14029 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14030 OS << ": ";
14031 if (!isa<SCEVCouldNotCompute>(PredConstantMax)) {
14032 OS << "Predicated constant max backedge-taken count is ";
14033 PrintSCEVWithTypeHint(OS, PredConstantMax);
14034 } else
14035 OS << "Unpredictable predicated constant max backedge-taken count.";
14036 OS << "\n";
14037 OS << " Predicates:\n";
14038 for (const auto *P : Preds)
14039 P->print(OS, 4);
14040 }
14041 Preds.clear();
14042
14043 auto *PredSymbolicMax =
14045 if (SymbolicBTC != PredSymbolicMax) {
14046 assert(!Preds.empty() &&
14047 "Different predicated symbolic max BTC, but no predicates");
14048 OS << "Loop ";
14049 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14050 OS << ": ";
14051 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
14052 OS << "Predicated symbolic max backedge-taken count is ";
14053 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
14054 } else
14055 OS << "Unpredictable predicated symbolic max backedge-taken count.";
14056 OS << "\n";
14057 OS << " Predicates:\n";
14058 for (const auto *P : Preds)
14059 P->print(OS, 4);
14060 }
14061
14063 OS << "Loop ";
14064 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14065 OS << ": ";
14066 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
14067 }
14068}
14069
14070namespace llvm {
14071// Note: these overloaded operators need to be in the llvm namespace for them
14072// to be resolved correctly. If we put them outside the llvm namespace, the
14073//
14074// OS << ": " << SE.getLoopDisposition(SV, InnerL);
14075//
14076// code below "breaks" and start printing raw enum values as opposed to the
14077// string values.
14080 switch (LD) {
14082 OS << "Variant";
14083 break;
14085 OS << "Invariant";
14086 break;
14088 OS << "Computable";
14089 break;
14090 }
14091 return OS;
14092}
14093
14096 switch (BD) {
14098 OS << "DoesNotDominate";
14099 break;
14101 OS << "Dominates";
14102 break;
14104 OS << "ProperlyDominates";
14105 break;
14106 }
14107 return OS;
14108}
14109} // namespace llvm
14110
14112 // ScalarEvolution's implementation of the print method is to print
14113 // out SCEV values of all instructions that are interesting. Doing
14114 // this potentially causes it to create new SCEV objects though,
14115 // which technically conflicts with the const qualifier. This isn't
14116 // observable from outside the class though, so casting away the
14117 // const isn't dangerous.
14118 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14119
14120 if (ClassifyExpressions) {
14121 OS << "Classifying expressions for: ";
14122 F.printAsOperand(OS, /*PrintType=*/false);
14123 OS << "\n";
14124 for (Instruction &I : instructions(F))
14125 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
14126 OS << I << '\n';
14127 OS << " --> ";
14128 const SCEV *SV = SE.getSCEV(&I);
14129 SV->print(OS);
14130 if (!isa<SCEVCouldNotCompute>(SV)) {
14131 OS << " U: ";
14132 SE.getUnsignedRange(SV).print(OS);
14133 OS << " S: ";
14134 SE.getSignedRange(SV).print(OS);
14135 }
14136
14137 const Loop *L = LI.getLoopFor(I.getParent());
14138
14139 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
14140 if (AtUse != SV) {
14141 OS << " --> ";
14142 AtUse->print(OS);
14143 if (!isa<SCEVCouldNotCompute>(AtUse)) {
14144 OS << " U: ";
14145 SE.getUnsignedRange(AtUse).print(OS);
14146 OS << " S: ";
14147 SE.getSignedRange(AtUse).print(OS);
14148 }
14149 }
14150
14151 if (L) {
14152 OS << "\t\t" "Exits: ";
14153 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
14154 if (!SE.isLoopInvariant(ExitValue, L)) {
14155 OS << "<<Unknown>>";
14156 } else {
14157 OS << *ExitValue;
14158 }
14159
14160 ListSeparator LS(", ", "\t\tLoopDispositions: { ");
14161 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
14162 OS << LS;
14163 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14164 OS << ": " << SE.getLoopDisposition(SV, Iter);
14165 }
14166
14167 for (const auto *InnerL : depth_first(L)) {
14168 if (InnerL == L)
14169 continue;
14170 OS << LS;
14171 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14172 OS << ": " << SE.getLoopDisposition(SV, InnerL);
14173 }
14174
14175 OS << " }";
14176 }
14177
14178 OS << "\n";
14179 }
14180 }
14181
14182 OS << "Determining loop execution counts for: ";
14183 F.printAsOperand(OS, /*PrintType=*/false);
14184 OS << "\n";
14185 for (Loop *I : LI)
14186 PrintLoopInfo(OS, &SE, I);
14187}
14188
14191 auto &Values = LoopDispositions[S];
14192 for (auto &V : Values) {
14193 if (V.getPointer() == L)
14194 return V.getInt();
14195 }
14196 Values.emplace_back(L, LoopVariant);
14197 LoopDisposition D = computeLoopDisposition(S, L);
14198 auto &Values2 = LoopDispositions[S];
14199 for (auto &V : llvm::reverse(Values2)) {
14200 if (V.getPointer() == L) {
14201 V.setInt(D);
14202 break;
14203 }
14204 }
14205 return D;
14206}
14207
14209ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
14210 switch (S->getSCEVType()) {
14211 case scConstant:
14212 case scVScale:
14213 return LoopInvariant;
14214 case scAddRecExpr: {
14215 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14216
14217 // If L is the addrec's loop, it's computable.
14218 if (AR->getLoop() == L)
14219 return LoopComputable;
14220
14221 // Add recurrences are never invariant in the function-body (null loop).
14222 if (!L)
14223 return LoopVariant;
14224
14225 // Everything that is not defined at loop entry is variant.
14226 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
14227 return LoopVariant;
14228 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
14229 " dominate the contained loop's header?");
14230
14231 // This recurrence is invariant w.r.t. L if AR's loop contains L.
14232 if (AR->getLoop()->contains(L))
14233 return LoopInvariant;
14234
14235 // This recurrence is variant w.r.t. L if any of its operands
14236 // are variant.
14237 for (const auto *Op : AR->operands())
14238 if (!isLoopInvariant(Op, L))
14239 return LoopVariant;
14240
14241 // Otherwise it's loop-invariant.
14242 return LoopInvariant;
14243 }
14244 case scTruncate:
14245 case scZeroExtend:
14246 case scSignExtend:
14247 case scPtrToAddr:
14248 case scPtrToInt:
14249 case scAddExpr:
14250 case scMulExpr:
14251 case scUDivExpr:
14252 case scUMaxExpr:
14253 case scSMaxExpr:
14254 case scUMinExpr:
14255 case scSMinExpr:
14256 case scSequentialUMinExpr: {
14257 bool HasVarying = false;
14258 for (const auto *Op : S->operands()) {
14260 if (D == LoopVariant)
14261 return LoopVariant;
14262 if (D == LoopComputable)
14263 HasVarying = true;
14264 }
14265 return HasVarying ? LoopComputable : LoopInvariant;
14266 }
14267 case scUnknown:
14268 // All non-instruction values are loop invariant. All instructions are loop
14269 // invariant if they are not contained in the specified loop.
14270 // Instructions are never considered invariant in the function body
14271 // (null loop) because they are defined within the "loop".
14272 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
14273 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
14274 return LoopInvariant;
14275 case scCouldNotCompute:
14276 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14277 }
14278 llvm_unreachable("Unknown SCEV kind!");
14279}
14280
14282 return getLoopDisposition(S, L) == LoopInvariant;
14283}
14284
14286 return getLoopDisposition(S, L) == LoopComputable;
14287}
14288
14291 auto &Values = BlockDispositions[S];
14292 for (auto &V : Values) {
14293 if (V.getPointer() == BB)
14294 return V.getInt();
14295 }
14296 Values.emplace_back(BB, DoesNotDominateBlock);
14297 BlockDisposition D = computeBlockDisposition(S, BB);
14298 auto &Values2 = BlockDispositions[S];
14299 for (auto &V : llvm::reverse(Values2)) {
14300 if (V.getPointer() == BB) {
14301 V.setInt(D);
14302 break;
14303 }
14304 }
14305 return D;
14306}
14307
14309ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
14310 switch (S->getSCEVType()) {
14311 case scConstant:
14312 case scVScale:
14314 case scAddRecExpr: {
14315 // This uses a "dominates" query instead of "properly dominates" query
14316 // to test for proper dominance too, because the instruction which
14317 // produces the addrec's value is a PHI, and a PHI effectively properly
14318 // dominates its entire containing block.
14319 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14320 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
14321 return DoesNotDominateBlock;
14322
14323 // Fall through into SCEVNAryExpr handling.
14324 [[fallthrough]];
14325 }
14326 case scTruncate:
14327 case scZeroExtend:
14328 case scSignExtend:
14329 case scPtrToAddr:
14330 case scPtrToInt:
14331 case scAddExpr:
14332 case scMulExpr:
14333 case scUDivExpr:
14334 case scUMaxExpr:
14335 case scSMaxExpr:
14336 case scUMinExpr:
14337 case scSMinExpr:
14338 case scSequentialUMinExpr: {
14339 bool Proper = true;
14340 for (const SCEV *NAryOp : S->operands()) {
14342 if (D == DoesNotDominateBlock)
14343 return DoesNotDominateBlock;
14344 if (D == DominatesBlock)
14345 Proper = false;
14346 }
14347 return Proper ? ProperlyDominatesBlock : DominatesBlock;
14348 }
14349 case scUnknown:
14350 if (Instruction *I =
14351 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
14352 if (I->getParent() == BB)
14353 return DominatesBlock;
14354 if (DT.properlyDominates(I->getParent(), BB))
14356 return DoesNotDominateBlock;
14357 }
14359 case scCouldNotCompute:
14360 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14361 }
14362 llvm_unreachable("Unknown SCEV kind!");
14363}
14364
14365bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14366 return getBlockDisposition(S, BB) >= DominatesBlock;
14367}
14368
14371}
14372
14373bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14374 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14375}
14376
14377void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14378 bool Predicated) {
14379 auto &BECounts =
14380 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14381 auto It = BECounts.find(L);
14382 if (It != BECounts.end()) {
14383 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14384 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14385 if (!isa<SCEVConstant>(S)) {
14386 auto UserIt = BECountUsers.find(S);
14387 assert(UserIt != BECountUsers.end());
14388 UserIt->second.erase({L, Predicated});
14389 }
14390 }
14391 }
14392 BECounts.erase(It);
14393 }
14394}
14395
14396void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
14397 SmallPtrSet<const SCEV *, 8> ToForget(llvm::from_range, SCEVs);
14398 SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
14399
14400 while (!Worklist.empty()) {
14401 const SCEV *Curr = Worklist.pop_back_val();
14402 auto Users = SCEVUsers.find(Curr);
14403 if (Users != SCEVUsers.end())
14404 for (const auto *User : Users->second)
14405 if (ToForget.insert(User).second)
14406 Worklist.push_back(User);
14407 }
14408
14409 for (const auto *S : ToForget)
14410 forgetMemoizedResultsImpl(S);
14411
14412 for (auto I = PredicatedSCEVRewrites.begin();
14413 I != PredicatedSCEVRewrites.end();) {
14414 std::pair<const SCEV *, const Loop *> Entry = I->first;
14415 if (ToForget.count(Entry.first))
14416 PredicatedSCEVRewrites.erase(I++);
14417 else
14418 ++I;
14419 }
14420}
14421
14422void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14423 LoopDispositions.erase(S);
14424 BlockDispositions.erase(S);
14425 UnsignedRanges.erase(S);
14426 SignedRanges.erase(S);
14427 HasRecMap.erase(S);
14428 ConstantMultipleCache.erase(S);
14429
14430 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14431 UnsignedWrapViaInductionTried.erase(AR);
14432 SignedWrapViaInductionTried.erase(AR);
14433 }
14434
14435 auto ExprIt = ExprValueMap.find(S);
14436 if (ExprIt != ExprValueMap.end()) {
14437 for (Value *V : ExprIt->second) {
14438 auto ValueIt = ValueExprMap.find_as(V);
14439 if (ValueIt != ValueExprMap.end())
14440 ValueExprMap.erase(ValueIt);
14441 }
14442 ExprValueMap.erase(ExprIt);
14443 }
14444
14445 auto ScopeIt = ValuesAtScopes.find(S);
14446 if (ScopeIt != ValuesAtScopes.end()) {
14447 for (const auto &Pair : ScopeIt->second)
14448 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14449 llvm::erase(ValuesAtScopesUsers[Pair.second],
14450 std::make_pair(Pair.first, S));
14451 ValuesAtScopes.erase(ScopeIt);
14452 }
14453
14454 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14455 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14456 for (const auto &Pair : ScopeUserIt->second)
14457 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14458 ValuesAtScopesUsers.erase(ScopeUserIt);
14459 }
14460
14461 auto BEUsersIt = BECountUsers.find(S);
14462 if (BEUsersIt != BECountUsers.end()) {
14463 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14464 auto Copy = BEUsersIt->second;
14465 for (const auto &Pair : Copy)
14466 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14467 BECountUsers.erase(BEUsersIt);
14468 }
14469
14470 auto FoldUser = FoldCacheUser.find(S);
14471 if (FoldUser != FoldCacheUser.end())
14472 for (auto &KV : FoldUser->second)
14473 FoldCache.erase(KV);
14474 FoldCacheUser.erase(S);
14475}
14476
14477void
14478ScalarEvolution::getUsedLoops(const SCEV *S,
14479 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14480 struct FindUsedLoops {
14481 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14482 : LoopsUsed(LoopsUsed) {}
14483 SmallPtrSetImpl<const Loop *> &LoopsUsed;
14484 bool follow(const SCEV *S) {
14485 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14486 LoopsUsed.insert(AR->getLoop());
14487 return true;
14488 }
14489
14490 bool isDone() const { return false; }
14491 };
14492
14493 FindUsedLoops F(LoopsUsed);
14494 SCEVTraversal<FindUsedLoops>(F).visitAll(S);
14495}
14496
14497void ScalarEvolution::getReachableBlocks(
14500 Worklist.push_back(&F.getEntryBlock());
14501 while (!Worklist.empty()) {
14502 BasicBlock *BB = Worklist.pop_back_val();
14503 if (!Reachable.insert(BB).second)
14504 continue;
14505
14506 Value *Cond;
14507 BasicBlock *TrueBB, *FalseBB;
14508 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14509 m_BasicBlock(FalseBB)))) {
14510 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14511 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14512 continue;
14513 }
14514
14515 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14516 const SCEV *L = getSCEV(Cmp->getOperand(0));
14517 const SCEV *R = getSCEV(Cmp->getOperand(1));
14518 if (isKnownPredicateViaConstantRanges(Cmp->getCmpPredicate(), L, R)) {
14519 Worklist.push_back(TrueBB);
14520 continue;
14521 }
14522 if (isKnownPredicateViaConstantRanges(Cmp->getInverseCmpPredicate(), L,
14523 R)) {
14524 Worklist.push_back(FalseBB);
14525 continue;
14526 }
14527 }
14528 }
14529
14530 append_range(Worklist, successors(BB));
14531 }
14532}
14533
14535 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14536 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14537
14538 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14539
14540 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14541 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14542 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14543
14544 const SCEV *visitConstant(const SCEVConstant *Constant) {
14545 return SE.getConstant(Constant->getAPInt());
14546 }
14547
14548 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14549 return SE.getUnknown(Expr->getValue());
14550 }
14551
14552 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14553 return SE.getCouldNotCompute();
14554 }
14555 };
14556
14557 SCEVMapper SCM(SE2);
14558 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14559 SE2.getReachableBlocks(ReachableBlocks, F);
14560
14561 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14562 if (containsUndefs(Old) || containsUndefs(New)) {
14563 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14564 // not propagate undef aggressively). This means we can (and do) fail
14565 // verification in cases where a transform makes a value go from "undef"
14566 // to "undef+1" (say). The transform is fine, since in both cases the
14567 // result is "undef", but SCEV thinks the value increased by 1.
14568 return nullptr;
14569 }
14570
14571 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14572 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14573 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14574 return nullptr;
14575
14576 return Delta;
14577 };
14578
14579 while (!LoopStack.empty()) {
14580 auto *L = LoopStack.pop_back_val();
14581 llvm::append_range(LoopStack, *L);
14582
14583 // Only verify BECounts in reachable loops. For an unreachable loop,
14584 // any BECount is legal.
14585 if (!ReachableBlocks.contains(L->getHeader()))
14586 continue;
14587
14588 // Only verify cached BECounts. Computing new BECounts may change the
14589 // results of subsequent SCEV uses.
14590 auto It = BackedgeTakenCounts.find(L);
14591 if (It == BackedgeTakenCounts.end())
14592 continue;
14593
14594 auto *CurBECount =
14595 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14596 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14597
14598 if (CurBECount == SE2.getCouldNotCompute() ||
14599 NewBECount == SE2.getCouldNotCompute()) {
14600 // NB! This situation is legal, but is very suspicious -- whatever pass
14601 // change the loop to make a trip count go from could not compute to
14602 // computable or vice-versa *should have* invalidated SCEV. However, we
14603 // choose not to assert here (for now) since we don't want false
14604 // positives.
14605 continue;
14606 }
14607
14608 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14609 SE.getTypeSizeInBits(NewBECount->getType()))
14610 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14611 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14612 SE.getTypeSizeInBits(NewBECount->getType()))
14613 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14614
14615 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14616 if (Delta && !Delta->isZero()) {
14617 dbgs() << "Trip Count for " << *L << " Changed!\n";
14618 dbgs() << "Old: " << *CurBECount << "\n";
14619 dbgs() << "New: " << *NewBECount << "\n";
14620 dbgs() << "Delta: " << *Delta << "\n";
14621 std::abort();
14622 }
14623 }
14624
14625 // Collect all valid loops currently in LoopInfo.
14626 SmallPtrSet<Loop *, 32> ValidLoops;
14627 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14628 while (!Worklist.empty()) {
14629 Loop *L = Worklist.pop_back_val();
14630 if (ValidLoops.insert(L).second)
14631 Worklist.append(L->begin(), L->end());
14632 }
14633 for (const auto &KV : ValueExprMap) {
14634#ifndef NDEBUG
14635 // Check for SCEV expressions referencing invalid/deleted loops.
14636 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14637 assert(ValidLoops.contains(AR->getLoop()) &&
14638 "AddRec references invalid loop");
14639 }
14640#endif
14641
14642 // Check that the value is also part of the reverse map.
14643 auto It = ExprValueMap.find(KV.second);
14644 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14645 dbgs() << "Value " << *KV.first
14646 << " is in ValueExprMap but not in ExprValueMap\n";
14647 std::abort();
14648 }
14649
14650 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14651 if (!ReachableBlocks.contains(I->getParent()))
14652 continue;
14653 const SCEV *OldSCEV = SCM.visit(KV.second);
14654 const SCEV *NewSCEV = SE2.getSCEV(I);
14655 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14656 if (Delta && !Delta->isZero()) {
14657 dbgs() << "SCEV for value " << *I << " changed!\n"
14658 << "Old: " << *OldSCEV << "\n"
14659 << "New: " << *NewSCEV << "\n"
14660 << "Delta: " << *Delta << "\n";
14661 std::abort();
14662 }
14663 }
14664 }
14665
14666 for (const auto &KV : ExprValueMap) {
14667 for (Value *V : KV.second) {
14668 const SCEV *S = ValueExprMap.lookup(V);
14669 if (!S) {
14670 dbgs() << "Value " << *V
14671 << " is in ExprValueMap but not in ValueExprMap\n";
14672 std::abort();
14673 }
14674 if (S != KV.first) {
14675 dbgs() << "Value " << *V << " mapped to " << *S << " rather than "
14676 << *KV.first << "\n";
14677 std::abort();
14678 }
14679 }
14680 }
14681
14682 // Verify integrity of SCEV users.
14683 for (const auto &S : UniqueSCEVs) {
14684 for (const auto *Op : S.operands()) {
14685 // We do not store dependencies of constants.
14686 if (isa<SCEVConstant>(Op))
14687 continue;
14688 auto It = SCEVUsers.find(Op);
14689 if (It != SCEVUsers.end() && It->second.count(&S))
14690 continue;
14691 dbgs() << "Use of operand " << *Op << " by user " << S
14692 << " is not being tracked!\n";
14693 std::abort();
14694 }
14695 }
14696
14697 // Verify integrity of ValuesAtScopes users.
14698 for (const auto &ValueAndVec : ValuesAtScopes) {
14699 const SCEV *Value = ValueAndVec.first;
14700 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14701 const Loop *L = LoopAndValueAtScope.first;
14702 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14703 if (!isa<SCEVConstant>(ValueAtScope)) {
14704 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14705 if (It != ValuesAtScopesUsers.end() &&
14706 is_contained(It->second, std::make_pair(L, Value)))
14707 continue;
14708 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14709 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14710 std::abort();
14711 }
14712 }
14713 }
14714
14715 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14716 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14717 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14718 const Loop *L = LoopAndValue.first;
14719 const SCEV *Value = LoopAndValue.second;
14721 auto It = ValuesAtScopes.find(Value);
14722 if (It != ValuesAtScopes.end() &&
14723 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14724 continue;
14725 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14726 << *ValueAtScope << " missing in ValuesAtScopes\n";
14727 std::abort();
14728 }
14729 }
14730
14731 // Verify integrity of BECountUsers.
14732 auto VerifyBECountUsers = [&](bool Predicated) {
14733 auto &BECounts =
14734 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14735 for (const auto &LoopAndBEInfo : BECounts) {
14736 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14737 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14738 if (!isa<SCEVConstant>(S)) {
14739 auto UserIt = BECountUsers.find(S);
14740 if (UserIt != BECountUsers.end() &&
14741 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14742 continue;
14743 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14744 << " missing from BECountUsers\n";
14745 std::abort();
14746 }
14747 }
14748 }
14749 }
14750 };
14751 VerifyBECountUsers(/* Predicated */ false);
14752 VerifyBECountUsers(/* Predicated */ true);
14753
14754 // Verify intergity of loop disposition cache.
14755 for (auto &[S, Values] : LoopDispositions) {
14756 for (auto [Loop, CachedDisposition] : Values) {
14757 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14758 if (CachedDisposition != RecomputedDisposition) {
14759 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14760 << " is incorrect: cached " << CachedDisposition << ", actual "
14761 << RecomputedDisposition << "\n";
14762 std::abort();
14763 }
14764 }
14765 }
14766
14767 // Verify integrity of the block disposition cache.
14768 for (auto &[S, Values] : BlockDispositions) {
14769 for (auto [BB, CachedDisposition] : Values) {
14770 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14771 if (CachedDisposition != RecomputedDisposition) {
14772 dbgs() << "Cached disposition of " << *S << " for block %"
14773 << BB->getName() << " is incorrect: cached " << CachedDisposition
14774 << ", actual " << RecomputedDisposition << "\n";
14775 std::abort();
14776 }
14777 }
14778 }
14779
14780 // Verify FoldCache/FoldCacheUser caches.
14781 for (auto [FoldID, Expr] : FoldCache) {
14782 auto I = FoldCacheUser.find(Expr);
14783 if (I == FoldCacheUser.end()) {
14784 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14785 << "!\n";
14786 std::abort();
14787 }
14788 if (!is_contained(I->second, FoldID)) {
14789 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14790 std::abort();
14791 }
14792 }
14793 for (auto [Expr, IDs] : FoldCacheUser) {
14794 for (auto &FoldID : IDs) {
14795 const SCEV *S = FoldCache.lookup(FoldID);
14796 if (!S) {
14797 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14798 << "!\n";
14799 std::abort();
14800 }
14801 if (S != Expr) {
14802 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: " << *S
14803 << " != " << *Expr << "!\n";
14804 std::abort();
14805 }
14806 }
14807 }
14808
14809 // Verify that ConstantMultipleCache computations are correct. We check that
14810 // cached multiples and recomputed multiples are multiples of each other to
14811 // verify correctness. It is possible that a recomputed multiple is different
14812 // from the cached multiple due to strengthened no wrap flags or changes in
14813 // KnownBits computations.
14814 for (auto [S, Multiple] : ConstantMultipleCache) {
14815 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14816 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14817 Multiple.urem(RecomputedMultiple) != 0 &&
14818 RecomputedMultiple.urem(Multiple) != 0)) {
14819 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14820 << *S << " : Computed " << RecomputedMultiple
14821 << " but cache contains " << Multiple << "!\n";
14822 std::abort();
14823 }
14824 }
14825}
14826
14828 Function &F, const PreservedAnalyses &PA,
14829 FunctionAnalysisManager::Invalidator &Inv) {
14830 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14831 // of its dependencies is invalidated.
14832 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14833 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14834 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14835 Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
14836 Inv.invalidate<LoopAnalysis>(F, PA);
14837}
14838
14839AnalysisKey ScalarEvolutionAnalysis::Key;
14840
14843 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14844 auto &AC = AM.getResult<AssumptionAnalysis>(F);
14845 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
14846 auto &LI = AM.getResult<LoopAnalysis>(F);
14847 return ScalarEvolution(F, TLI, AC, DT, LI);
14848}
14849
14855
14858 // For compatibility with opt's -analyze feature under legacy pass manager
14859 // which was not ported to NPM. This keeps tests using
14860 // update_analyze_test_checks.py working.
14861 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
14862 << F.getName() << "':\n";
14864 return PreservedAnalyses::all();
14865}
14866
14868 "Scalar Evolution Analysis", false, true)
14874 "Scalar Evolution Analysis", false, true)
14875
14877
14879
14881 SE.reset(new ScalarEvolution(
14883 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
14885 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
14886 return false;
14887}
14888
14890
14892 SE->print(OS);
14893}
14894
14896 if (!VerifySCEV)
14897 return;
14898
14899 SE->verify();
14900}
14901
14909
14911 const SCEV *RHS) {
14912 return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS);
14913}
14914
14915const SCEVPredicate *
14917 const SCEV *LHS, const SCEV *RHS) {
14919 assert(LHS->getType() == RHS->getType() &&
14920 "Type mismatch between LHS and RHS");
14921 // Unique this node based on the arguments
14922 ID.AddInteger(SCEVPredicate::P_Compare);
14923 ID.AddInteger(Pred);
14924 ID.AddPointer(LHS);
14925 ID.AddPointer(RHS);
14926 void *IP = nullptr;
14927 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14928 return S;
14929 SCEVComparePredicate *Eq = new (SCEVAllocator)
14930 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
14931 UniquePreds.InsertNode(Eq, IP);
14932 return Eq;
14933}
14934
14936 const SCEVAddRecExpr *AR,
14939 // Unique this node based on the arguments
14940 ID.AddInteger(SCEVPredicate::P_Wrap);
14941 ID.AddPointer(AR);
14942 ID.AddInteger(AddedFlags);
14943 void *IP = nullptr;
14944 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14945 return S;
14946 auto *OF = new (SCEVAllocator)
14947 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
14948 UniquePreds.InsertNode(OF, IP);
14949 return OF;
14950}
14951
14952namespace {
14953
14954class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14955public:
14956
14957 /// Rewrites \p S in the context of a loop L and the SCEV predication
14958 /// infrastructure.
14959 ///
14960 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
14961 /// equivalences present in \p Pred.
14962 ///
14963 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
14964 /// \p NewPreds such that the result will be an AddRecExpr.
14965 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
14967 const SCEVPredicate *Pred) {
14968 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
14969 return Rewriter.visit(S);
14970 }
14971
14972 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14973 if (Pred) {
14974 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
14975 for (const auto *Pred : U->getPredicates())
14976 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
14977 if (IPred->getLHS() == Expr &&
14978 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14979 return IPred->getRHS();
14980 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
14981 if (IPred->getLHS() == Expr &&
14982 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14983 return IPred->getRHS();
14984 }
14985 }
14986 return convertToAddRecWithPreds(Expr);
14987 }
14988
14989 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14990 const SCEV *Operand = visit(Expr->getOperand());
14991 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14992 if (AR && AR->getLoop() == L && AR->isAffine()) {
14993 // This couldn't be folded because the operand didn't have the nuw
14994 // flag. Add the nusw flag as an assumption that we could make.
14995 const SCEV *Step = AR->getStepRecurrence(SE);
14996 Type *Ty = Expr->getType();
14997 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
14998 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
14999 SE.getSignExtendExpr(Step, Ty), L,
15000 AR->getNoWrapFlags());
15001 }
15002 return SE.getZeroExtendExpr(Operand, Expr->getType());
15003 }
15004
15005 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
15006 const SCEV *Operand = visit(Expr->getOperand());
15007 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
15008 if (AR && AR->getLoop() == L && AR->isAffine()) {
15009 // This couldn't be folded because the operand didn't have the nsw
15010 // flag. Add the nssw flag as an assumption that we could make.
15011 const SCEV *Step = AR->getStepRecurrence(SE);
15012 Type *Ty = Expr->getType();
15013 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
15014 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
15015 SE.getSignExtendExpr(Step, Ty), L,
15016 AR->getNoWrapFlags());
15017 }
15018 return SE.getSignExtendExpr(Operand, Expr->getType());
15019 }
15020
15021private:
15022 explicit SCEVPredicateRewriter(
15023 const Loop *L, ScalarEvolution &SE,
15024 SmallVectorImpl<const SCEVPredicate *> *NewPreds,
15025 const SCEVPredicate *Pred)
15026 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
15027
15028 bool addOverflowAssumption(const SCEVPredicate *P) {
15029 if (!NewPreds) {
15030 // Check if we've already made this assumption.
15031 return Pred && Pred->implies(P, SE);
15032 }
15033 NewPreds->push_back(P);
15034 return true;
15035 }
15036
15037 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
15039 auto *A = SE.getWrapPredicate(AR, AddedFlags);
15040 return addOverflowAssumption(A);
15041 }
15042
15043 // If \p Expr represents a PHINode, we try to see if it can be represented
15044 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
15045 // to add this predicate as a runtime overflow check, we return the AddRec.
15046 // If \p Expr does not meet these conditions (is not a PHI node, or we
15047 // couldn't create an AddRec for it, or couldn't add the predicate), we just
15048 // return \p Expr.
15049 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
15050 if (!isa<PHINode>(Expr->getValue()))
15051 return Expr;
15052 std::optional<
15053 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
15054 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
15055 if (!PredicatedRewrite)
15056 return Expr;
15057 for (const auto *P : PredicatedRewrite->second){
15058 // Wrap predicates from outer loops are not supported.
15059 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
15060 if (L != WP->getExpr()->getLoop())
15061 return Expr;
15062 }
15063 if (!addOverflowAssumption(P))
15064 return Expr;
15065 }
15066 return PredicatedRewrite->first;
15067 }
15068
15069 SmallVectorImpl<const SCEVPredicate *> *NewPreds;
15070 const SCEVPredicate *Pred;
15071 const Loop *L;
15072};
15073
15074} // end anonymous namespace
15075
15076const SCEV *
15078 const SCEVPredicate &Preds) {
15079 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
15080}
15081
15083 const SCEV *S, const Loop *L,
15086 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
15087 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
15088
15089 if (!AddRec)
15090 return nullptr;
15091
15092 // Check if any of the transformed predicates is known to be false. In that
15093 // case, it doesn't make sense to convert to a predicated AddRec, as the
15094 // versioned loop will never execute.
15095 for (const SCEVPredicate *Pred : TransformPreds) {
15096 auto *WrapPred = dyn_cast<SCEVWrapPredicate>(Pred);
15097 if (!WrapPred || WrapPred->getFlags() != SCEVWrapPredicate::IncrementNSSW)
15098 continue;
15099
15100 const SCEVAddRecExpr *AddRecToCheck = WrapPred->getExpr();
15101 const SCEV *ExitCount = getBackedgeTakenCount(AddRecToCheck->getLoop());
15102 if (isa<SCEVCouldNotCompute>(ExitCount))
15103 continue;
15104
15105 const SCEV *Step = AddRecToCheck->getStepRecurrence(*this);
15106 if (!Step->isOne())
15107 continue;
15108
15109 ExitCount = getTruncateOrSignExtend(ExitCount, Step->getType());
15110 const SCEV *Add = getAddExpr(AddRecToCheck->getStart(), ExitCount);
15111 if (isKnownPredicate(CmpInst::ICMP_SLT, Add, AddRecToCheck->getStart()))
15112 return nullptr;
15113 }
15114
15115 // Since the transformation was successful, we can now transfer the SCEV
15116 // predicates.
15117 Preds.append(TransformPreds.begin(), TransformPreds.end());
15118
15119 return AddRec;
15120}
15121
15122/// SCEV predicates
15126
15128 const ICmpInst::Predicate Pred,
15129 const SCEV *LHS, const SCEV *RHS)
15130 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
15131 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
15132 assert(LHS != RHS && "LHS and RHS are the same SCEV");
15133}
15134
15136 ScalarEvolution &SE) const {
15137 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
15138
15139 if (!Op)
15140 return false;
15141
15142 if (Pred != ICmpInst::ICMP_EQ)
15143 return false;
15144
15145 return Op->LHS == LHS && Op->RHS == RHS;
15146}
15147
15148bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
15149
15151 if (Pred == ICmpInst::ICMP_EQ)
15152 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
15153 else
15154 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
15155 << *RHS << "\n";
15156
15157}
15158
15160 const SCEVAddRecExpr *AR,
15161 IncrementWrapFlags Flags)
15162 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
15163
15164const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
15165
15167 ScalarEvolution &SE) const {
15168 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
15169 if (!Op || setFlags(Flags, Op->Flags) != Flags)
15170 return false;
15171
15172 if (Op->AR == AR)
15173 return true;
15174
15175 if (Flags != SCEVWrapPredicate::IncrementNSSW &&
15177 return false;
15178
15179 const SCEV *Start = AR->getStart();
15180 const SCEV *OpStart = Op->AR->getStart();
15181 if (Start->getType()->isPointerTy() != OpStart->getType()->isPointerTy())
15182 return false;
15183
15184 // Reject pointers to different address spaces.
15185 if (Start->getType()->isPointerTy() && Start->getType() != OpStart->getType())
15186 return false;
15187
15188 const SCEV *Step = AR->getStepRecurrence(SE);
15189 const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
15190 if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
15191 return false;
15192
15193 // If both steps are positive, this implies N, if N's start and step are
15194 // ULE/SLE (for NSUW/NSSW) than this'.
15195 Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
15196 Step = SE.getNoopOrZeroExtend(Step, WiderTy);
15197 OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
15198
15199 bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
15200 OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
15201 : SE.getNoopOrSignExtend(OpStart, WiderTy);
15202 Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
15203 : SE.getNoopOrSignExtend(Start, WiderTy);
15205 return SE.isKnownPredicate(Pred, OpStep, Step) &&
15206 SE.isKnownPredicate(Pred, OpStart, Start);
15207}
15208
15210 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
15211 IncrementWrapFlags IFlags = Flags;
15212
15213 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
15214 IFlags = clearFlags(IFlags, IncrementNSSW);
15215
15216 return IFlags == IncrementAnyWrap;
15217}
15218
15219void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
15220 OS.indent(Depth) << *getExpr() << " Added Flags: ";
15222 OS << "<nusw>";
15224 OS << "<nssw>";
15225 OS << "\n";
15226}
15227
15230 ScalarEvolution &SE) {
15231 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
15232 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
15233
15234 // We can safely transfer the NSW flag as NSSW.
15235 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
15236 ImpliedFlags = IncrementNSSW;
15237
15238 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
15239 // If the increment is positive, the SCEV NUW flag will also imply the
15240 // WrapPredicate NUSW flag.
15241 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
15242 if (Step->getValue()->getValue().isNonNegative())
15243 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
15244 }
15245
15246 return ImpliedFlags;
15247}
15248
15249/// Union predicates don't get cached so create a dummy set ID for it.
15251 ScalarEvolution &SE)
15252 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15253 for (const auto *P : Preds)
15254 add(P, SE);
15255}
15256
15258 return all_of(Preds,
15259 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15260}
15261
15263 ScalarEvolution &SE) const {
15264 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15265 return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15266 return this->implies(I, SE);
15267 });
15268
15269 return any_of(Preds,
15270 [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
15271}
15272
15274 for (const auto *Pred : Preds)
15275 Pred->print(OS, Depth);
15276}
15277
15278void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
15279 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15280 for (const auto *Pred : Set->Preds)
15281 add(Pred, SE);
15282 return;
15283 }
15284
15285 // Implication checks are quadratic in the number of predicates. Stop doing
15286 // them if there are many predicates, as they should be too expensive to use
15287 // anyway at that point.
15288 bool CheckImplies = Preds.size() < 16;
15289
15290 // Only add predicate if it is not already implied by this union predicate.
15291 if (CheckImplies && implies(N, SE))
15292 return;
15293
15294 // Build a new vector containing the current predicates, except the ones that
15295 // are implied by the new predicate N.
15297 for (auto *P : Preds) {
15298 if (CheckImplies && N->implies(P, SE))
15299 continue;
15300 PrunedPreds.push_back(P);
15301 }
15302 Preds = std::move(PrunedPreds);
15303 Preds.push_back(N);
15304}
15305
15307 Loop &L)
15308 : SE(SE), L(L) {
15310 Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
15311}
15312
15315 for (const auto *Op : Ops)
15316 // We do not expect that forgetting cached data for SCEVConstants will ever
15317 // open any prospects for sharpening or introduce any correctness issues,
15318 // so we don't bother storing their dependencies.
15319 if (!isa<SCEVConstant>(Op))
15320 SCEVUsers[Op].insert(User);
15321}
15322
15324 const SCEV *Expr = SE.getSCEV(V);
15325 return getPredicatedSCEV(Expr);
15326}
15327
15329 RewriteEntry &Entry = RewriteMap[Expr];
15330
15331 // If we already have an entry and the version matches, return it.
15332 if (Entry.second && Generation == Entry.first)
15333 return Entry.second;
15334
15335 // We found an entry but it's stale. Rewrite the stale entry
15336 // according to the current predicate.
15337 if (Entry.second)
15338 Expr = Entry.second;
15339
15340 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
15341 Entry = {Generation, NewSCEV};
15342
15343 return NewSCEV;
15344}
15345
15347 if (!BackedgeCount) {
15349 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
15350 for (const auto *P : Preds)
15351 addPredicate(*P);
15352 }
15353 return BackedgeCount;
15354}
15355
15357 if (!SymbolicMaxBackedgeCount) {
15359 SymbolicMaxBackedgeCount =
15360 SE.getPredicatedSymbolicMaxBackedgeTakenCount(&L, Preds);
15361 for (const auto *P : Preds)
15362 addPredicate(*P);
15363 }
15364 return SymbolicMaxBackedgeCount;
15365}
15366
15368 if (!SmallConstantMaxTripCount) {
15370 SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds);
15371 for (const auto *P : Preds)
15372 addPredicate(*P);
15373 }
15374 return *SmallConstantMaxTripCount;
15375}
15376
15378 if (Preds->implies(&Pred, SE))
15379 return;
15380
15381 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15382 NewPreds.push_back(&Pred);
15383 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
15384 updateGeneration();
15385}
15386
15388 return *Preds;
15389}
15390
15391void PredicatedScalarEvolution::updateGeneration() {
15392 // If the generation number wrapped recompute everything.
15393 if (++Generation == 0) {
15394 for (auto &II : RewriteMap) {
15395 const SCEV *Rewritten = II.second.second;
15396 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
15397 }
15398 }
15399}
15400
15403 const SCEV *Expr = getSCEV(V);
15404 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15405
15406 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
15407
15408 // Clear the statically implied flags.
15409 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
15410 addPredicate(*SE.getWrapPredicate(AR, Flags));
15411
15412 auto II = FlagsMap.insert({V, Flags});
15413 if (!II.second)
15414 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
15415}
15416
15419 const SCEV *Expr = getSCEV(V);
15420 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15421
15423 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
15424
15425 auto II = FlagsMap.find(V);
15426
15427 if (II != FlagsMap.end())
15428 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
15429
15431}
15432
15434 const SCEV *Expr = this->getSCEV(V);
15436 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
15437
15438 if (!New)
15439 return nullptr;
15440
15441 for (const auto *P : NewPreds)
15442 addPredicate(*P);
15443
15444 RewriteMap[SE.getSCEV(V)] = {Generation, New};
15445 return New;
15446}
15447
15450 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15451 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15452 SE)),
15453 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15454 for (auto I : Init.FlagsMap)
15455 FlagsMap.insert(I);
15456}
15457
15459 // For each block.
15460 for (auto *BB : L.getBlocks())
15461 for (auto &I : *BB) {
15462 if (!SE.isSCEVable(I.getType()))
15463 continue;
15464
15465 auto *Expr = SE.getSCEV(&I);
15466 auto II = RewriteMap.find(Expr);
15467
15468 if (II == RewriteMap.end())
15469 continue;
15470
15471 // Don't print things that are not interesting.
15472 if (II->second.second == Expr)
15473 continue;
15474
15475 OS.indent(Depth) << "[PSE]" << I << ":\n";
15476 OS.indent(Depth + 2) << *Expr << "\n";
15477 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15478 }
15479}
15480
15483 BasicBlock *Header = L->getHeader();
15484 BasicBlock *Pred = L->getLoopPredecessor();
15485 LoopGuards Guards(SE);
15486 if (!Pred)
15487 return Guards;
15489 collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15490 return Guards;
15491}
15492
15493void ScalarEvolution::LoopGuards::collectFromPHI(
15497 unsigned Depth) {
15498 if (!SE.isSCEVable(Phi.getType()))
15499 return;
15500
15501 using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15502 auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15503 const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15504 if (!VisitedBlocks.insert(InBlock).second)
15505 return {nullptr, scCouldNotCompute};
15506
15507 // Avoid analyzing unreachable blocks so that we don't get trapped
15508 // traversing cycles with ill-formed dominance or infinite cycles
15509 if (!SE.DT.isReachableFromEntry(InBlock))
15510 return {nullptr, scCouldNotCompute};
15511
15512 auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15513 if (Inserted)
15514 collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15515 Depth + 1);
15516 auto &RewriteMap = G->second.RewriteMap;
15517 if (RewriteMap.empty())
15518 return {nullptr, scCouldNotCompute};
15519 auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15520 if (S == RewriteMap.end())
15521 return {nullptr, scCouldNotCompute};
15522 auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15523 if (!SM)
15524 return {nullptr, scCouldNotCompute};
15525 if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15526 return {C0, SM->getSCEVType()};
15527 return {nullptr, scCouldNotCompute};
15528 };
15529 auto MergeMinMaxConst = [](MinMaxPattern P1,
15530 MinMaxPattern P2) -> MinMaxPattern {
15531 auto [C1, T1] = P1;
15532 auto [C2, T2] = P2;
15533 if (!C1 || !C2 || T1 != T2)
15534 return {nullptr, scCouldNotCompute};
15535 switch (T1) {
15536 case scUMaxExpr:
15537 return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15538 case scSMaxExpr:
15539 return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15540 case scUMinExpr:
15541 return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15542 case scSMinExpr:
15543 return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15544 default:
15545 llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15546 }
15547 };
15548 auto P = GetMinMaxConst(0);
15549 for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15550 if (!P.first)
15551 break;
15552 P = MergeMinMaxConst(P, GetMinMaxConst(In));
15553 }
15554 if (P.first) {
15555 const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15557 const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15558 Guards.RewriteMap.insert({LHS, RHS});
15559 }
15560}
15561
15562// Return a new SCEV that modifies \p Expr to the closest number divides by
15563// \p Divisor and less or equal than Expr. For now, only handle constant
15564// Expr.
15566 const APInt &DivisorVal,
15567 ScalarEvolution &SE) {
15568 const APInt *ExprVal;
15569 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15570 DivisorVal.isNonPositive())
15571 return Expr;
15572 APInt Rem = ExprVal->urem(DivisorVal);
15573 // return the SCEV: Expr - Expr % Divisor
15574 return SE.getConstant(*ExprVal - Rem);
15575}
15576
15577// Return a new SCEV that modifies \p Expr to the closest number divides by
15578// \p Divisor and greater or equal than Expr. For now, only handle constant
15579// Expr.
15580static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
15581 const APInt &DivisorVal,
15582 ScalarEvolution &SE) {
15583 const APInt *ExprVal;
15584 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15585 DivisorVal.isNonPositive())
15586 return Expr;
15587 APInt Rem = ExprVal->urem(DivisorVal);
15588 if (Rem.isZero())
15589 return Expr;
15590 // return the SCEV: Expr + Divisor - Expr % Divisor
15591 return SE.getConstant(*ExprVal + DivisorVal - Rem);
15592}
15593
15595 ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS,
15598 // If we have LHS == 0, check if LHS is computing a property of some unknown
15599 // SCEV %v which we can rewrite %v to express explicitly.
15601 return false;
15602 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15603 // explicitly express that.
15604 const SCEVUnknown *URemLHS = nullptr;
15605 const SCEV *URemRHS = nullptr;
15606 if (!match(LHS, m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE)))
15607 return false;
15608
15609 const SCEV *Multiple =
15610 SE.getMulExpr(SE.getUDivExpr(URemLHS, URemRHS), URemRHS);
15611 DivInfo[URemLHS] = Multiple;
15612 if (auto *C = dyn_cast<SCEVConstant>(URemRHS))
15613 Multiples[URemLHS] = C->getAPInt();
15614 return true;
15615}
15616
15617// Check if the condition is a divisibility guard (A % B == 0).
15618static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS,
15619 ScalarEvolution &SE) {
15620 const SCEV *X, *Y;
15621 return match(LHS, m_scev_URem(m_SCEV(X), m_SCEV(Y), SE)) && RHS->isZero();
15622}
15623
15624// Apply divisibility by \p Divisor on MinMaxExpr with constant values,
15625// recursively. This is done by aligning up/down the constant value to the
15626// Divisor.
15627static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr,
15628 APInt Divisor,
15629 ScalarEvolution &SE) {
15630 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15631 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15632 // the non-constant operand and in \p LHS the constant operand.
15633 auto IsMinMaxSCEVWithNonNegativeConstant =
15634 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15635 const SCEV *&RHS) {
15636 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15637 if (MinMax->getNumOperands() != 2)
15638 return false;
15639 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15640 if (C->getAPInt().isNegative())
15641 return false;
15642 SCTy = MinMax->getSCEVType();
15643 LHS = MinMax->getOperand(0);
15644 RHS = MinMax->getOperand(1);
15645 return true;
15646 }
15647 }
15648 return false;
15649 };
15650
15651 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15652 SCEVTypes SCTy;
15653 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15654 MinMaxRHS))
15655 return MinMaxExpr;
15656 auto IsMin = isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15657 assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!");
15658 auto *DivisibleExpr =
15659 IsMin ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE)
15660 : getNextSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE);
15662 applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr};
15663 return SE.getMinMaxExpr(SCTy, Ops);
15664}
15665
15666void ScalarEvolution::LoopGuards::collectFromBlock(
15667 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
15668 const BasicBlock *Block, const BasicBlock *Pred,
15669 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
15670
15672
15673 SmallVector<const SCEV *> ExprsToRewrite;
15674 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15675 const SCEV *RHS,
15676 DenseMap<const SCEV *, const SCEV *> &RewriteMap,
15677 const LoopGuards &DivGuards) {
15678 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15679 // replacement SCEV which isn't directly implied by the structure of that
15680 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15681 // legal. See the scoping rules for flags in the header to understand why.
15682
15683 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15684 // create this form when combining two checks of the form (X u< C2 + C1) and
15685 // (X >=u C1).
15686 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15687 &ExprsToRewrite]() {
15688 const SCEVConstant *C1;
15689 const SCEVUnknown *LHSUnknown;
15690 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15691 if (!match(LHS,
15692 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
15693 !C2)
15694 return false;
15695
15696 auto ExactRegion =
15697 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15698 .sub(C1->getAPInt());
15699
15700 // Bail out, unless we have a non-wrapping, monotonic range.
15701 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15702 return false;
15703 auto [I, Inserted] = RewriteMap.try_emplace(LHSUnknown);
15704 const SCEV *RewrittenLHS = Inserted ? LHSUnknown : I->second;
15705 I->second = SE.getUMaxExpr(
15706 SE.getConstant(ExactRegion.getUnsignedMin()),
15707 SE.getUMinExpr(RewrittenLHS,
15708 SE.getConstant(ExactRegion.getUnsignedMax())));
15709 ExprsToRewrite.push_back(LHSUnknown);
15710 return true;
15711 };
15712 if (MatchRangeCheckIdiom())
15713 return;
15714
15715 // Do not apply information for constants or if RHS contains an AddRec.
15717 return;
15718
15719 // If RHS is SCEVUnknown, make sure the information is applied to it.
15721 std::swap(LHS, RHS);
15723 }
15724
15725 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15726 // and \p FromRewritten are the same (i.e. there has been no rewrite
15727 // registered for \p From), then puts this value in the list of rewritten
15728 // expressions.
15729 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15730 const SCEV *To) {
15731 if (From == FromRewritten)
15732 ExprsToRewrite.push_back(From);
15733 RewriteMap[From] = To;
15734 };
15735
15736 // Checks whether \p S has already been rewritten. In that case returns the
15737 // existing rewrite because we want to chain further rewrites onto the
15738 // already rewritten value. Otherwise returns \p S.
15739 auto GetMaybeRewritten = [&](const SCEV *S) {
15740 return RewriteMap.lookup_or(S, S);
15741 };
15742
15743 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15744 // Apply divisibility information when computing the constant multiple.
15745 const APInt &DividesBy =
15746 SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS));
15747
15748 // Collect rewrites for LHS and its transitive operands based on the
15749 // condition.
15750 // For min/max expressions, also apply the guard to its operands:
15751 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15752 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15753 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15754 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15755
15756 // We cannot express strict predicates in SCEV, so instead we replace them
15757 // with non-strict ones against plus or minus one of RHS depending on the
15758 // predicate.
15759 const SCEV *One = SE.getOne(RHS->getType());
15760 switch (Predicate) {
15761 case CmpInst::ICMP_ULT:
15762 if (RHS->getType()->isPointerTy())
15763 return;
15764 RHS = SE.getUMaxExpr(RHS, One);
15765 [[fallthrough]];
15766 case CmpInst::ICMP_SLT: {
15767 RHS = SE.getMinusSCEV(RHS, One);
15768 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15769 break;
15770 }
15771 case CmpInst::ICMP_UGT:
15772 case CmpInst::ICMP_SGT:
15773 RHS = SE.getAddExpr(RHS, One);
15774 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15775 break;
15776 case CmpInst::ICMP_ULE:
15777 case CmpInst::ICMP_SLE:
15778 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15779 break;
15780 case CmpInst::ICMP_UGE:
15781 case CmpInst::ICMP_SGE:
15782 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15783 break;
15784 default:
15785 break;
15786 }
15787
15789 SmallPtrSet<const SCEV *, 16> Visited;
15790
15791 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15792 append_range(Worklist, S->operands());
15793 };
15794
15795 while (!Worklist.empty()) {
15796 const SCEV *From = Worklist.pop_back_val();
15797 if (isa<SCEVConstant>(From))
15798 continue;
15799 if (!Visited.insert(From).second)
15800 continue;
15801 const SCEV *FromRewritten = GetMaybeRewritten(From);
15802 const SCEV *To = nullptr;
15803
15804 switch (Predicate) {
15805 case CmpInst::ICMP_ULT:
15806 case CmpInst::ICMP_ULE:
15807 To = SE.getUMinExpr(FromRewritten, RHS);
15808 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15809 EnqueueOperands(UMax);
15810 break;
15811 case CmpInst::ICMP_SLT:
15812 case CmpInst::ICMP_SLE:
15813 To = SE.getSMinExpr(FromRewritten, RHS);
15814 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15815 EnqueueOperands(SMax);
15816 break;
15817 case CmpInst::ICMP_UGT:
15818 case CmpInst::ICMP_UGE:
15819 To = SE.getUMaxExpr(FromRewritten, RHS);
15820 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15821 EnqueueOperands(UMin);
15822 break;
15823 case CmpInst::ICMP_SGT:
15824 case CmpInst::ICMP_SGE:
15825 To = SE.getSMaxExpr(FromRewritten, RHS);
15826 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15827 EnqueueOperands(SMin);
15828 break;
15829 case CmpInst::ICMP_EQ:
15831 To = RHS;
15832 break;
15833 case CmpInst::ICMP_NE:
15834 if (match(RHS, m_scev_Zero())) {
15835 const SCEV *OneAlignedUp =
15836 getNextSCEVDivisibleByDivisor(One, DividesBy, SE);
15837 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
15838 } else {
15839 // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS),
15840 // but creating the subtraction eagerly is expensive. Track the
15841 // inequalities in a separate map, and materialize the rewrite lazily
15842 // when encountering a suitable subtraction while re-writing.
15843 if (LHS->getType()->isPointerTy()) {
15847 break;
15848 }
15849 const SCEVConstant *C;
15850 const SCEV *A, *B;
15853 RHS = A;
15854 LHS = B;
15855 }
15856 if (LHS > RHS)
15857 std::swap(LHS, RHS);
15858 Guards.NotEqual.insert({LHS, RHS});
15859 continue;
15860 }
15861 break;
15862 default:
15863 break;
15864 }
15865
15866 if (To)
15867 AddRewrite(From, FromRewritten, To);
15868 }
15869 };
15870
15872 // First, collect information from assumptions dominating the loop.
15873 for (auto &AssumeVH : SE.AC.assumptions()) {
15874 if (!AssumeVH)
15875 continue;
15876 auto *AssumeI = cast<CallInst>(AssumeVH);
15877 if (!SE.DT.dominates(AssumeI, Block))
15878 continue;
15879 Terms.emplace_back(AssumeI->getOperand(0), true);
15880 }
15881
15882 // Second, collect information from llvm.experimental.guards dominating the loop.
15883 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
15884 SE.F.getParent(), Intrinsic::experimental_guard);
15885 if (GuardDecl)
15886 for (const auto *GU : GuardDecl->users())
15887 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
15888 if (Guard->getFunction() == Block->getParent() &&
15889 SE.DT.dominates(Guard, Block))
15890 Terms.emplace_back(Guard->getArgOperand(0), true);
15891
15892 // Third, collect conditions from dominating branches. Starting at the loop
15893 // predecessor, climb up the predecessor chain, as long as there are
15894 // predecessors that can be found that have unique successors leading to the
15895 // original header.
15896 // TODO: share this logic with isLoopEntryGuardedByCond.
15897 unsigned NumCollectedConditions = 0;
15899 std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
15900 for (; Pair.first;
15901 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
15902 VisitedBlocks.insert(Pair.second);
15903 const BranchInst *LoopEntryPredicate =
15904 dyn_cast<BranchInst>(Pair.first->getTerminator());
15905 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
15906 continue;
15907
15908 Terms.emplace_back(LoopEntryPredicate->getCondition(),
15909 LoopEntryPredicate->getSuccessor(0) == Pair.second);
15910 NumCollectedConditions++;
15911
15912 // If we are recursively collecting guards stop after 2
15913 // conditions to limit compile-time impact for now.
15914 if (Depth > 0 && NumCollectedConditions == 2)
15915 break;
15916 }
15917 // Finally, if we stopped climbing the predecessor chain because
15918 // there wasn't a unique one to continue, try to collect conditions
15919 // for PHINodes by recursively following all of their incoming
15920 // blocks and try to merge the found conditions to build a new one
15921 // for the Phi.
15922 if (Pair.second->hasNPredecessorsOrMore(2) &&
15924 SmallDenseMap<const BasicBlock *, LoopGuards> IncomingGuards;
15925 for (auto &Phi : Pair.second->phis())
15926 collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
15927 }
15928
15929 // Now apply the information from the collected conditions to
15930 // Guards.RewriteMap. Conditions are processed in reverse order, so the
15931 // earliest conditions is processed first, except guards with divisibility
15932 // information, which are moved to the back. This ensures the SCEVs with the
15933 // shortest dependency chains are constructed first.
15935 GuardsToProcess;
15936 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
15937 SmallVector<Value *, 8> Worklist;
15938 SmallPtrSet<Value *, 8> Visited;
15939 Worklist.push_back(Term);
15940 while (!Worklist.empty()) {
15941 Value *Cond = Worklist.pop_back_val();
15942 if (!Visited.insert(Cond).second)
15943 continue;
15944
15945 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
15946 auto Predicate =
15947 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
15948 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
15949 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15950 // If LHS is a constant, apply information to the other expression.
15951 // TODO: If LHS is not a constant, check if using CompareSCEVComplexity
15952 // can improve results.
15953 if (isa<SCEVConstant>(LHS)) {
15954 std::swap(LHS, RHS);
15956 }
15957 GuardsToProcess.emplace_back(Predicate, LHS, RHS);
15958 continue;
15959 }
15960
15961 Value *L, *R;
15962 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
15963 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
15964 Worklist.push_back(L);
15965 Worklist.push_back(R);
15966 }
15967 }
15968 }
15969
15970 // Process divisibility guards in reverse order to populate DivGuards early.
15971 DenseMap<const SCEV *, APInt> Multiples;
15972 LoopGuards DivGuards(SE);
15973 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) {
15974 if (!isDivisibilityGuard(LHS, RHS, SE))
15975 continue;
15976 collectDivisibilityInformation(Predicate, LHS, RHS, DivGuards.RewriteMap,
15977 Multiples, SE);
15978 }
15979
15980 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess)
15981 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivGuards);
15982
15983 // Apply divisibility information last. This ensures it is applied to the
15984 // outermost expression after other rewrites for the given value.
15985 for (const auto &[K, Divisor] : Multiples) {
15986 const SCEV *DivisorSCEV = SE.getConstant(Divisor);
15987 Guards.RewriteMap[K] =
15989 Guards.rewrite(K), Divisor, SE),
15990 DivisorSCEV),
15991 DivisorSCEV);
15992 ExprsToRewrite.push_back(K);
15993 }
15994
15995 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
15996 // the replacement expressions are contained in the ranges of the replaced
15997 // expressions.
15998 Guards.PreserveNUW = true;
15999 Guards.PreserveNSW = true;
16000 for (const SCEV *Expr : ExprsToRewrite) {
16001 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
16002 Guards.PreserveNUW &=
16003 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
16004 Guards.PreserveNSW &=
16005 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
16006 }
16007
16008 // Now that all rewrite information is collect, rewrite the collected
16009 // expressions with the information in the map. This applies information to
16010 // sub-expressions.
16011 if (ExprsToRewrite.size() > 1) {
16012 for (const SCEV *Expr : ExprsToRewrite) {
16013 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
16014 Guards.RewriteMap.erase(Expr);
16015 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
16016 }
16017 }
16018}
16019
16021 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
16022 /// in the map. It skips AddRecExpr because we cannot guarantee that the
16023 /// replacement is loop invariant in the loop of the AddRec.
16024 class SCEVLoopGuardRewriter
16025 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
16028
16030
16031 public:
16032 SCEVLoopGuardRewriter(ScalarEvolution &SE,
16033 const ScalarEvolution::LoopGuards &Guards)
16034 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap),
16035 NotEqual(Guards.NotEqual) {
16036 if (Guards.PreserveNUW)
16037 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
16038 if (Guards.PreserveNSW)
16039 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
16040 }
16041
16042 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
16043
16044 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
16045 return Map.lookup_or(Expr, Expr);
16046 }
16047
16048 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
16049 if (const SCEV *S = Map.lookup(Expr))
16050 return S;
16051
16052 // If we didn't find the extact ZExt expr in the map, check if there's
16053 // an entry for a smaller ZExt we can use instead.
16054 Type *Ty = Expr->getType();
16055 const SCEV *Op = Expr->getOperand(0);
16056 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
16057 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
16058 Bitwidth > Op->getType()->getScalarSizeInBits()) {
16059 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
16060 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
16061 if (const SCEV *S = Map.lookup(NarrowExt))
16062 return SE.getZeroExtendExpr(S, Ty);
16063 Bitwidth = Bitwidth / 2;
16064 }
16065
16067 Expr);
16068 }
16069
16070 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
16071 if (const SCEV *S = Map.lookup(Expr))
16072 return S;
16074 Expr);
16075 }
16076
16077 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
16078 if (const SCEV *S = Map.lookup(Expr))
16079 return S;
16081 }
16082
16083 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
16084 if (const SCEV *S = Map.lookup(Expr))
16085 return S;
16087 }
16088
16089 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
16090 // Helper to check if S is a subtraction (A - B) where A != B, and if so,
16091 // return UMax(S, 1).
16092 auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * {
16093 const SCEV *LHS, *RHS;
16094 if (MatchBinarySub(S, LHS, RHS)) {
16095 if (LHS > RHS)
16096 std::swap(LHS, RHS);
16097 if (NotEqual.contains({LHS, RHS})) {
16098 const SCEV *OneAlignedUp = getNextSCEVDivisibleByDivisor(
16099 SE.getOne(S->getType()), SE.getConstantMultiple(S), SE);
16100 return SE.getUMaxExpr(OneAlignedUp, S);
16101 }
16102 }
16103 return nullptr;
16104 };
16105
16106 // Check if Expr itself is a subtraction pattern with guard info.
16107 if (const SCEV *Rewritten = RewriteSubtraction(Expr))
16108 return Rewritten;
16109
16110 // Trip count expressions sometimes consist of adding 3 operands, i.e.
16111 // (Const + A + B). There may be guard info for A + B, and if so, apply
16112 // it.
16113 // TODO: Could more generally apply guards to Add sub-expressions.
16114 if (isa<SCEVConstant>(Expr->getOperand(0)) &&
16115 Expr->getNumOperands() == 3) {
16116 const SCEV *Add =
16117 SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2));
16118 if (const SCEV *Rewritten = RewriteSubtraction(Add))
16119 return SE.getAddExpr(
16120 Expr->getOperand(0), Rewritten,
16121 ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask));
16122 if (const SCEV *S = Map.lookup(Add))
16123 return SE.getAddExpr(Expr->getOperand(0), S);
16124 }
16126 bool Changed = false;
16127 for (const auto *Op : Expr->operands()) {
16128 Operands.push_back(
16130 Changed |= Op != Operands.back();
16131 }
16132 // We are only replacing operands with equivalent values, so transfer the
16133 // flags from the original expression.
16134 return !Changed ? Expr
16135 : SE.getAddExpr(Operands,
16137 Expr->getNoWrapFlags(), FlagMask));
16138 }
16139
16140 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
16142 bool Changed = false;
16143 for (const auto *Op : Expr->operands()) {
16144 Operands.push_back(
16146 Changed |= Op != Operands.back();
16147 }
16148 // We are only replacing operands with equivalent values, so transfer the
16149 // flags from the original expression.
16150 return !Changed ? Expr
16151 : SE.getMulExpr(Operands,
16153 Expr->getNoWrapFlags(), FlagMask));
16154 }
16155 };
16156
16157 if (RewriteMap.empty() && NotEqual.empty())
16158 return Expr;
16159
16160 SCEVLoopGuardRewriter Rewriter(SE, *this);
16161 return Rewriter.visit(Expr);
16162}
16163
16164const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
16165 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
16166}
16167
16169 const LoopGuards &Guards) {
16170 return Guards.rewrite(Expr);
16171}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
constexpr LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition Compiler.h:661
This file contains the declarations for the subclasses of Constant, which represent the different fla...
SmallPtrSet< const BasicBlock *, 8 > VisitedBlocks
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition IVUsers.cpp:48
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition Lint.cpp:539
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
#define G(x, y, z)
Definition MD5.cpp:55
#define T
#define T1
static constexpr unsigned SM(unsigned Version)
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
SI optimize exec mask operations pre RA
static void visit(BasicBlock &Start, std::function< bool(BasicBlock *)> op)
This file contains some templates that are useful if you are working with the STL at all.
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static const SCEV * getNextSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< unsigned > MaxLoopGuardCollectionDepth("scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static void GroupByComplexity(SmallVectorImpl< const SCEV * > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, SmallVectorImpl< const SCEV * > &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber)
Performs a number of common optimizations on the passed Ops.
static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS, ScalarEvolution &SE)
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, SmallVectorImpl< const SCEVPredicate * > *Predicates, ScalarEvolution &SE, const Loop *L)
Finds the minimum unsigned root of the following equation:
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS)
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool hasHugeExpression(ArrayRef< const SCEV * > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static bool collectDivisibilityInformation(ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS, DenseMap< const SCEV *, const SCEV * > &DivInfo, DenseMap< const SCEV *, APInt > &Multiples, ScalarEvolution &SE)
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, ArrayRef< const SCEV * > Ops, SCEV::NoWrapFlags Flags)
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * getPreviousSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static const SCEV * applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr, APInt Divisor, ScalarEvolution &SE)
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static std::optional< int > CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static bool CollectAddOperandsWithScales(SmallDenseMap< const SCEV *, APInt, 16 > &M, SmallVectorImpl< const SCEV * > &NewOps, APInt &AccumulatedConstant, ArrayRef< const SCEV * > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
static bool InBlock(const Value *V, const BasicBlock *BB)
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:114
static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")
static TableGen::Emitter::OptClass< SkeletonEmitter > X("gen-skeleton-class", "Generate example skeleton class")
static SymbolRef::Type getType(const Symbol *Sym)
Definition TapiFile.cpp:39
LocallyHashedType DenseMapInfo< LocallyHashedType >::Empty
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition VPlanSLP.cpp:247
static std::optional< bool > isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS, const Value *ARHS, const Value *BLHS, const Value *BRHS)
Return true if "icmp Pred BLHS BRHS" is true whenever "icmp PredALHS ARHS" is true.
Virtual Register Rewriter
Value * RHS
Value * LHS
BinaryOperator * Mul
static const uint32_t IV[8]
Definition blake3_impl.h:83
SCEVCastSinkingRewriter(ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
static const SCEV * rewrite(const SCEV *Scev, ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
const SCEV * visitUnknown(const SCEVUnknown *Expr)
const SCEV * visitMulExpr(const SCEVMulExpr *Expr)
const SCEV * visitAddExpr(const SCEVAddExpr *Expr)
const SCEV * visit(const SCEV *S)
Class for arbitrary precision integers.
Definition APInt.h:78
LLVM_ABI APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition APInt.cpp:1982
LLVM_ABI APInt zext(unsigned width) const
Zero extend to a new width.
Definition APInt.cpp:1023
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition APInt.h:424
uint64_t getZExtValue() const
Get zero extended value.
Definition APInt.h:1555
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition APInt.h:1406
LLVM_ABI APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition APInt.cpp:639
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition APInt.h:1527
LLVM_ABI APInt trunc(unsigned width) const
Truncate to new width.
Definition APInt.cpp:936
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition APInt.h:207
APInt abs() const
Get the absolute value.
Definition APInt.h:1810
bool sgt(const APInt &RHS) const
Signed greater than comparison.
Definition APInt.h:1208
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
Definition APInt.h:372
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition APInt.h:1189
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition APInt.h:381
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition APInt.h:467
LLVM_ABI APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition APInt.cpp:1677
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1503
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition APInt.h:1118
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition APInt.h:210
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition APInt.h:217
bool isNegative() const
Determine sign of this APInt.
Definition APInt.h:330
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition APInt.h:1173
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition APInt.h:220
bool isNonPositive() const
Determine if this APInt Value is non-positive (<= 0).
Definition APInt.h:362
unsigned countTrailingZeros() const
Definition APInt.h:1662
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition APInt.h:357
unsigned logBase2() const
Definition APInt.h:1776
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition APInt.h:834
LLVM_ABI APInt multiplicativeInverse() const
Definition APInt.cpp:1285
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition APInt.h:1157
LLVM_ABI APInt sext(unsigned width) const
Sign extend to a new width.
Definition APInt.cpp:996
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition APInt.h:880
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
Definition APInt.h:441
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition APInt.h:307
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition APInt.h:342
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition APInt.h:1137
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition APInt.h:201
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition APInt.h:433
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:240
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition APInt.h:1228
This templated class represents "all analyses that operate over <aparticular IR unit>" (e....
Definition Analysis.h:50
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
iterator end() const
Definition ArrayRef.h:131
size_t size() const
size - Get the array size.
Definition ArrayRef.h:142
iterator begin() const
Definition ArrayRef.h:130
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< WeakVH > assumptions()
Access the list of assumption handles currently tracked for this function.
LLVM_ABI bool isSingleEdge() const
Check if this is the only edge between Start and End.
LLVM Basic Block Representation.
Definition BasicBlock.h:62
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:470
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
LLVM_ABI const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const Instruction & front() const
Definition BasicBlock.h:493
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition BasicBlock.h:233
LLVM_ABI unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
LLVM_ABI Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
BinaryOps getOpcode() const
Definition InstrTypes.h:374
Conditional or Unconditional Branch instruction.
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Value * getCondition() const
This class represents a function call, abstracting a target machine's calling convention.
virtual void deleted()
Callback for Value destruction.
void setValPtr(Value *P)
bool isFalseWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:948
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:676
@ ICMP_SLT
signed less than
Definition InstrTypes.h:705
@ ICMP_SLE
signed less or equal
Definition InstrTypes.h:706
@ ICMP_UGE
unsigned greater or equal
Definition InstrTypes.h:700
@ ICMP_UGT
unsigned greater than
Definition InstrTypes.h:699
@ ICMP_SGT
signed greater than
Definition InstrTypes.h:703
@ ICMP_ULT
unsigned less than
Definition InstrTypes.h:701
@ ICMP_NE
not equal
Definition InstrTypes.h:698
@ ICMP_SGE
signed greater or equal
Definition InstrTypes.h:704
@ ICMP_ULE
unsigned less or equal
Definition InstrTypes.h:702
bool isSigned() const
Definition InstrTypes.h:930
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition InstrTypes.h:827
bool isTrueWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:942
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition InstrTypes.h:789
bool isUnsigned() const
Definition InstrTypes.h:936
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition InstrTypes.h:926
An abstraction over a floating-point predicate, and a pack of an integer predicate with samesign info...
static LLVM_ABI std::optional< CmpPredicate > getMatching(CmpPredicate A, CmpPredicate B)
Compares two CmpPredicates taking samesign into account and returns the canonicalized CmpPredicate if...
LLVM_ABI CmpInst::Predicate getPreferredSignedPredicate() const
Attempts to return a signed CmpInst::Predicate from the CmpPredicate.
CmpInst::Predicate dropSameSign() const
Drops samesign information.
static LLVM_ABI Constant * getNot(Constant *C)
static Constant * getPtrAdd(Constant *Ptr, Constant *Offset, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReduced=nullptr)
Create a getelementptr i8, ptr, offset constant expression.
Definition Constants.h:1311
static LLVM_ABI Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static LLVM_ABI Constant * getPtrToAddr(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static LLVM_ABI Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
static LLVM_ABI Constant * getNeg(Constant *C, bool HasNSW=false)
static LLVM_ABI Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
This is the shared class of boolean and integer constants.
Definition Constants.h:87
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:219
static LLVM_ABI ConstantInt * getFalse(LLVMContext &Context)
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition Constants.h:168
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition Constants.h:159
static LLVM_ABI ConstantInt * getBool(LLVMContext &Context, bool V)
This class represents a range of values.
LLVM_ABI ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
LLVM_ABI ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
LLVM_ABI bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
const APInt & getLower() const
Return the lower value for this range.
LLVM_ABI ConstantRange urem(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an unsigned remainder operation of...
LLVM_ABI bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
LLVM_ABI bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other?
LLVM_ABI bool isEmptySet() const
Return true if this set contains no members.
LLVM_ABI ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
LLVM_ABI bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
LLVM_ABI APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
LLVM_ABI bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
LLVM_ABI void print(raw_ostream &OS) const
Print out the bounds to a stream.
LLVM_ABI ConstantRange truncate(uint32_t BitWidth, unsigned NoWrapKind=0) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
LLVM_ABI ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
LLVM_ABI ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static LLVM_ABI ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
LLVM_ABI bool contains(const APInt &Val) const
Return true if the specified value is in the set.
LLVM_ABI APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
LLVM_ABI ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
LLVM_ABI APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
static LLVM_ABI ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
LLVM_ABI unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
LLVM_ABI ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
LLVM_ABI ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static LLVM_ABI ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
LLVM_ABI const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
LLVM_ABI IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
LLVM_ABI unsigned getIndexTypeSizeInBits(Type *Ty) const
The size in bits of the index used in GEP calculation for this type.
LLVM_ABI IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition DataLayout.h:771
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition DenseMap.h:205
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:178
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition DenseMap.h:256
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition DenseMap.h:74
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition DenseMap.h:191
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition DenseMap.h:174
iterator end()
Definition DenseMap.h:81
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition DenseMap.h:169
void swap(DerivedT &RHS)
Definition DenseMap.h:371
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:241
Analysis pass which computes a DominatorTree.
Definition Dominators.h:283
Legacy analysis pass which computes a DominatorTree.
Definition Dominators.h:321
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:164
LLVM_ABI bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
FoldingSetNodeIDRef - This class describes a reference to an interned FoldingSetNodeID,...
Definition FoldingSet.h:172
FoldingSetNodeID - This class is used to gather all the unique data bits of a node.
Definition FoldingSet.h:209
FunctionPass(char &pid)
Definition Pass.h:316
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static LLVM_ABI Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
static bool isPrivateLinkage(LinkageTypes Linkage)
static bool isInternalLinkage(LinkageTypes Linkage)
This instruction compares its operands according to the predicate given to the constructor.
CmpPredicate getCmpPredicate() const
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
CmpPredicate getSwappedCmpPredicate() const
static LLVM_ABI bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
CmpPredicate getInverseCmpPredicate() const
Predicate getNonStrictCmpPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
Predicate getFlippedSignednessPredicate() const
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ.
static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred)
static bool isEquality(Predicate P)
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
LLVM_ABI bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
LLVM_ABI bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
LLVM_ABI bool isIdenticalToWhenDefined(const Instruction *I, bool IntersectAttrs=false) const LLVM_READONLY
This is like isIdenticalTo, except that it ignores the SubclassOptionalData flags,...
Class to represent integer types.
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:318
A helper class to return the specified delimiter string after the first invocation of operator String...
An instruction for reading from memory.
Analysis pass that exposes the LoopInfo for a function.
Definition LoopInfo.h:569
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition LoopInfo.h:596
Represents a single loop in the control flow graph.
Definition LoopInfo.h:40
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition LoopInfo.cpp:61
Metadata node.
Definition Metadata.h:1080
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition Operator.h:43
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition Operator.h:78
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition Operator.h:111
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition Operator.h:105
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
LLVM_ABI void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
LLVM_ABI const SCEVPredicate & getPredicate() const
LLVM_ABI const SCEV * getPredicatedSCEV(const SCEV *Expr)
Returns the rewritten SCEV for Expr in the context of the current SCEV predicate.
LLVM_ABI bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
LLVM_ABI void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
LLVM_ABI void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
LLVM_ABI bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
LLVM_ABI PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
LLVM_ABI const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
LLVM_ABI unsigned getSmallConstantMaxTripCount()
Returns the upper bound of the loop trip count as a normal unsigned value, or 0 if the trip count is ...
LLVM_ABI const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition Analysis.h:275
constexpr bool isValid() const
Definition Register.h:112
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
LLVM_ABI const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
LLVM_ABI const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
LLVM_ABI const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
This is the base class for unary cast operator classes.
const SCEV * getOperand() const
LLVM_ABI SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Implementation of the SCEVPredicate interface.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
LLVM_ABI SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
const SCEV * getOperand(unsigned i) const
ArrayRef< const SCEV * > operands() const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
SCEVPredicate(const SCEVPredicate &)=default
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const =0
Returns true if this predicate implies N.
SCEVPredicateKind Kind
This class represents a cast from a pointer to a pointer-sized integer value, without capturing the p...
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty)
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds, ScalarEvolution &SE)
Union predicates don't get cached so create a dummy set ID for it.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
LLVM_ABI ArrayRef< const SCEV * > operands() const
Return operands of this SCEV expression.
unsigned short getExpressionSize() const
LLVM_ABI bool isOne() const
Return true if the expression is a constant one.
LLVM_ABI bool isZero() const
Return true if the expression is a constant zero.
LLVM_ABI void dump() const
This method is used for debugging.
LLVM_ABI bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
LLVM_ABI bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
LLVM_ABI void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEV(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, unsigned short ExpressionSize)
SCEVTypes getSCEVType() const
LLVM_ABI Type * getType() const
Return the LLVM type of this SCEV expression.
NoWrapFlags
NoWrapFlags are bitfield indices into SubclassData.
Analysis pass that exposes the ScalarEvolution for a function.
LLVM_ABI ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LLVM_ABI LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
LLVM_ABI const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
LLVM_ABI bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
LLVM_ABI bool isKnownOnEveryIteration(CmpPredicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
LLVM_ABI const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
LLVM_ABI const SCEV * getSMaxExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
LLVM_ABI Type * getWiderType(Type *Ty1, Type *Ty2) const
LLVM_ABI const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
LLVM_ABI bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
LLVM_ABI const SCEV * getURemExpr(const SCEV *LHS, const SCEV *RHS)
Represents an unsigned remainder expression based on unsigned division.
LLVM_ABI bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
LLVM_ABI const SCEV * getPredicatedConstantMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getConstantMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
LLVM_ABI bool isLoopEntryGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
LLVM_ABI bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
LLVM_ABI const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
LLVM_ABI const SCEV * getSMinExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
LLVM_ABI const SCEV * getUMaxExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
LLVM_ABI const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
LLVM_ABI bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?
LLVM_ABI ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
LLVM_ABI const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
LLVM_ABI uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
LLVM_ABI const SCEV * getConstant(ConstantInt *V)
LLVM_ABI const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
LLVM_ABI const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
LLVM_ABI ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
LLVM_ABI const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI const SCEV * getLosslessPtrToIntExpr(const SCEV *Op)
LLVM_ABI const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
LLVM_ABI const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
LLVM_ABI std::optional< bool > evaluatePredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
LLVM_ABI unsigned getSmallConstantMaxTripCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > *Predicates=nullptr)
Returns the upper bound of the loop trip count as a normal unsigned value.
LLVM_ABI const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
LLVM_ABI bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
LLVM_ABI void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
LLVM_ABI bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
LLVM_ABI bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
LLVM_ABI bool SimplifyICmpOperands(CmpPredicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
LLVM_ABI const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LLVM_ABI LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
LLVM_ABI bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
LLVM_ABI const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
LLVM_ABI bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
LLVM_ABI const SCEV * getUDivExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
LLVM_ABI Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
LLVM_ABI const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
LLVM_ABI bool haveSameSign(const SCEV *S1, const SCEV *S2)
Return true if we know that S1 and S2 must have the same sign.
LLVM_ABI const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
LLVM_ABI const SCEV * getElementCount(Type *Ty, ElementCount EC, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
LLVM_ABI bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
LLVM_ABI void print(raw_ostream &OS) const
LLVM_ABI const SCEV * getUMinExpr(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
LLVM_ABI const SCEV * getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl< const SCEVPredicate * > *Predicates, ExitCountKind Kind=Exact)
Same as above except this uses the predicated backedge taken info and may require predicates.
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
LLVM_ABI void forgetTopmostLoop(const Loop *L)
LLVM_ABI void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
LLVM_ABI const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
LLVM_ABI const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
LLVM_ABI const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
LLVM_ABI bool isLoopBackedgeGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
LLVM_ABI APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
LLVM_ABI bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
LLVM_ABI BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
LLVM_ABI const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
LLVM_ABI const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
LLVM_ABI bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
LLVM_ABI const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LLVM_ABI APInt getConstantMultiple(const SCEV *S, const Instruction *CtxI=nullptr)
Returns the max constant multiple of S.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
LLVM_ABI bool isKnownMultipleOf(const SCEV *S, uint64_t M, SmallVectorImpl< const SCEVPredicate * > &Assumptions)
Check that S is a multiple of M.
LLVM_ABI const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
LLVM_ABI bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero=false, bool OrNegative=false)
Test if the given expression is known to be a power of 2.
LLVM_ABI std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
LLVM_ABI void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
LLVM_ABI bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
LLVM_ABI std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
LLVM_ABI const SCEV * getCouldNotCompute()
LLVM_ABI bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S, const Instruction *CtxI=nullptr)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
LLVM_ABI const SCEV * getGEPExpr(GEPOperator *GEP, ArrayRef< const SCEV * > IndexExprs)
Returns an expression for a GEP.
LLVM_ABI const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
LLVM_ABI const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
LLVM_ABI void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
LLVM_ABI const SCEV * getVScale(Type *Ty)
LLVM_ABI unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
LLVM_ABI bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
LLVM_ABI const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
LLVM_ABI const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
LLVM_ABI void forgetAllLoops()
LLVM_ABI bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
LLVM_ABI const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
LLVM_ABI const SCEV * getMulExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
LLVM_ABI const SCEV * getPtrToAddrExpr(const SCEV *Op)
LLVM_ABI const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
LLVM_ABI const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
LLVM_ABI const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
LLVM_ABI std::optional< bool > evaluatePredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
LLVM_ABI const SCEV * getUnknown(Value *V)
LLVM_ABI std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
LLVM_ABI const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, int Mask)
Convenient NoWrapFlags manipulation that hides enum casts and is visible in the ScalarEvolution name ...
LLVM_ABI std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
LLVM_ABI bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
LLVM_ABI const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
LLVM_ABI std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
LLVM_ABI bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
LLVM_ABI bool isKnownPredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * getUDivExactExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
LLVM_ABI const SCEV * getAddExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
LLVM_ABI bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
LLVM_ABI bool isKnownPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI bool isKnownViaInduction(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVM_ABI void verify() const
LLVMContext & getContext() const
Implements a dense probed hash-table based set with some number of buckets stored inline.
Definition DenseSet.h:291
size_type size() const
Definition SmallPtrSet.h:99
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
reference emplace_back(ArgTypes &&... Args)
void reserve(size_type N)
iterator erase(const_iterator CI)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
iterator insert(iterator I, T &&Elt)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition DataLayout.h:723
TypeSize getElementOffset(unsigned Idx) const
Definition DataLayout.h:754
TypeSize getSizeInBits() const
Definition DataLayout.h:734
Class to represent struct types.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:296
bool isPointerTy() const
True if this is an instance of PointerType.
Definition Type.h:267
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:197
static LLVM_ABI IntegerType * getInt1Ty(LLVMContext &C)
Definition Type.cpp:293
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition Type.h:255
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:240
static LLVM_ABI IntegerType * getIntNTy(LLVMContext &C, unsigned N)
Definition Type.cpp:300
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
op_range operands()
Definition User.h:267
Use & Op()
Definition User.h:171
Value * getOperand(unsigned i) const
Definition User.h:207
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.h:259
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition Value.h:543
LLVM_ABI void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:322
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition TypeSize.h:168
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
Definition ilist_node.h:34
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition APInt.h:2263
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition APInt.h:2268
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition APInt.h:2273
LLVM_ABI std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition APInt.cpp:2823
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition APInt.h:2278
LLVM_ABI APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition APInt.cpp:798
@ Entry
Definition COFF.h:862
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
int getMinValue(MCInstrInfo const &MCII, MCInst const &MCI)
Return the minimum value of an extendable operand.
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI Function * getDeclarationIfExists(const Module *M, ID id)
Look up the Function declaration of the intrinsic id in the Module M and return it if it exists.
Predicate
Predicate - These are "(BI << 5) | BO" for various predicates.
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
ap_match< APInt > m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
ExtractValue_match< Ind, Val_t > m_ExtractValue(const Val_t &V)
Match a single index ExtractValue instruction.
bind_ty< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
brc_match< Cond_t, bind_ty< BasicBlock >, bind_ty< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
class_match< BasicBlock > m_BasicBlock()
Match an arbitrary basic block value and ignore it.
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
class_match< const SCEVVScale > m_SCEVVScale()
bind_cst_ty m_scev_APInt(const APInt *&C)
Match an SCEV constant and bind it to an APInt.
cst_pred_ty< is_all_ones > m_scev_AllOnes()
Match an integer with all bits set.
SCEVUnaryExpr_match< SCEVZeroExtendExpr, Op0_t > m_scev_ZExt(const Op0_t &Op0)
is_undef_or_poison m_scev_UndefOrPoison()
Match an SCEVUnknown wrapping undef or poison.
class_match< const SCEVConstant > m_SCEVConstant()
cst_pred_ty< is_one > m_scev_One()
Match an integer 1.
specificloop_ty m_SpecificLoop(const Loop *L)
SCEVAffineAddRec_match< Op0_t, Op1_t, class_match< const Loop > > m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1)
bind_ty< const SCEVMulExpr > m_scev_Mul(const SCEVMulExpr *&V)
SCEVUnaryExpr_match< SCEVSignExtendExpr, Op0_t > m_scev_SExt(const Op0_t &Op0)
cst_pred_ty< is_zero > m_scev_Zero()
Match an integer 0.
SCEVUnaryExpr_match< SCEVTruncateExpr, Op0_t > m_scev_Trunc(const Op0_t &Op0)
bool match(const SCEV *S, const Pattern &P)
SCEVBinaryExpr_match< SCEVUDivExpr, Op0_t, Op1_t > m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1)
specificscev_ty m_scev_Specific(const SCEV *S)
Match if we have a specific specified SCEV.
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagNUW, true > m_scev_c_NUWMul(const Op0_t &Op0, const Op1_t &Op1)
class_match< const Loop > m_Loop()
bind_ty< const SCEVAddExpr > m_scev_Add(const SCEVAddExpr *&V)
bind_ty< const SCEVUnknown > m_SCEVUnknown(const SCEVUnknown *&V)
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagAnyWrap, true > m_scev_c_Mul(const Op0_t &Op0, const Op1_t &Op1)
SCEVBinaryExpr_match< SCEVSMaxExpr, Op0_t, Op1_t > m_scev_SMax(const Op0_t &Op0, const Op1_t &Op1)
SCEVURem_match< Op0_t, Op1_t > m_scev_URem(Op0_t LHS, Op1_t RHS, ScalarEvolution &SE)
Match the mathematical pattern A - (A / B) * B, where A and B can be arbitrary expressions.
class_match< const SCEV > m_SCEV()
@ Valid
The data is already valid.
initializer< Ty > init(const Ty &Val)
LocationClass< Ty > location(Ty &L)
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
Definition CoroShape.h:31
constexpr double e
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
Definition Types.h:26
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:316
@ Offset
Definition DWP.cpp:532
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
void stable_sort(R &&Range)
Definition STLExtras.h:2116
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1739
SaveAndRestore(T &) -> SaveAndRestore< T >
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr, unsigned DynamicVGPRBlockSize=0)
LLVM_ABI bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
LLVM_ABI bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
LLVM_ABI bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
InterleavedRange< Range > interleaved(const Range &R, StringRef Separator=", ", StringRef Prefix="", StringRef Suffix="")
Output range R as a sequence of interleaved elements.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
LLVM_ABI bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
auto successors(const MachineBasicBlock *BB)
scope_exit(Callable) -> scope_exit< Callable >
constexpr from_range_t from_range
auto dyn_cast_if_present(const Y &Val)
dyn_cast_if_present<X> - Functionally identical to dyn_cast, except that a null (or none in the case ...
Definition Casting.h:732
bool set_is_subset(const S1Ty &S1, const S2Ty &S2)
set_is_subset(A, B) - Return true iff A in B
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition STLExtras.h:2208
constexpr bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition MathExtras.h:243
LLVM_ABI Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
unsigned short computeExpressionSize(ArrayRef< const SCEV * > Args)
void * PointerTy
LLVM_ABI bool VerifySCEV
auto uninitialized_copy(R &&Src, IterTy Dst)
Definition STLExtras.h:2111
bool isa_and_nonnull(const Y &Val)
Definition Casting.h:676
LLVM_ABI ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition bit.h:202
LLVM_ABI Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
LLVM_ABI bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
DomTreeNodeBase< BasicBlock > DomTreeNode
Definition Dominators.h:94
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition STLExtras.h:2200
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1746
iterator_range< pointee_iterator< WrappedIteratorT > > make_pointee_range(RangeT &&Range)
Definition iterator.h:341
auto reverse(ContainerTy &&C)
Definition STLExtras.h:408
LLVM_ABI bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
LLVM_ABI bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
LLVM_ABI bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
LLVM_ABI void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
LLVM_ABI bool programUndefinedIfPoison(const Instruction *Inst)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
bool isPointerTy(const Type *T)
Definition SPIRVUtils.h:364
FunctionAddr VTableAddr Count
Definition InstrProf.h:139
LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
LLVM_ABI bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition STLExtras.h:2012
DWARFExpression::Operation Op
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition STLExtras.h:2088
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
ArrayRef(const T &OneElt) -> ArrayRef< T >
LLVM_ABI unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Return the number of times the sign bit of the register is replicated into the other bits.
constexpr unsigned BitWidth
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1917
LLVM_ABI bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition STLExtras.h:2019
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
constexpr bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition MathExtras.h:248
auto predecessors(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1947
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition Sequence.h:305
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
LLVM_ABI Constant * ConstantFoldInstOperands(const Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:870
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:872
#define N
#define NC
Definition regutils.h:42
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition Analysis.h:29
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition KnownBits.h:317
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition KnownBits.h:108
static LLVM_ABI KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
unsigned getBitWidth() const
Get the bit width of this value.
Definition KnownBits.h:44
static LLVM_ABI KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition KnownBits.h:202
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition KnownBits.h:148
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition KnownBits.h:132
bool isNegative() const
Returns true if this value is known to be negative.
Definition KnownBits.h:105
static LLVM_ABI KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
An object of this class is returned by queries that could not be answered.
static LLVM_ABI bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
LLVM_ABI ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
SmallVector< const SCEVPredicate *, 4 > Predicates
A vector of predicate guards for this ExitLimit.