LLVM 23.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
65#include "llvm/ADT/FoldingSet.h"
66#include "llvm/ADT/STLExtras.h"
67#include "llvm/ADT/ScopeExit.h"
68#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/Statistic.h"
73#include "llvm/ADT/StringRef.h"
83#include "llvm/Config/llvm-config.h"
84#include "llvm/IR/Argument.h"
85#include "llvm/IR/BasicBlock.h"
86#include "llvm/IR/CFG.h"
87#include "llvm/IR/Constant.h"
89#include "llvm/IR/Constants.h"
90#include "llvm/IR/DataLayout.h"
92#include "llvm/IR/Dominators.h"
93#include "llvm/IR/Function.h"
94#include "llvm/IR/GlobalAlias.h"
95#include "llvm/IR/GlobalValue.h"
97#include "llvm/IR/InstrTypes.h"
98#include "llvm/IR/Instruction.h"
101#include "llvm/IR/Intrinsics.h"
102#include "llvm/IR/LLVMContext.h"
103#include "llvm/IR/Operator.h"
104#include "llvm/IR/PatternMatch.h"
105#include "llvm/IR/Type.h"
106#include "llvm/IR/Use.h"
107#include "llvm/IR/User.h"
108#include "llvm/IR/Value.h"
109#include "llvm/IR/Verifier.h"
111#include "llvm/Pass.h"
112#include "llvm/Support/Casting.h"
115#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136using namespace SCEVPatternMatch;
137
138#define DEBUG_TYPE "scalar-evolution"
139
140STATISTIC(NumExitCountsComputed,
141 "Number of loop exits with predictable exit counts");
142STATISTIC(NumExitCountsNotComputed,
143 "Number of loop exits without predictable exit counts");
144STATISTIC(NumBruteForceTripCountsComputed,
145 "Number of loops with trip counts computed by force");
146
147#ifdef EXPENSIVE_CHECKS
148bool llvm::VerifySCEV = true;
149#else
150bool llvm::VerifySCEV = false;
151#endif
152
154 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
155 cl::desc("Maximum number of iterations SCEV will "
156 "symbolically execute a constant "
157 "derived loop"),
158 cl::init(100));
159
161 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
162 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
164 "verify-scev-strict", cl::Hidden,
165 cl::desc("Enable stricter verification with -verify-scev is passed"));
166
168 "scev-verify-ir", cl::Hidden,
169 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
170 cl::init(false));
171
173 "scev-mulops-inline-threshold", cl::Hidden,
174 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
175 cl::init(32));
176
178 "scev-addops-inline-threshold", cl::Hidden,
179 cl::desc("Threshold for inlining addition operands into a SCEV"),
180 cl::init(500));
181
183 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
184 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
185 cl::init(32));
186
188 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
189 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
190 cl::init(2));
191
193 "scalar-evolution-max-value-compare-depth", cl::Hidden,
194 cl::desc("Maximum depth of recursive value complexity comparisons"),
195 cl::init(2));
196
198 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
199 cl::desc("Maximum depth of recursive arithmetics"),
200 cl::init(32));
201
203 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
204 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
205
207 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
208 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
209 cl::init(8));
210
212 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
213 cl::desc("Max coefficients in AddRec during evolving"),
214 cl::init(8));
215
217 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
218 cl::desc("Size of the expression which is considered huge"),
219 cl::init(4096));
220
222 "scev-range-iter-threshold", cl::Hidden,
223 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
224 cl::init(32));
225
227 "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
228 cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1));
229
230static cl::opt<bool>
231ClassifyExpressions("scalar-evolution-classify-expressions",
232 cl::Hidden, cl::init(true),
233 cl::desc("When printing analysis, include information on every instruction"));
234
236 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
237 cl::init(false),
238 cl::desc("Use more powerful methods of sharpening expression ranges. May "
239 "be costly in terms of compile time"));
240
242 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
243 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
244 "Phi strongly connected components"),
245 cl::init(8));
246
247static cl::opt<bool>
248 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
249 cl::desc("Handle <= and >= in finite loops"),
250 cl::init(true));
251
253 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
254 cl::desc("Infer nuw/nsw flags using context where suitable"),
255 cl::init(true));
256
257//===----------------------------------------------------------------------===//
258// SCEV class definitions
259//===----------------------------------------------------------------------===//
260
261//===----------------------------------------------------------------------===//
262// Implementation of the SCEV class.
263//
264
265#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
267 print(dbgs());
268 dbgs() << '\n';
269}
270#endif
271
272void SCEV::print(raw_ostream &OS) const {
273 switch (getSCEVType()) {
274 case scConstant:
275 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
276 return;
277 case scVScale:
278 OS << "vscale";
279 return;
280 case scPtrToInt: {
281 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
282 const SCEV *Op = PtrToInt->getOperand();
283 OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
284 << *PtrToInt->getType() << ")";
285 return;
286 }
287 case scTruncate: {
288 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
289 const SCEV *Op = Trunc->getOperand();
290 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
291 << *Trunc->getType() << ")";
292 return;
293 }
294 case scZeroExtend: {
296 const SCEV *Op = ZExt->getOperand();
297 OS << "(zext " << *Op->getType() << " " << *Op << " to "
298 << *ZExt->getType() << ")";
299 return;
300 }
301 case scSignExtend: {
303 const SCEV *Op = SExt->getOperand();
304 OS << "(sext " << *Op->getType() << " " << *Op << " to "
305 << *SExt->getType() << ")";
306 return;
307 }
308 case scAddRecExpr: {
309 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
310 OS << "{" << *AR->getOperand(0);
311 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
312 OS << ",+," << *AR->getOperand(i);
313 OS << "}<";
314 if (AR->hasNoUnsignedWrap())
315 OS << "nuw><";
316 if (AR->hasNoSignedWrap())
317 OS << "nsw><";
318 if (AR->hasNoSelfWrap() &&
320 OS << "nw><";
321 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
322 OS << ">";
323 return;
324 }
325 case scAddExpr:
326 case scMulExpr:
327 case scUMaxExpr:
328 case scSMaxExpr:
329 case scUMinExpr:
330 case scSMinExpr:
332 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
333 const char *OpStr = nullptr;
334 switch (NAry->getSCEVType()) {
335 case scAddExpr: OpStr = " + "; break;
336 case scMulExpr: OpStr = " * "; break;
337 case scUMaxExpr: OpStr = " umax "; break;
338 case scSMaxExpr: OpStr = " smax "; break;
339 case scUMinExpr:
340 OpStr = " umin ";
341 break;
342 case scSMinExpr:
343 OpStr = " smin ";
344 break;
346 OpStr = " umin_seq ";
347 break;
348 default:
349 llvm_unreachable("There are no other nary expression types.");
350 }
351 OS << "("
353 << ")";
354 switch (NAry->getSCEVType()) {
355 case scAddExpr:
356 case scMulExpr:
357 if (NAry->hasNoUnsignedWrap())
358 OS << "<nuw>";
359 if (NAry->hasNoSignedWrap())
360 OS << "<nsw>";
361 break;
362 default:
363 // Nothing to print for other nary expressions.
364 break;
365 }
366 return;
367 }
368 case scUDivExpr: {
369 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
370 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
371 return;
372 }
373 case scUnknown:
374 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
375 return;
377 OS << "***COULDNOTCOMPUTE***";
378 return;
379 }
380 llvm_unreachable("Unknown SCEV kind!");
381}
382
384 switch (getSCEVType()) {
385 case scConstant:
386 return cast<SCEVConstant>(this)->getType();
387 case scVScale:
388 return cast<SCEVVScale>(this)->getType();
389 case scPtrToInt:
390 case scTruncate:
391 case scZeroExtend:
392 case scSignExtend:
393 return cast<SCEVCastExpr>(this)->getType();
394 case scAddRecExpr:
395 return cast<SCEVAddRecExpr>(this)->getType();
396 case scMulExpr:
397 return cast<SCEVMulExpr>(this)->getType();
398 case scUMaxExpr:
399 case scSMaxExpr:
400 case scUMinExpr:
401 case scSMinExpr:
402 return cast<SCEVMinMaxExpr>(this)->getType();
404 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
405 case scAddExpr:
406 return cast<SCEVAddExpr>(this)->getType();
407 case scUDivExpr:
408 return cast<SCEVUDivExpr>(this)->getType();
409 case scUnknown:
410 return cast<SCEVUnknown>(this)->getType();
412 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
413 }
414 llvm_unreachable("Unknown SCEV kind!");
415}
416
418 switch (getSCEVType()) {
419 case scConstant:
420 case scVScale:
421 case scUnknown:
422 return {};
423 case scPtrToInt:
424 case scTruncate:
425 case scZeroExtend:
426 case scSignExtend:
427 return cast<SCEVCastExpr>(this)->operands();
428 case scAddRecExpr:
429 case scAddExpr:
430 case scMulExpr:
431 case scUMaxExpr:
432 case scSMaxExpr:
433 case scUMinExpr:
434 case scSMinExpr:
436 return cast<SCEVNAryExpr>(this)->operands();
437 case scUDivExpr:
438 return cast<SCEVUDivExpr>(this)->operands();
440 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
441 }
442 llvm_unreachable("Unknown SCEV kind!");
443}
444
445bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
446
447bool SCEV::isOne() const { return match(this, m_scev_One()); }
448
449bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
450
453 if (!Mul) return false;
454
455 // If there is a constant factor, it will be first.
456 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
457 if (!SC) return false;
458
459 // Return true if the value is negative, this matches things like (-42 * V).
460 return SC->getAPInt().isNegative();
461}
462
465
467 return S->getSCEVType() == scCouldNotCompute;
468}
469
472 ID.AddInteger(scConstant);
473 ID.AddPointer(V);
474 void *IP = nullptr;
475 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
476 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
477 UniqueSCEVs.InsertNode(S, IP);
478 return S;
479}
480
482 return getConstant(ConstantInt::get(getContext(), Val));
483}
484
485const SCEV *
488 // TODO: Avoid implicit trunc?
489 // See https://github.com/llvm/llvm-project/issues/112510.
490 return getConstant(
491 ConstantInt::get(ITy, V, isSigned, /*ImplicitTrunc=*/true));
492}
493
496 ID.AddInteger(scVScale);
497 ID.AddPointer(Ty);
498 void *IP = nullptr;
499 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
500 return S;
501 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
502 UniqueSCEVs.InsertNode(S, IP);
503 return S;
504}
505
507 SCEV::NoWrapFlags Flags) {
508 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
509 if (EC.isScalable())
510 Res = getMulExpr(Res, getVScale(Ty), Flags);
511 return Res;
512}
513
515 const SCEV *op, Type *ty)
516 : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
517
518SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
519 Type *ITy)
520 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
521 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
522 "Must be a non-bit-width-changing pointer-to-integer cast!");
523}
524
526 SCEVTypes SCEVTy, const SCEV *op,
527 Type *ty)
528 : SCEVCastExpr(ID, SCEVTy, op, ty) {}
529
530SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
531 Type *ty)
533 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
534 "Cannot truncate non-integer value!");
535}
536
537SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
538 const SCEV *op, Type *ty)
540 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
541 "Cannot zero extend non-integer value!");
542}
543
544SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
545 const SCEV *op, Type *ty)
547 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
548 "Cannot sign extend non-integer value!");
549}
550
552 // Clear this SCEVUnknown from various maps.
553 SE->forgetMemoizedResults(this);
554
555 // Remove this SCEVUnknown from the uniquing map.
556 SE->UniqueSCEVs.RemoveNode(this);
557
558 // Release the value.
559 setValPtr(nullptr);
560}
561
562void SCEVUnknown::allUsesReplacedWith(Value *New) {
563 // Clear this SCEVUnknown from various maps.
564 SE->forgetMemoizedResults(this);
565
566 // Remove this SCEVUnknown from the uniquing map.
567 SE->UniqueSCEVs.RemoveNode(this);
568
569 // Replace the value pointer in case someone is still using this SCEVUnknown.
570 setValPtr(New);
571}
572
573//===----------------------------------------------------------------------===//
574// SCEV Utilities
575//===----------------------------------------------------------------------===//
576
577/// Compare the two values \p LV and \p RV in terms of their "complexity" where
578/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
579/// operands in SCEV expressions.
580static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
581 Value *RV, unsigned Depth) {
583 return 0;
584
585 // Order pointer values after integer values. This helps SCEVExpander form
586 // GEPs.
587 bool LIsPointer = LV->getType()->isPointerTy(),
588 RIsPointer = RV->getType()->isPointerTy();
589 if (LIsPointer != RIsPointer)
590 return (int)LIsPointer - (int)RIsPointer;
591
592 // Compare getValueID values.
593 unsigned LID = LV->getValueID(), RID = RV->getValueID();
594 if (LID != RID)
595 return (int)LID - (int)RID;
596
597 // Sort arguments by their position.
598 if (const auto *LA = dyn_cast<Argument>(LV)) {
599 const auto *RA = cast<Argument>(RV);
600 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
601 return (int)LArgNo - (int)RArgNo;
602 }
603
604 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
605 const auto *RGV = cast<GlobalValue>(RV);
606
607 if (auto L = LGV->getLinkage() - RGV->getLinkage())
608 return L;
609
610 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
611 auto LT = GV->getLinkage();
612 return !(GlobalValue::isPrivateLinkage(LT) ||
614 };
615
616 // Use the names to distinguish the two values, but only if the
617 // names are semantically important.
618 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
619 return LGV->getName().compare(RGV->getName());
620 }
621
622 // For instructions, compare their loop depth, and their operand count. This
623 // is pretty loose.
624 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
625 const auto *RInst = cast<Instruction>(RV);
626
627 // Compare loop depths.
628 const BasicBlock *LParent = LInst->getParent(),
629 *RParent = RInst->getParent();
630 if (LParent != RParent) {
631 unsigned LDepth = LI->getLoopDepth(LParent),
632 RDepth = LI->getLoopDepth(RParent);
633 if (LDepth != RDepth)
634 return (int)LDepth - (int)RDepth;
635 }
636
637 // Compare the number of operands.
638 unsigned LNumOps = LInst->getNumOperands(),
639 RNumOps = RInst->getNumOperands();
640 if (LNumOps != RNumOps)
641 return (int)LNumOps - (int)RNumOps;
642
643 for (unsigned Idx : seq(LNumOps)) {
644 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
645 RInst->getOperand(Idx), Depth + 1);
646 if (Result != 0)
647 return Result;
648 }
649 }
650
651 return 0;
652}
653
654// Return negative, zero, or positive, if LHS is less than, equal to, or greater
655// than RHS, respectively. A three-way result allows recursive comparisons to be
656// more efficient.
657// If the max analysis depth was reached, return std::nullopt, assuming we do
658// not know if they are equivalent for sure.
659static std::optional<int>
660CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
661 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
662 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
663 if (LHS == RHS)
664 return 0;
665
666 // Primarily, sort the SCEVs by their getSCEVType().
667 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
668 if (LType != RType)
669 return (int)LType - (int)RType;
670
672 return std::nullopt;
673
674 // Aside from the getSCEVType() ordering, the particular ordering
675 // isn't very important except that it's beneficial to be consistent,
676 // so that (a + b) and (b + a) don't end up as different expressions.
677 switch (LType) {
678 case scUnknown: {
679 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
680 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
681
682 int X =
683 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
684 return X;
685 }
686
687 case scConstant: {
690
691 // Compare constant values.
692 const APInt &LA = LC->getAPInt();
693 const APInt &RA = RC->getAPInt();
694 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
695 if (LBitWidth != RBitWidth)
696 return (int)LBitWidth - (int)RBitWidth;
697 return LA.ult(RA) ? -1 : 1;
698 }
699
700 case scVScale: {
701 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
702 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
703 return LTy->getBitWidth() - RTy->getBitWidth();
704 }
705
706 case scAddRecExpr: {
709
710 // There is always a dominance between two recs that are used by one SCEV,
711 // so we can safely sort recs by loop header dominance. We require such
712 // order in getAddExpr.
713 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
714 if (LLoop != RLoop) {
715 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
716 assert(LHead != RHead && "Two loops share the same header?");
717 if (DT.dominates(LHead, RHead))
718 return 1;
719 assert(DT.dominates(RHead, LHead) &&
720 "No dominance between recurrences used by one SCEV?");
721 return -1;
722 }
723
724 [[fallthrough]];
725 }
726
727 case scTruncate:
728 case scZeroExtend:
729 case scSignExtend:
730 case scPtrToInt:
731 case scAddExpr:
732 case scMulExpr:
733 case scUDivExpr:
734 case scSMaxExpr:
735 case scUMaxExpr:
736 case scSMinExpr:
737 case scUMinExpr:
739 ArrayRef<const SCEV *> LOps = LHS->operands();
740 ArrayRef<const SCEV *> ROps = RHS->operands();
741
742 // Lexicographically compare n-ary-like expressions.
743 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
744 if (LNumOps != RNumOps)
745 return (int)LNumOps - (int)RNumOps;
746
747 for (unsigned i = 0; i != LNumOps; ++i) {
748 auto X = CompareSCEVComplexity(LI, LOps[i], ROps[i], DT, Depth + 1);
749 if (X != 0)
750 return X;
751 }
752 return 0;
753 }
754
756 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
757 }
758 llvm_unreachable("Unknown SCEV kind!");
759}
760
761/// Given a list of SCEV objects, order them by their complexity, and group
762/// objects of the same complexity together by value. When this routine is
763/// finished, we know that any duplicates in the vector are consecutive and that
764/// complexity is monotonically increasing.
765///
766/// Note that we go take special precautions to ensure that we get deterministic
767/// results from this routine. In other words, we don't want the results of
768/// this to depend on where the addresses of various SCEV objects happened to
769/// land in memory.
771 LoopInfo *LI, DominatorTree &DT) {
772 if (Ops.size() < 2) return; // Noop
773
774 // Whether LHS has provably less complexity than RHS.
775 auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
776 auto Complexity = CompareSCEVComplexity(LI, LHS, RHS, DT);
777 return Complexity && *Complexity < 0;
778 };
779 if (Ops.size() == 2) {
780 // This is the common case, which also happens to be trivially simple.
781 // Special case it.
782 const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
783 if (IsLessComplex(RHS, LHS))
784 std::swap(LHS, RHS);
785 return;
786 }
787
788 // Do the rough sort by complexity.
789 llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
790 return IsLessComplex(LHS, RHS);
791 });
792
793 // Now that we are sorted by complexity, group elements of the same
794 // complexity. Note that this is, at worst, N^2, but the vector is likely to
795 // be extremely short in practice. Note that we take this approach because we
796 // do not want to depend on the addresses of the objects we are grouping.
797 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
798 const SCEV *S = Ops[i];
799 unsigned Complexity = S->getSCEVType();
800
801 // If there are any objects of the same complexity and same value as this
802 // one, group them.
803 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
804 if (Ops[j] == S) { // Found a duplicate.
805 // Move it to immediately after i'th element.
806 std::swap(Ops[i+1], Ops[j]);
807 ++i; // no need to rescan it.
808 if (i == e-2) return; // Done!
809 }
810 }
811 }
812}
813
814/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
815/// least HugeExprThreshold nodes).
817 return any_of(Ops, [](const SCEV *S) {
819 });
820}
821
822/// Performs a number of common optimizations on the passed \p Ops. If the
823/// whole expression reduces down to a single operand, it will be returned.
824///
825/// The following optimizations are performed:
826/// * Fold constants using the \p Fold function.
827/// * Remove identity constants satisfying \p IsIdentity.
828/// * If a constant satisfies \p IsAbsorber, return it.
829/// * Sort operands by complexity.
830template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
831static const SCEV *
834 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
835 const SCEVConstant *Folded = nullptr;
836 for (unsigned Idx = 0; Idx < Ops.size();) {
837 const SCEV *Op = Ops[Idx];
838 if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
839 if (!Folded)
840 Folded = C;
841 else
842 Folded = cast<SCEVConstant>(
843 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
844 Ops.erase(Ops.begin() + Idx);
845 continue;
846 }
847 ++Idx;
848 }
849
850 if (Ops.empty()) {
851 assert(Folded && "Must have folded value");
852 return Folded;
853 }
854
855 if (Folded && IsAbsorber(Folded->getAPInt()))
856 return Folded;
857
858 GroupByComplexity(Ops, &LI, DT);
859 if (Folded && !IsIdentity(Folded->getAPInt()))
860 Ops.insert(Ops.begin(), Folded);
861
862 return Ops.size() == 1 ? Ops[0] : nullptr;
863}
864
865//===----------------------------------------------------------------------===//
866// Simple SCEV method implementations
867//===----------------------------------------------------------------------===//
868
869/// Compute BC(It, K). The result has width W. Assume, K > 0.
870static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
871 ScalarEvolution &SE,
872 Type *ResultTy) {
873 // Handle the simplest case efficiently.
874 if (K == 1)
875 return SE.getTruncateOrZeroExtend(It, ResultTy);
876
877 // We are using the following formula for BC(It, K):
878 //
879 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
880 //
881 // Suppose, W is the bitwidth of the return value. We must be prepared for
882 // overflow. Hence, we must assure that the result of our computation is
883 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
884 // safe in modular arithmetic.
885 //
886 // However, this code doesn't use exactly that formula; the formula it uses
887 // is something like the following, where T is the number of factors of 2 in
888 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
889 // exponentiation:
890 //
891 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
892 //
893 // This formula is trivially equivalent to the previous formula. However,
894 // this formula can be implemented much more efficiently. The trick is that
895 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
896 // arithmetic. To do exact division in modular arithmetic, all we have
897 // to do is multiply by the inverse. Therefore, this step can be done at
898 // width W.
899 //
900 // The next issue is how to safely do the division by 2^T. The way this
901 // is done is by doing the multiplication step at a width of at least W + T
902 // bits. This way, the bottom W+T bits of the product are accurate. Then,
903 // when we perform the division by 2^T (which is equivalent to a right shift
904 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
905 // truncated out after the division by 2^T.
906 //
907 // In comparison to just directly using the first formula, this technique
908 // is much more efficient; using the first formula requires W * K bits,
909 // but this formula less than W + K bits. Also, the first formula requires
910 // a division step, whereas this formula only requires multiplies and shifts.
911 //
912 // It doesn't matter whether the subtraction step is done in the calculation
913 // width or the input iteration count's width; if the subtraction overflows,
914 // the result must be zero anyway. We prefer here to do it in the width of
915 // the induction variable because it helps a lot for certain cases; CodeGen
916 // isn't smart enough to ignore the overflow, which leads to much less
917 // efficient code if the width of the subtraction is wider than the native
918 // register width.
919 //
920 // (It's possible to not widen at all by pulling out factors of 2 before
921 // the multiplication; for example, K=2 can be calculated as
922 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
923 // extra arithmetic, so it's not an obvious win, and it gets
924 // much more complicated for K > 3.)
925
926 // Protection from insane SCEVs; this bound is conservative,
927 // but it probably doesn't matter.
928 if (K > 1000)
929 return SE.getCouldNotCompute();
930
931 unsigned W = SE.getTypeSizeInBits(ResultTy);
932
933 // Calculate K! / 2^T and T; we divide out the factors of two before
934 // multiplying for calculating K! / 2^T to avoid overflow.
935 // Other overflow doesn't matter because we only care about the bottom
936 // W bits of the result.
937 APInt OddFactorial(W, 1);
938 unsigned T = 1;
939 for (unsigned i = 3; i <= K; ++i) {
940 unsigned TwoFactors = countr_zero(i);
941 T += TwoFactors;
942 OddFactorial *= (i >> TwoFactors);
943 }
944
945 // We need at least W + T bits for the multiplication step
946 unsigned CalculationBits = W + T;
947
948 // Calculate 2^T, at width T+W.
949 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
950
951 // Calculate the multiplicative inverse of K! / 2^T;
952 // this multiplication factor will perform the exact division by
953 // K! / 2^T.
954 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
955
956 // Calculate the product, at width T+W
957 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
958 CalculationBits);
959 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
960 for (unsigned i = 1; i != K; ++i) {
961 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
962 Dividend = SE.getMulExpr(Dividend,
963 SE.getTruncateOrZeroExtend(S, CalculationTy));
964 }
965
966 // Divide by 2^T
967 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
968
969 // Truncate the result, and divide by K! / 2^T.
970
971 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
972 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
973}
974
975/// Return the value of this chain of recurrences at the specified iteration
976/// number. We can evaluate this recurrence by multiplying each element in the
977/// chain by the binomial coefficient corresponding to it. In other words, we
978/// can evaluate {A,+,B,+,C,+,D} as:
979///
980/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
981///
982/// where BC(It, k) stands for binomial coefficient.
984 ScalarEvolution &SE) const {
985 return evaluateAtIteration(operands(), It, SE);
986}
987
988const SCEV *
990 const SCEV *It, ScalarEvolution &SE) {
991 assert(Operands.size() > 0);
992 const SCEV *Result = Operands[0];
993 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
994 // The computation is correct in the face of overflow provided that the
995 // multiplication is performed _after_ the evaluation of the binomial
996 // coefficient.
997 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
998 if (isa<SCEVCouldNotCompute>(Coeff))
999 return Coeff;
1000
1001 Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
1002 }
1003 return Result;
1004}
1005
1006//===----------------------------------------------------------------------===//
1007// SCEV Expression folder implementations
1008//===----------------------------------------------------------------------===//
1009
1010/// The SCEVCastSinkingRewriter takes a scalar evolution expression,
1011/// which computes a pointer-typed value, and rewrites the whole expression
1012/// tree so that *all* the computations are done on integers, and the only
1013/// pointer-typed operands in the expression are SCEVUnknown.
1014/// The CreatePtrCast callback is invoked to create the actual conversion
1015/// (ptrtoint or ptrtoaddr) at the SCEVUnknown leaves.
1017 : public SCEVRewriteVisitor<SCEVCastSinkingRewriter> {
1019 using ConversionFn = function_ref<const SCEV *(const SCEVUnknown *)>;
1020 Type *TargetTy;
1021 ConversionFn CreatePtrCast;
1022
1023public:
1025 ConversionFn CreatePtrCast)
1026 : Base(SE), TargetTy(TargetTy), CreatePtrCast(std::move(CreatePtrCast)) {}
1027
1028 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
1029 Type *TargetTy, ConversionFn CreatePtrCast) {
1030 SCEVCastSinkingRewriter Rewriter(SE, TargetTy, std::move(CreatePtrCast));
1031 return Rewriter.visit(Scev);
1032 }
1033
1034 const SCEV *visit(const SCEV *S) {
1035 Type *STy = S->getType();
1036 // If the expression is not pointer-typed, just keep it as-is.
1037 if (!STy->isPointerTy())
1038 return S;
1039 // Else, recursively sink the cast down into it.
1040 return Base::visit(S);
1041 }
1042
1043 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1044 // Preserve wrap flags on rewritten SCEVAddExpr, which the default
1045 // implementation drops.
1047 bool Changed = false;
1048 for (const auto *Op : Expr->operands()) {
1049 Operands.push_back(visit(Op));
1050 Changed |= Op != Operands.back();
1051 }
1052 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1053 }
1054
1055 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1057 bool Changed = false;
1058 for (const auto *Op : Expr->operands()) {
1059 Operands.push_back(visit(Op));
1060 Changed |= Op != Operands.back();
1061 }
1062 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1063 }
1064
1065 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1066 assert(Expr->getType()->isPointerTy() &&
1067 "Should only reach pointer-typed SCEVUnknown's.");
1068 // Perform some basic constant folding. If the operand of the cast is a
1069 // null pointer, don't create a cast SCEV expression (that will be left
1070 // as-is), but produce a zero constant.
1072 return SE.getZero(TargetTy);
1073 return CreatePtrCast(Expr);
1074 }
1075};
1076
1078 assert(Op->getType()->isPointerTy() && "Op must be a pointer");
1079
1080 // It isn't legal for optimizations to construct new ptrtoint expressions
1081 // for non-integral pointers.
1082 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1083 return getCouldNotCompute();
1084
1085 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1086
1087 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1088 // is sufficiently wide to represent all possible pointer values.
1089 // We could theoretically teach SCEV to truncate wider pointers, but
1090 // that isn't implemented for now.
1092 getDataLayout().getTypeSizeInBits(IntPtrTy))
1093 return getCouldNotCompute();
1094
1095 // Use the rewriter to sink the cast down to SCEVUnknown leaves.
1097 Op, *this, IntPtrTy, [this, IntPtrTy](const SCEVUnknown *U) {
1099 ID.AddInteger(scPtrToInt);
1100 ID.AddPointer(U);
1101 void *IP = nullptr;
1102 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1103 return S;
1104 SCEV *S = new (SCEVAllocator)
1105 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), U, IntPtrTy);
1106 UniqueSCEVs.InsertNode(S, IP);
1107 registerUser(S, U);
1108 return static_cast<const SCEV *>(S);
1109 });
1110 assert(IntOp->getType()->isIntegerTy() &&
1111 "We must have succeeded in sinking the cast, "
1112 "and ending up with an integer-typed expression!");
1113 return IntOp;
1114}
1115
1117 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1118
1119 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1120 if (isa<SCEVCouldNotCompute>(IntOp))
1121 return IntOp;
1122
1123 return getTruncateOrZeroExtend(IntOp, Ty);
1124}
1125
1127 unsigned Depth) {
1128 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1129 "This is not a truncating conversion!");
1130 assert(isSCEVable(Ty) &&
1131 "This is not a conversion to a SCEVable type!");
1132 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1133 Ty = getEffectiveSCEVType(Ty);
1134
1136 ID.AddInteger(scTruncate);
1137 ID.AddPointer(Op);
1138 ID.AddPointer(Ty);
1139 void *IP = nullptr;
1140 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1141
1142 // Fold if the operand is constant.
1143 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1144 return getConstant(
1145 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1146
1147 // trunc(trunc(x)) --> trunc(x)
1149 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1150
1151 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1153 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1154
1155 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1157 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1158
1159 if (Depth > MaxCastDepth) {
1160 SCEV *S =
1161 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1162 UniqueSCEVs.InsertNode(S, IP);
1163 registerUser(S, Op);
1164 return S;
1165 }
1166
1167 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1168 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1169 // if after transforming we have at most one truncate, not counting truncates
1170 // that replace other casts.
1172 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1174 unsigned numTruncs = 0;
1175 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1176 ++i) {
1177 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1178 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1180 numTruncs++;
1181 Operands.push_back(S);
1182 }
1183 if (numTruncs < 2) {
1184 if (isa<SCEVAddExpr>(Op))
1185 return getAddExpr(Operands);
1186 if (isa<SCEVMulExpr>(Op))
1187 return getMulExpr(Operands);
1188 llvm_unreachable("Unexpected SCEV type for Op.");
1189 }
1190 // Although we checked in the beginning that ID is not in the cache, it is
1191 // possible that during recursion and different modification ID was inserted
1192 // into the cache. So if we find it, just return it.
1193 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1194 return S;
1195 }
1196
1197 // If the input value is a chrec scev, truncate the chrec's operands.
1198 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1200 for (const SCEV *Op : AddRec->operands())
1201 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1202 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1203 }
1204
1205 // Return zero if truncating to known zeros.
1206 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1207 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1208 return getZero(Ty);
1209
1210 // The cast wasn't folded; create an explicit cast node. We can reuse
1211 // the existing insert position since if we get here, we won't have
1212 // made any changes which would invalidate it.
1213 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1214 Op, Ty);
1215 UniqueSCEVs.InsertNode(S, IP);
1216 registerUser(S, Op);
1217 return S;
1218}
1219
1220// Get the limit of a recurrence such that incrementing by Step cannot cause
1221// signed overflow as long as the value of the recurrence within the
1222// loop does not exceed this limit before incrementing.
1223static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1224 ICmpInst::Predicate *Pred,
1225 ScalarEvolution *SE) {
1226 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1227 if (SE->isKnownPositive(Step)) {
1228 *Pred = ICmpInst::ICMP_SLT;
1230 SE->getSignedRangeMax(Step));
1231 }
1232 if (SE->isKnownNegative(Step)) {
1233 *Pred = ICmpInst::ICMP_SGT;
1235 SE->getSignedRangeMin(Step));
1236 }
1237 return nullptr;
1238}
1239
1240// Get the limit of a recurrence such that incrementing by Step cannot cause
1241// unsigned overflow as long as the value of the recurrence within the loop does
1242// not exceed this limit before incrementing.
1244 ICmpInst::Predicate *Pred,
1245 ScalarEvolution *SE) {
1246 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1247 *Pred = ICmpInst::ICMP_ULT;
1248
1250 SE->getUnsignedRangeMax(Step));
1251}
1252
1253namespace {
1254
1255struct ExtendOpTraitsBase {
1256 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1257 unsigned);
1258};
1259
1260// Used to make code generic over signed and unsigned overflow.
1261template <typename ExtendOp> struct ExtendOpTraits {
1262 // Members present:
1263 //
1264 // static const SCEV::NoWrapFlags WrapType;
1265 //
1266 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1267 //
1268 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1269 // ICmpInst::Predicate *Pred,
1270 // ScalarEvolution *SE);
1271};
1272
1273template <>
1274struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1275 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1276
1277 static const GetExtendExprTy GetExtendExpr;
1278
1279 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1280 ICmpInst::Predicate *Pred,
1281 ScalarEvolution *SE) {
1282 return getSignedOverflowLimitForStep(Step, Pred, SE);
1283 }
1284};
1285
1286const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1288
1289template <>
1290struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1291 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1292
1293 static const GetExtendExprTy GetExtendExpr;
1294
1295 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1296 ICmpInst::Predicate *Pred,
1297 ScalarEvolution *SE) {
1298 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1299 }
1300};
1301
1302const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1304
1305} // end anonymous namespace
1306
1307// The recurrence AR has been shown to have no signed/unsigned wrap or something
1308// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1309// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1310// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1311// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1312// expression "Step + sext/zext(PreIncAR)" is congruent with
1313// "sext/zext(PostIncAR)"
1314template <typename ExtendOpTy>
1315static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1316 ScalarEvolution *SE, unsigned Depth) {
1317 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1318 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1319
1320 const Loop *L = AR->getLoop();
1321 const SCEV *Start = AR->getStart();
1322 const SCEV *Step = AR->getStepRecurrence(*SE);
1323
1324 // Check for a simple looking step prior to loop entry.
1325 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1326 if (!SA)
1327 return nullptr;
1328
1329 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1330 // subtraction is expensive. For this purpose, perform a quick and dirty
1331 // difference, by checking for Step in the operand list. Note, that
1332 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1334 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1335 if (*It == Step) {
1336 DiffOps.erase(It);
1337 break;
1338 }
1339
1340 if (DiffOps.size() == SA->getNumOperands())
1341 return nullptr;
1342
1343 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1344 // `Step`:
1345
1346 // 1. NSW/NUW flags on the step increment.
1347 auto PreStartFlags =
1349 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1351 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1352
1353 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1354 // "S+X does not sign/unsign-overflow".
1355 //
1356
1357 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1358 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1359 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1360 return PreStart;
1361
1362 // 2. Direct overflow check on the step operation's expression.
1363 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1364 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1365 const SCEV *OperandExtendedStart =
1366 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1367 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1368 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1369 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1370 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1371 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1372 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1373 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1374 }
1375 return PreStart;
1376 }
1377
1378 // 3. Loop precondition.
1380 const SCEV *OverflowLimit =
1381 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1382
1383 if (OverflowLimit &&
1384 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1385 return PreStart;
1386
1387 return nullptr;
1388}
1389
1390// Get the normalized zero or sign extended expression for this AddRec's Start.
1391template <typename ExtendOpTy>
1392static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1393 ScalarEvolution *SE,
1394 unsigned Depth) {
1395 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1396
1397 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1398 if (!PreStart)
1399 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1400
1401 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1402 Depth),
1403 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1404}
1405
1406// Try to prove away overflow by looking at "nearby" add recurrences. A
1407// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1408// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1409//
1410// Formally:
1411//
1412// {S,+,X} == {S-T,+,X} + T
1413// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1414//
1415// If ({S-T,+,X} + T) does not overflow ... (1)
1416//
1417// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1418//
1419// If {S-T,+,X} does not overflow ... (2)
1420//
1421// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1422// == {Ext(S-T)+Ext(T),+,Ext(X)}
1423//
1424// If (S-T)+T does not overflow ... (3)
1425//
1426// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1427// == {Ext(S),+,Ext(X)} == LHS
1428//
1429// Thus, if (1), (2) and (3) are true for some T, then
1430// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1431//
1432// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1433// does not overflow" restricted to the 0th iteration. Therefore we only need
1434// to check for (1) and (2).
1435//
1436// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1437// is `Delta` (defined below).
1438template <typename ExtendOpTy>
1439bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1440 const SCEV *Step,
1441 const Loop *L) {
1442 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1443
1444 // We restrict `Start` to a constant to prevent SCEV from spending too much
1445 // time here. It is correct (but more expensive) to continue with a
1446 // non-constant `Start` and do a general SCEV subtraction to compute
1447 // `PreStart` below.
1448 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1449 if (!StartC)
1450 return false;
1451
1452 APInt StartAI = StartC->getAPInt();
1453
1454 for (unsigned Delta : {-2, -1, 1, 2}) {
1455 const SCEV *PreStart = getConstant(StartAI - Delta);
1456
1457 FoldingSetNodeID ID;
1458 ID.AddInteger(scAddRecExpr);
1459 ID.AddPointer(PreStart);
1460 ID.AddPointer(Step);
1461 ID.AddPointer(L);
1462 void *IP = nullptr;
1463 const auto *PreAR =
1464 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1465
1466 // Give up if we don't already have the add recurrence we need because
1467 // actually constructing an add recurrence is relatively expensive.
1468 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1469 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1471 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1472 DeltaS, &Pred, this);
1473 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1474 return true;
1475 }
1476 }
1477
1478 return false;
1479}
1480
1481// Finds an integer D for an expression (C + x + y + ...) such that the top
1482// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1483// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1484// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1485// the (C + x + y + ...) expression is \p WholeAddExpr.
1487 const SCEVConstant *ConstantTerm,
1488 const SCEVAddExpr *WholeAddExpr) {
1489 const APInt &C = ConstantTerm->getAPInt();
1490 const unsigned BitWidth = C.getBitWidth();
1491 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1492 uint32_t TZ = BitWidth;
1493 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1494 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1495 if (TZ) {
1496 // Set D to be as many least significant bits of C as possible while still
1497 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1498 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1499 }
1500 return APInt(BitWidth, 0);
1501}
1502
1503// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1504// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1505// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1506// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1508 const APInt &ConstantStart,
1509 const SCEV *Step) {
1510 const unsigned BitWidth = ConstantStart.getBitWidth();
1511 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1512 if (TZ)
1513 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1514 : ConstantStart;
1515 return APInt(BitWidth, 0);
1516}
1517
1519 const ScalarEvolution::FoldID &ID, const SCEV *S,
1522 &FoldCacheUser) {
1523 auto I = FoldCache.insert({ID, S});
1524 if (!I.second) {
1525 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1526 // entry.
1527 auto &UserIDs = FoldCacheUser[I.first->second];
1528 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1529 for (unsigned I = 0; I != UserIDs.size(); ++I)
1530 if (UserIDs[I] == ID) {
1531 std::swap(UserIDs[I], UserIDs.back());
1532 break;
1533 }
1534 UserIDs.pop_back();
1535 I.first->second = S;
1536 }
1537 FoldCacheUser[S].push_back(ID);
1538}
1539
1540const SCEV *
1542 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1543 "This is not an extending conversion!");
1544 assert(isSCEVable(Ty) &&
1545 "This is not a conversion to a SCEVable type!");
1546 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1547 Ty = getEffectiveSCEVType(Ty);
1548
1549 FoldID ID(scZeroExtend, Op, Ty);
1550 if (const SCEV *S = FoldCache.lookup(ID))
1551 return S;
1552
1553 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1555 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1556 return S;
1557}
1558
1560 unsigned Depth) {
1561 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1562 "This is not an extending conversion!");
1563 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1564 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1565
1566 // Fold if the operand is constant.
1567 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1568 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1569
1570 // zext(zext(x)) --> zext(x)
1572 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1573
1574 // Before doing any expensive analysis, check to see if we've already
1575 // computed a SCEV for this Op and Ty.
1577 ID.AddInteger(scZeroExtend);
1578 ID.AddPointer(Op);
1579 ID.AddPointer(Ty);
1580 void *IP = nullptr;
1581 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1582 if (Depth > MaxCastDepth) {
1583 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1584 Op, Ty);
1585 UniqueSCEVs.InsertNode(S, IP);
1586 registerUser(S, Op);
1587 return S;
1588 }
1589
1590 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1592 // It's possible the bits taken off by the truncate were all zero bits. If
1593 // so, we should be able to simplify this further.
1594 const SCEV *X = ST->getOperand();
1596 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1597 unsigned NewBits = getTypeSizeInBits(Ty);
1598 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1599 CR.zextOrTrunc(NewBits)))
1600 return getTruncateOrZeroExtend(X, Ty, Depth);
1601 }
1602
1603 // If the input value is a chrec scev, and we can prove that the value
1604 // did not overflow the old, smaller, value, we can zero extend all of the
1605 // operands (often constants). This allows analysis of something like
1606 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1608 if (AR->isAffine()) {
1609 const SCEV *Start = AR->getStart();
1610 const SCEV *Step = AR->getStepRecurrence(*this);
1611 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1612 const Loop *L = AR->getLoop();
1613
1614 // If we have special knowledge that this addrec won't overflow,
1615 // we don't need to do any further analysis.
1616 if (AR->hasNoUnsignedWrap()) {
1617 Start =
1619 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1620 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1621 }
1622
1623 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1624 // Note that this serves two purposes: It filters out loops that are
1625 // simply not analyzable, and it covers the case where this code is
1626 // being called from within backedge-taken count analysis, such that
1627 // attempting to ask for the backedge-taken count would likely result
1628 // in infinite recursion. In the later case, the analysis code will
1629 // cope with a conservative value, and it will take care to purge
1630 // that value once it has finished.
1631 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1632 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1633 // Manually compute the final value for AR, checking for overflow.
1634
1635 // Check whether the backedge-taken count can be losslessly casted to
1636 // the addrec's type. The count is always unsigned.
1637 const SCEV *CastedMaxBECount =
1638 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1639 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1640 CastedMaxBECount, MaxBECount->getType(), Depth);
1641 if (MaxBECount == RecastedMaxBECount) {
1642 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1643 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1644 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1646 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1648 Depth + 1),
1649 WideTy, Depth + 1);
1650 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1651 const SCEV *WideMaxBECount =
1652 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1653 const SCEV *OperandExtendedAdd =
1654 getAddExpr(WideStart,
1655 getMulExpr(WideMaxBECount,
1656 getZeroExtendExpr(Step, WideTy, Depth + 1),
1659 if (ZAdd == OperandExtendedAdd) {
1660 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1661 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1662 // Return the expression with the addrec on the outside.
1663 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1664 Depth + 1);
1665 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1666 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1667 }
1668 // Similar to above, only this time treat the step value as signed.
1669 // This covers loops that count down.
1670 OperandExtendedAdd =
1671 getAddExpr(WideStart,
1672 getMulExpr(WideMaxBECount,
1673 getSignExtendExpr(Step, WideTy, Depth + 1),
1676 if (ZAdd == OperandExtendedAdd) {
1677 // Cache knowledge of AR NW, which is propagated to this AddRec.
1678 // Negative step causes unsigned wrap, but it still can't self-wrap.
1679 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1680 // Return the expression with the addrec on the outside.
1681 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1682 Depth + 1);
1683 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1684 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1685 }
1686 }
1687 }
1688
1689 // Normally, in the cases we can prove no-overflow via a
1690 // backedge guarding condition, we can also compute a backedge
1691 // taken count for the loop. The exceptions are assumptions and
1692 // guards present in the loop -- SCEV is not great at exploiting
1693 // these to compute max backedge taken counts, but can still use
1694 // these to prove lack of overflow. Use this fact to avoid
1695 // doing extra work that may not pay off.
1696 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1697 !AC.assumptions().empty()) {
1698
1699 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1700 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1701 if (AR->hasNoUnsignedWrap()) {
1702 // Same as nuw case above - duplicated here to avoid a compile time
1703 // issue. It's not clear that the order of checks does matter, but
1704 // it's one of two issue possible causes for a change which was
1705 // reverted. Be conservative for the moment.
1706 Start =
1708 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1709 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1710 }
1711
1712 // For a negative step, we can extend the operands iff doing so only
1713 // traverses values in the range zext([0,UINT_MAX]).
1714 if (isKnownNegative(Step)) {
1716 getSignedRangeMin(Step));
1719 // Cache knowledge of AR NW, which is propagated to this
1720 // AddRec. Negative step causes unsigned wrap, but it
1721 // still can't self-wrap.
1722 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1723 // Return the expression with the addrec on the outside.
1724 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1725 Depth + 1);
1726 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1727 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1728 }
1729 }
1730 }
1731
1732 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1733 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1734 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1735 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1736 const APInt &C = SC->getAPInt();
1737 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1738 if (D != 0) {
1739 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1740 const SCEV *SResidual =
1741 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1742 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1743 return getAddExpr(SZExtD, SZExtR,
1745 Depth + 1);
1746 }
1747 }
1748
1749 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1750 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1751 Start =
1753 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1754 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1755 }
1756 }
1757
1758 // zext(A % B) --> zext(A) % zext(B)
1759 {
1760 const SCEV *LHS;
1761 const SCEV *RHS;
1762 if (match(Op, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), *this)))
1763 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1764 getZeroExtendExpr(RHS, Ty, Depth + 1));
1765 }
1766
1767 // zext(A / B) --> zext(A) / zext(B).
1768 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1769 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1770 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1771
1772 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1773 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1774 if (SA->hasNoUnsignedWrap()) {
1775 // If the addition does not unsign overflow then we can, by definition,
1776 // commute the zero extension with the addition operation.
1778 for (const auto *Op : SA->operands())
1779 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1780 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1781 }
1782
1783 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1784 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1785 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1786 //
1787 // Often address arithmetics contain expressions like
1788 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1789 // This transformation is useful while proving that such expressions are
1790 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1791 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1792 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1793 if (D != 0) {
1794 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1795 const SCEV *SResidual =
1797 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1798 return getAddExpr(SZExtD, SZExtR,
1800 Depth + 1);
1801 }
1802 }
1803 }
1804
1805 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1806 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1807 if (SM->hasNoUnsignedWrap()) {
1808 // If the multiply does not unsign overflow then we can, by definition,
1809 // commute the zero extension with the multiply operation.
1811 for (const auto *Op : SM->operands())
1812 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1813 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1814 }
1815
1816 // zext(2^K * (trunc X to iN)) to iM ->
1817 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1818 //
1819 // Proof:
1820 //
1821 // zext(2^K * (trunc X to iN)) to iM
1822 // = zext((trunc X to iN) << K) to iM
1823 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1824 // (because shl removes the top K bits)
1825 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1826 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1827 //
1828 const APInt *C;
1829 const SCEV *TruncRHS;
1830 if (match(SM,
1831 m_scev_Mul(m_scev_APInt(C), m_scev_Trunc(m_SCEV(TruncRHS)))) &&
1832 C->isPowerOf2()) {
1833 int NewTruncBits =
1834 getTypeSizeInBits(SM->getOperand(1)->getType()) - C->logBase2();
1835 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1836 return getMulExpr(
1837 getZeroExtendExpr(SM->getOperand(0), Ty),
1838 getZeroExtendExpr(getTruncateExpr(TruncRHS, NewTruncTy), Ty),
1839 SCEV::FlagNUW, Depth + 1);
1840 }
1841 }
1842
1843 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1844 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1848 for (auto *Operand : MinMax->operands())
1849 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1851 return getUMinExpr(Operands);
1852 return getUMaxExpr(Operands);
1853 }
1854
1855 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1857 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1859 for (auto *Operand : MinMax->operands())
1860 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1861 return getUMinExpr(Operands, /*Sequential*/ true);
1862 }
1863
1864 // The cast wasn't folded; create an explicit cast node.
1865 // Recompute the insert position, as it may have been invalidated.
1866 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1867 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1868 Op, Ty);
1869 UniqueSCEVs.InsertNode(S, IP);
1870 registerUser(S, Op);
1871 return S;
1872}
1873
1874const SCEV *
1876 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1877 "This is not an extending conversion!");
1878 assert(isSCEVable(Ty) &&
1879 "This is not a conversion to a SCEVable type!");
1880 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1881 Ty = getEffectiveSCEVType(Ty);
1882
1883 FoldID ID(scSignExtend, Op, Ty);
1884 if (const SCEV *S = FoldCache.lookup(ID))
1885 return S;
1886
1887 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
1889 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1890 return S;
1891}
1892
1894 unsigned Depth) {
1895 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1896 "This is not an extending conversion!");
1897 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1898 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1899 Ty = getEffectiveSCEVType(Ty);
1900
1901 // Fold if the operand is constant.
1902 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1903 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
1904
1905 // sext(sext(x)) --> sext(x)
1907 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1908
1909 // sext(zext(x)) --> zext(x)
1911 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1912
1913 // Before doing any expensive analysis, check to see if we've already
1914 // computed a SCEV for this Op and Ty.
1916 ID.AddInteger(scSignExtend);
1917 ID.AddPointer(Op);
1918 ID.AddPointer(Ty);
1919 void *IP = nullptr;
1920 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1921 // Limit recursion depth.
1922 if (Depth > MaxCastDepth) {
1923 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1924 Op, Ty);
1925 UniqueSCEVs.InsertNode(S, IP);
1926 registerUser(S, Op);
1927 return S;
1928 }
1929
1930 // sext(trunc(x)) --> sext(x) or x or trunc(x)
1932 // It's possible the bits taken off by the truncate were all sign bits. If
1933 // so, we should be able to simplify this further.
1934 const SCEV *X = ST->getOperand();
1936 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1937 unsigned NewBits = getTypeSizeInBits(Ty);
1938 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1939 CR.sextOrTrunc(NewBits)))
1940 return getTruncateOrSignExtend(X, Ty, Depth);
1941 }
1942
1943 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1944 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1945 if (SA->hasNoSignedWrap()) {
1946 // If the addition does not sign overflow then we can, by definition,
1947 // commute the sign extension with the addition operation.
1949 for (const auto *Op : SA->operands())
1950 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1951 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1952 }
1953
1954 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1955 // if D + (C - D + x + y + ...) could be proven to not signed wrap
1956 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1957 //
1958 // For instance, this will bring two seemingly different expressions:
1959 // 1 + sext(5 + 20 * %x + 24 * %y) and
1960 // sext(6 + 20 * %x + 24 * %y)
1961 // to the same form:
1962 // 2 + sext(4 + 20 * %x + 24 * %y)
1963 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1964 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1965 if (D != 0) {
1966 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1967 const SCEV *SResidual =
1969 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1970 return getAddExpr(SSExtD, SSExtR,
1972 Depth + 1);
1973 }
1974 }
1975 }
1976 // If the input value is a chrec scev, and we can prove that the value
1977 // did not overflow the old, smaller, value, we can sign extend all of the
1978 // operands (often constants). This allows analysis of something like
1979 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
1981 if (AR->isAffine()) {
1982 const SCEV *Start = AR->getStart();
1983 const SCEV *Step = AR->getStepRecurrence(*this);
1984 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1985 const Loop *L = AR->getLoop();
1986
1987 // If we have special knowledge that this addrec won't overflow,
1988 // we don't need to do any further analysis.
1989 if (AR->hasNoSignedWrap()) {
1990 Start =
1992 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1993 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
1994 }
1995
1996 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1997 // Note that this serves two purposes: It filters out loops that are
1998 // simply not analyzable, and it covers the case where this code is
1999 // being called from within backedge-taken count analysis, such that
2000 // attempting to ask for the backedge-taken count would likely result
2001 // in infinite recursion. In the later case, the analysis code will
2002 // cope with a conservative value, and it will take care to purge
2003 // that value once it has finished.
2004 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2005 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2006 // Manually compute the final value for AR, checking for
2007 // overflow.
2008
2009 // Check whether the backedge-taken count can be losslessly casted to
2010 // the addrec's type. The count is always unsigned.
2011 const SCEV *CastedMaxBECount =
2012 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2013 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2014 CastedMaxBECount, MaxBECount->getType(), Depth);
2015 if (MaxBECount == RecastedMaxBECount) {
2016 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2017 // Check whether Start+Step*MaxBECount has no signed overflow.
2018 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2020 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2022 Depth + 1),
2023 WideTy, Depth + 1);
2024 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2025 const SCEV *WideMaxBECount =
2026 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2027 const SCEV *OperandExtendedAdd =
2028 getAddExpr(WideStart,
2029 getMulExpr(WideMaxBECount,
2030 getSignExtendExpr(Step, WideTy, Depth + 1),
2033 if (SAdd == OperandExtendedAdd) {
2034 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2035 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2036 // Return the expression with the addrec on the outside.
2037 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2038 Depth + 1);
2039 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2040 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2041 }
2042 // Similar to above, only this time treat the step value as unsigned.
2043 // This covers loops that count up with an unsigned step.
2044 OperandExtendedAdd =
2045 getAddExpr(WideStart,
2046 getMulExpr(WideMaxBECount,
2047 getZeroExtendExpr(Step, WideTy, Depth + 1),
2050 if (SAdd == OperandExtendedAdd) {
2051 // If AR wraps around then
2052 //
2053 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2054 // => SAdd != OperandExtendedAdd
2055 //
2056 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2057 // (SAdd == OperandExtendedAdd => AR is NW)
2058
2059 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2060
2061 // Return the expression with the addrec on the outside.
2062 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2063 Depth + 1);
2064 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2065 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2066 }
2067 }
2068 }
2069
2070 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2071 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2072 if (AR->hasNoSignedWrap()) {
2073 // Same as nsw case above - duplicated here to avoid a compile time
2074 // issue. It's not clear that the order of checks does matter, but
2075 // it's one of two issue possible causes for a change which was
2076 // reverted. Be conservative for the moment.
2077 Start =
2079 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2080 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2081 }
2082
2083 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2084 // if D + (C - D + Step * n) could be proven to not signed wrap
2085 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2086 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2087 const APInt &C = SC->getAPInt();
2088 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2089 if (D != 0) {
2090 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2091 const SCEV *SResidual =
2092 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2093 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2094 return getAddExpr(SSExtD, SSExtR,
2096 Depth + 1);
2097 }
2098 }
2099
2100 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2101 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2102 Start =
2104 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2105 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2106 }
2107 }
2108
2109 // If the input value is provably positive and we could not simplify
2110 // away the sext build a zext instead.
2112 return getZeroExtendExpr(Op, Ty, Depth + 1);
2113
2114 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2115 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2119 for (auto *Operand : MinMax->operands())
2120 Operands.push_back(getSignExtendExpr(Operand, Ty));
2122 return getSMinExpr(Operands);
2123 return getSMaxExpr(Operands);
2124 }
2125
2126 // The cast wasn't folded; create an explicit cast node.
2127 // Recompute the insert position, as it may have been invalidated.
2128 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2129 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2130 Op, Ty);
2131 UniqueSCEVs.InsertNode(S, IP);
2132 registerUser(S, { Op });
2133 return S;
2134}
2135
2137 Type *Ty) {
2138 switch (Kind) {
2139 case scTruncate:
2140 return getTruncateExpr(Op, Ty);
2141 case scZeroExtend:
2142 return getZeroExtendExpr(Op, Ty);
2143 case scSignExtend:
2144 return getSignExtendExpr(Op, Ty);
2145 case scPtrToInt:
2146 return getPtrToIntExpr(Op, Ty);
2147 default:
2148 llvm_unreachable("Not a SCEV cast expression!");
2149 }
2150}
2151
2152/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2153/// unspecified bits out to the given type.
2155 Type *Ty) {
2156 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2157 "This is not an extending conversion!");
2158 assert(isSCEVable(Ty) &&
2159 "This is not a conversion to a SCEVable type!");
2160 Ty = getEffectiveSCEVType(Ty);
2161
2162 // Sign-extend negative constants.
2163 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2164 if (SC->getAPInt().isNegative())
2165 return getSignExtendExpr(Op, Ty);
2166
2167 // Peel off a truncate cast.
2169 const SCEV *NewOp = T->getOperand();
2170 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2171 return getAnyExtendExpr(NewOp, Ty);
2172 return getTruncateOrNoop(NewOp, Ty);
2173 }
2174
2175 // Next try a zext cast. If the cast is folded, use it.
2176 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2177 if (!isa<SCEVZeroExtendExpr>(ZExt))
2178 return ZExt;
2179
2180 // Next try a sext cast. If the cast is folded, use it.
2181 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2182 if (!isa<SCEVSignExtendExpr>(SExt))
2183 return SExt;
2184
2185 // Force the cast to be folded into the operands of an addrec.
2186 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2188 for (const SCEV *Op : AR->operands())
2189 Ops.push_back(getAnyExtendExpr(Op, Ty));
2190 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2191 }
2192
2193 // If the expression is obviously signed, use the sext cast value.
2194 if (isa<SCEVSMaxExpr>(Op))
2195 return SExt;
2196
2197 // Absent any other information, use the zext cast value.
2198 return ZExt;
2199}
2200
2201/// Process the given Ops list, which is a list of operands to be added under
2202/// the given scale, update the given map. This is a helper function for
2203/// getAddRecExpr. As an example of what it does, given a sequence of operands
2204/// that would form an add expression like this:
2205///
2206/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2207///
2208/// where A and B are constants, update the map with these values:
2209///
2210/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2211///
2212/// and add 13 + A*B*29 to AccumulatedConstant.
2213/// This will allow getAddRecExpr to produce this:
2214///
2215/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2216///
2217/// This form often exposes folding opportunities that are hidden in
2218/// the original operand list.
2219///
2220/// Return true iff it appears that any interesting folding opportunities
2221/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2222/// the common case where no interesting opportunities are present, and
2223/// is also used as a check to avoid infinite recursion.
2224static bool
2227 APInt &AccumulatedConstant,
2228 ArrayRef<const SCEV *> Ops, const APInt &Scale,
2229 ScalarEvolution &SE) {
2230 bool Interesting = false;
2231
2232 // Iterate over the add operands. They are sorted, with constants first.
2233 unsigned i = 0;
2234 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2235 ++i;
2236 // Pull a buried constant out to the outside.
2237 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2238 Interesting = true;
2239 AccumulatedConstant += Scale * C->getAPInt();
2240 }
2241
2242 // Next comes everything else. We're especially interested in multiplies
2243 // here, but they're in the middle, so just visit the rest with one loop.
2244 for (; i != Ops.size(); ++i) {
2246 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2247 APInt NewScale =
2248 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2249 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2250 // A multiplication of a constant with another add; recurse.
2251 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2252 Interesting |=
2253 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2254 Add->operands(), NewScale, SE);
2255 } else {
2256 // A multiplication of a constant with some other value. Update
2257 // the map.
2258 SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2259 const SCEV *Key = SE.getMulExpr(MulOps);
2260 auto Pair = M.insert({Key, NewScale});
2261 if (Pair.second) {
2262 NewOps.push_back(Pair.first->first);
2263 } else {
2264 Pair.first->second += NewScale;
2265 // The map already had an entry for this value, which may indicate
2266 // a folding opportunity.
2267 Interesting = true;
2268 }
2269 }
2270 } else {
2271 // An ordinary operand. Update the map.
2272 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2273 M.insert({Ops[i], Scale});
2274 if (Pair.second) {
2275 NewOps.push_back(Pair.first->first);
2276 } else {
2277 Pair.first->second += Scale;
2278 // The map already had an entry for this value, which may indicate
2279 // a folding opportunity.
2280 Interesting = true;
2281 }
2282 }
2283 }
2284
2285 return Interesting;
2286}
2287
2289 const SCEV *LHS, const SCEV *RHS,
2290 const Instruction *CtxI) {
2291 const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2292 SCEV::NoWrapFlags, unsigned);
2293 switch (BinOp) {
2294 default:
2295 llvm_unreachable("Unsupported binary op");
2296 case Instruction::Add:
2298 break;
2299 case Instruction::Sub:
2301 break;
2302 case Instruction::Mul:
2304 break;
2305 }
2306
2307 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2310
2311 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2312 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2313 auto *WideTy =
2314 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2315
2316 const SCEV *A = (this->*Extension)(
2317 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2318 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2319 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2320 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2321 if (A == B)
2322 return true;
2323 // Can we use context to prove the fact we need?
2324 if (!CtxI)
2325 return false;
2326 // TODO: Support mul.
2327 if (BinOp == Instruction::Mul)
2328 return false;
2329 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2330 // TODO: Lift this limitation.
2331 if (!RHSC)
2332 return false;
2333 APInt C = RHSC->getAPInt();
2334 unsigned NumBits = C.getBitWidth();
2335 bool IsSub = (BinOp == Instruction::Sub);
2336 bool IsNegativeConst = (Signed && C.isNegative());
2337 // Compute the direction and magnitude by which we need to check overflow.
2338 bool OverflowDown = IsSub ^ IsNegativeConst;
2339 APInt Magnitude = C;
2340 if (IsNegativeConst) {
2341 if (C == APInt::getSignedMinValue(NumBits))
2342 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2343 // want to deal with that.
2344 return false;
2345 Magnitude = -C;
2346 }
2347
2349 if (OverflowDown) {
2350 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2351 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2352 : APInt::getMinValue(NumBits);
2353 APInt Limit = Min + Magnitude;
2354 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2355 } else {
2356 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2357 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2358 : APInt::getMaxValue(NumBits);
2359 APInt Limit = Max - Magnitude;
2360 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2361 }
2362}
2363
2364std::optional<SCEV::NoWrapFlags>
2366 const OverflowingBinaryOperator *OBO) {
2367 // It cannot be done any better.
2368 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2369 return std::nullopt;
2370
2372
2373 if (OBO->hasNoUnsignedWrap())
2375 if (OBO->hasNoSignedWrap())
2377
2378 bool Deduced = false;
2379
2380 if (OBO->getOpcode() != Instruction::Add &&
2381 OBO->getOpcode() != Instruction::Sub &&
2382 OBO->getOpcode() != Instruction::Mul)
2383 return std::nullopt;
2384
2385 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2386 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2387
2388 const Instruction *CtxI =
2390 if (!OBO->hasNoUnsignedWrap() &&
2392 /* Signed */ false, LHS, RHS, CtxI)) {
2394 Deduced = true;
2395 }
2396
2397 if (!OBO->hasNoSignedWrap() &&
2399 /* Signed */ true, LHS, RHS, CtxI)) {
2401 Deduced = true;
2402 }
2403
2404 if (Deduced)
2405 return Flags;
2406 return std::nullopt;
2407}
2408
2409// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2410// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2411// can't-overflow flags for the operation if possible.
2415 SCEV::NoWrapFlags Flags) {
2416 using namespace std::placeholders;
2417
2418 using OBO = OverflowingBinaryOperator;
2419
2420 bool CanAnalyze =
2422 (void)CanAnalyze;
2423 assert(CanAnalyze && "don't call from other places!");
2424
2425 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2426 SCEV::NoWrapFlags SignOrUnsignWrap =
2427 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2428
2429 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2430 auto IsKnownNonNegative = [&](const SCEV *S) {
2431 return SE->isKnownNonNegative(S);
2432 };
2433
2434 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2435 Flags =
2436 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2437
2438 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2439
2440 if (SignOrUnsignWrap != SignOrUnsignMask &&
2441 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2442 isa<SCEVConstant>(Ops[0])) {
2443
2444 auto Opcode = [&] {
2445 switch (Type) {
2446 case scAddExpr:
2447 return Instruction::Add;
2448 case scMulExpr:
2449 return Instruction::Mul;
2450 default:
2451 llvm_unreachable("Unexpected SCEV op.");
2452 }
2453 }();
2454
2455 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2456
2457 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2458 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2460 Opcode, C, OBO::NoSignedWrap);
2461 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2463 }
2464
2465 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2466 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2468 Opcode, C, OBO::NoUnsignedWrap);
2469 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2471 }
2472 }
2473
2474 // <0,+,nonnegative><nw> is also nuw
2475 // TODO: Add corresponding nsw case
2477 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2478 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2480
2481 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2483 Ops.size() == 2) {
2484 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2485 if (UDiv->getOperand(1) == Ops[1])
2487 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2488 if (UDiv->getOperand(1) == Ops[0])
2490 }
2491
2492 return Flags;
2493}
2494
2496 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2497}
2498
2499/// Get a canonical add expression, or something simpler if possible.
2501 SCEV::NoWrapFlags OrigFlags,
2502 unsigned Depth) {
2503 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2504 "only nuw or nsw allowed");
2505 assert(!Ops.empty() && "Cannot get empty add!");
2506 if (Ops.size() == 1) return Ops[0];
2507#ifndef NDEBUG
2508 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2509 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2510 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2511 "SCEVAddExpr operand types don't match!");
2512 unsigned NumPtrs = count_if(
2513 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2514 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2515#endif
2516
2517 const SCEV *Folded = constantFoldAndGroupOps(
2518 *this, LI, DT, Ops,
2519 [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2520 [](const APInt &C) { return C.isZero(); }, // identity
2521 [](const APInt &C) { return false; }); // absorber
2522 if (Folded)
2523 return Folded;
2524
2525 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2526
2527 // Delay expensive flag strengthening until necessary.
2528 auto ComputeFlags = [this, OrigFlags](ArrayRef<const SCEV *> Ops) {
2529 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2530 };
2531
2532 // Limit recursion calls depth.
2534 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2535
2536 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2537 // Don't strengthen flags if we have no new information.
2538 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2539 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2540 Add->setNoWrapFlags(ComputeFlags(Ops));
2541 return S;
2542 }
2543
2544 // Okay, check to see if the same value occurs in the operand list more than
2545 // once. If so, merge them together into an multiply expression. Since we
2546 // sorted the list, these values are required to be adjacent.
2547 Type *Ty = Ops[0]->getType();
2548 bool FoundMatch = false;
2549 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2550 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2551 // Scan ahead to count how many equal operands there are.
2552 unsigned Count = 2;
2553 while (i+Count != e && Ops[i+Count] == Ops[i])
2554 ++Count;
2555 // Merge the values into a multiply.
2556 const SCEV *Scale = getConstant(Ty, Count);
2557 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2558 if (Ops.size() == Count)
2559 return Mul;
2560 Ops[i] = Mul;
2561 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2562 --i; e -= Count - 1;
2563 FoundMatch = true;
2564 }
2565 if (FoundMatch)
2566 return getAddExpr(Ops, OrigFlags, Depth + 1);
2567
2568 // Check for truncates. If all the operands are truncated from the same
2569 // type, see if factoring out the truncate would permit the result to be
2570 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2571 // if the contents of the resulting outer trunc fold to something simple.
2572 auto FindTruncSrcType = [&]() -> Type * {
2573 // We're ultimately looking to fold an addrec of truncs and muls of only
2574 // constants and truncs, so if we find any other types of SCEV
2575 // as operands of the addrec then we bail and return nullptr here.
2576 // Otherwise, we return the type of the operand of a trunc that we find.
2577 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2578 return T->getOperand()->getType();
2579 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2580 const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2581 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2582 return T->getOperand()->getType();
2583 }
2584 return nullptr;
2585 };
2586 if (auto *SrcType = FindTruncSrcType()) {
2588 bool Ok = true;
2589 // Check all the operands to see if they can be represented in the
2590 // source type of the truncate.
2591 for (const SCEV *Op : Ops) {
2593 if (T->getOperand()->getType() != SrcType) {
2594 Ok = false;
2595 break;
2596 }
2597 LargeOps.push_back(T->getOperand());
2598 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2599 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2600 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2601 SmallVector<const SCEV *, 8> LargeMulOps;
2602 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2603 if (const SCEVTruncateExpr *T =
2604 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2605 if (T->getOperand()->getType() != SrcType) {
2606 Ok = false;
2607 break;
2608 }
2609 LargeMulOps.push_back(T->getOperand());
2610 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2611 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2612 } else {
2613 Ok = false;
2614 break;
2615 }
2616 }
2617 if (Ok)
2618 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2619 } else {
2620 Ok = false;
2621 break;
2622 }
2623 }
2624 if (Ok) {
2625 // Evaluate the expression in the larger type.
2626 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2627 // If it folds to something simple, use it. Otherwise, don't.
2628 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2629 return getTruncateExpr(Fold, Ty);
2630 }
2631 }
2632
2633 if (Ops.size() == 2) {
2634 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2635 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2636 // C1).
2637 const SCEV *A = Ops[0];
2638 const SCEV *B = Ops[1];
2639 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2640 auto *C = dyn_cast<SCEVConstant>(A);
2641 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2642 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2643 auto C2 = C->getAPInt();
2644 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2645
2646 APInt ConstAdd = C1 + C2;
2647 auto AddFlags = AddExpr->getNoWrapFlags();
2648 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2650 ConstAdd.ule(C1)) {
2651 PreservedFlags =
2653 }
2654
2655 // Adding a constant with the same sign and small magnitude is NSW, if the
2656 // original AddExpr was NSW.
2658 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2659 ConstAdd.abs().ule(C1.abs())) {
2660 PreservedFlags =
2662 }
2663
2664 if (PreservedFlags != SCEV::FlagAnyWrap) {
2665 SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2666 NewOps[0] = getConstant(ConstAdd);
2667 return getAddExpr(NewOps, PreservedFlags);
2668 }
2669 }
2670
2671 // Try to push the constant operand into a ZExt: A + zext (-A + B) -> zext
2672 // (B), if trunc (A) + -A + B does not unsigned-wrap.
2673 const SCEVAddExpr *InnerAdd;
2674 if (match(B, m_scev_ZExt(m_scev_Add(InnerAdd)))) {
2675 const SCEV *NarrowA = getTruncateExpr(A, InnerAdd->getType());
2676 if (NarrowA == getNegativeSCEV(InnerAdd->getOperand(0)) &&
2677 getZeroExtendExpr(NarrowA, B->getType()) == A &&
2678 hasFlags(StrengthenNoWrapFlags(this, scAddExpr, {NarrowA, InnerAdd},
2680 SCEV::FlagNUW)) {
2681 return getZeroExtendExpr(getAddExpr(NarrowA, InnerAdd), B->getType());
2682 }
2683 }
2684 }
2685
2686 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2687 const SCEV *Y;
2688 if (Ops.size() == 2 &&
2689 match(Ops[0],
2691 m_scev_URem(m_scev_Specific(Ops[1]), m_SCEV(Y), *this))))
2692 return getMulExpr(Y, getUDivExpr(Ops[1], Y));
2693
2694 // Skip past any other cast SCEVs.
2695 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2696 ++Idx;
2697
2698 // If there are add operands they would be next.
2699 if (Idx < Ops.size()) {
2700 bool DeletedAdd = false;
2701 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2702 // common NUW flag for expression after inlining. Other flags cannot be
2703 // preserved, because they may depend on the original order of operations.
2704 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2705 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2706 if (Ops.size() > AddOpsInlineThreshold ||
2707 Add->getNumOperands() > AddOpsInlineThreshold)
2708 break;
2709 // If we have an add, expand the add operands onto the end of the operands
2710 // list.
2711 Ops.erase(Ops.begin()+Idx);
2712 append_range(Ops, Add->operands());
2713 DeletedAdd = true;
2714 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2715 }
2716
2717 // If we deleted at least one add, we added operands to the end of the list,
2718 // and they are not necessarily sorted. Recurse to resort and resimplify
2719 // any operands we just acquired.
2720 if (DeletedAdd)
2721 return getAddExpr(Ops, CommonFlags, Depth + 1);
2722 }
2723
2724 // Skip over the add expression until we get to a multiply.
2725 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2726 ++Idx;
2727
2728 // Check to see if there are any folding opportunities present with
2729 // operands multiplied by constant values.
2730 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2734 APInt AccumulatedConstant(BitWidth, 0);
2735 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2736 Ops, APInt(BitWidth, 1), *this)) {
2737 struct APIntCompare {
2738 bool operator()(const APInt &LHS, const APInt &RHS) const {
2739 return LHS.ult(RHS);
2740 }
2741 };
2742
2743 // Some interesting folding opportunity is present, so its worthwhile to
2744 // re-generate the operands list. Group the operands by constant scale,
2745 // to avoid multiplying by the same constant scale multiple times.
2746 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2747 for (const SCEV *NewOp : NewOps)
2748 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2749 // Re-generate the operands list.
2750 Ops.clear();
2751 if (AccumulatedConstant != 0)
2752 Ops.push_back(getConstant(AccumulatedConstant));
2753 for (auto &MulOp : MulOpLists) {
2754 if (MulOp.first == 1) {
2755 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2756 } else if (MulOp.first != 0) {
2757 Ops.push_back(getMulExpr(
2758 getConstant(MulOp.first),
2759 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2760 SCEV::FlagAnyWrap, Depth + 1));
2761 }
2762 }
2763 if (Ops.empty())
2764 return getZero(Ty);
2765 if (Ops.size() == 1)
2766 return Ops[0];
2767 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2768 }
2769 }
2770
2771 // If we are adding something to a multiply expression, make sure the
2772 // something is not already an operand of the multiply. If so, merge it into
2773 // the multiply.
2774 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2775 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2776 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2777 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2778 if (isa<SCEVConstant>(MulOpSCEV))
2779 continue;
2780 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2781 if (MulOpSCEV == Ops[AddOp]) {
2782 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2783 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2784 if (Mul->getNumOperands() != 2) {
2785 // If the multiply has more than two operands, we must get the
2786 // Y*Z term.
2788 Mul->operands().take_front(MulOp));
2789 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2790 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2791 }
2792 SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2793 const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2794 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2796 if (Ops.size() == 2) return OuterMul;
2797 if (AddOp < Idx) {
2798 Ops.erase(Ops.begin()+AddOp);
2799 Ops.erase(Ops.begin()+Idx-1);
2800 } else {
2801 Ops.erase(Ops.begin()+Idx);
2802 Ops.erase(Ops.begin()+AddOp-1);
2803 }
2804 Ops.push_back(OuterMul);
2805 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2806 }
2807
2808 // Check this multiply against other multiplies being added together.
2809 for (unsigned OtherMulIdx = Idx+1;
2810 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2811 ++OtherMulIdx) {
2812 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2813 // If MulOp occurs in OtherMul, we can fold the two multiplies
2814 // together.
2815 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2816 OMulOp != e; ++OMulOp)
2817 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2818 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2819 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2820 if (Mul->getNumOperands() != 2) {
2822 Mul->operands().take_front(MulOp));
2823 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2824 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2825 }
2826 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2827 if (OtherMul->getNumOperands() != 2) {
2829 OtherMul->operands().take_front(OMulOp));
2830 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2831 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2832 }
2833 SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2834 const SCEV *InnerMulSum =
2835 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2836 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2838 if (Ops.size() == 2) return OuterMul;
2839 Ops.erase(Ops.begin()+Idx);
2840 Ops.erase(Ops.begin()+OtherMulIdx-1);
2841 Ops.push_back(OuterMul);
2842 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2843 }
2844 }
2845 }
2846 }
2847
2848 // If there are any add recurrences in the operands list, see if any other
2849 // added values are loop invariant. If so, we can fold them into the
2850 // recurrence.
2851 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2852 ++Idx;
2853
2854 // Scan over all recurrences, trying to fold loop invariants into them.
2855 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2856 // Scan all of the other operands to this add and add them to the vector if
2857 // they are loop invariant w.r.t. the recurrence.
2859 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2860 const Loop *AddRecLoop = AddRec->getLoop();
2861 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2862 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2863 LIOps.push_back(Ops[i]);
2864 Ops.erase(Ops.begin()+i);
2865 --i; --e;
2866 }
2867
2868 // If we found some loop invariants, fold them into the recurrence.
2869 if (!LIOps.empty()) {
2870 // Compute nowrap flags for the addition of the loop-invariant ops and
2871 // the addrec. Temporarily push it as an operand for that purpose. These
2872 // flags are valid in the scope of the addrec only.
2873 LIOps.push_back(AddRec);
2874 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2875 LIOps.pop_back();
2876
2877 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2878 LIOps.push_back(AddRec->getStart());
2879
2880 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2881
2882 // It is not in general safe to propagate flags valid on an add within
2883 // the addrec scope to one outside it. We must prove that the inner
2884 // scope is guaranteed to execute if the outer one does to be able to
2885 // safely propagate. We know the program is undefined if poison is
2886 // produced on the inner scoped addrec. We also know that *for this use*
2887 // the outer scoped add can't overflow (because of the flags we just
2888 // computed for the inner scoped add) without the program being undefined.
2889 // Proving that entry to the outer scope neccesitates entry to the inner
2890 // scope, thus proves the program undefined if the flags would be violated
2891 // in the outer scope.
2892 SCEV::NoWrapFlags AddFlags = Flags;
2893 if (AddFlags != SCEV::FlagAnyWrap) {
2894 auto *DefI = getDefiningScopeBound(LIOps);
2895 auto *ReachI = &*AddRecLoop->getHeader()->begin();
2896 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2897 AddFlags = SCEV::FlagAnyWrap;
2898 }
2899 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2900
2901 // Build the new addrec. Propagate the NUW and NSW flags if both the
2902 // outer add and the inner addrec are guaranteed to have no overflow.
2903 // Always propagate NW.
2904 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2905 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2906
2907 // If all of the other operands were loop invariant, we are done.
2908 if (Ops.size() == 1) return NewRec;
2909
2910 // Otherwise, add the folded AddRec by the non-invariant parts.
2911 for (unsigned i = 0;; ++i)
2912 if (Ops[i] == AddRec) {
2913 Ops[i] = NewRec;
2914 break;
2915 }
2916 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2917 }
2918
2919 // Okay, if there weren't any loop invariants to be folded, check to see if
2920 // there are multiple AddRec's with the same loop induction variable being
2921 // added together. If so, we can fold them.
2922 for (unsigned OtherIdx = Idx+1;
2923 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2924 ++OtherIdx) {
2925 // We expect the AddRecExpr's to be sorted in reverse dominance order,
2926 // so that the 1st found AddRecExpr is dominated by all others.
2927 assert(DT.dominates(
2928 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2929 AddRec->getLoop()->getHeader()) &&
2930 "AddRecExprs are not sorted in reverse dominance order?");
2931 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2932 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2933 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2934 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2935 ++OtherIdx) {
2936 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2937 if (OtherAddRec->getLoop() == AddRecLoop) {
2938 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2939 i != e; ++i) {
2940 if (i >= AddRecOps.size()) {
2941 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
2942 break;
2943 }
2945 AddRecOps[i], OtherAddRec->getOperand(i)};
2946 AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2947 }
2948 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2949 }
2950 }
2951 // Step size has changed, so we cannot guarantee no self-wraparound.
2952 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2953 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2954 }
2955 }
2956
2957 // Otherwise couldn't fold anything into this recurrence. Move onto the
2958 // next one.
2959 }
2960
2961 // Okay, it looks like we really DO need an add expr. Check to see if we
2962 // already have one, otherwise create a new one.
2963 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2964}
2965
2966const SCEV *
2967ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2968 SCEV::NoWrapFlags Flags) {
2970 ID.AddInteger(scAddExpr);
2971 for (const SCEV *Op : Ops)
2972 ID.AddPointer(Op);
2973 void *IP = nullptr;
2974 SCEVAddExpr *S =
2975 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2976 if (!S) {
2977 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2979 S = new (SCEVAllocator)
2980 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
2981 UniqueSCEVs.InsertNode(S, IP);
2982 registerUser(S, Ops);
2983 }
2984 S->setNoWrapFlags(Flags);
2985 return S;
2986}
2987
2988const SCEV *
2989ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
2990 const Loop *L, SCEV::NoWrapFlags Flags) {
2991 FoldingSetNodeID ID;
2992 ID.AddInteger(scAddRecExpr);
2993 for (const SCEV *Op : Ops)
2994 ID.AddPointer(Op);
2995 ID.AddPointer(L);
2996 void *IP = nullptr;
2997 SCEVAddRecExpr *S =
2998 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2999 if (!S) {
3000 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3002 S = new (SCEVAllocator)
3003 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3004 UniqueSCEVs.InsertNode(S, IP);
3005 LoopUsers[L].push_back(S);
3006 registerUser(S, Ops);
3007 }
3008 setNoWrapFlags(S, Flags);
3009 return S;
3010}
3011
3012const SCEV *
3013ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
3014 SCEV::NoWrapFlags Flags) {
3015 FoldingSetNodeID ID;
3016 ID.AddInteger(scMulExpr);
3017 for (const SCEV *Op : Ops)
3018 ID.AddPointer(Op);
3019 void *IP = nullptr;
3020 SCEVMulExpr *S =
3021 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3022 if (!S) {
3023 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3025 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3026 O, Ops.size());
3027 UniqueSCEVs.InsertNode(S, IP);
3028 registerUser(S, Ops);
3029 }
3030 S->setNoWrapFlags(Flags);
3031 return S;
3032}
3033
3034static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3035 uint64_t k = i*j;
3036 if (j > 1 && k / j != i) Overflow = true;
3037 return k;
3038}
3039
3040/// Compute the result of "n choose k", the binomial coefficient. If an
3041/// intermediate computation overflows, Overflow will be set and the return will
3042/// be garbage. Overflow is not cleared on absence of overflow.
3043static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3044 // We use the multiplicative formula:
3045 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3046 // At each iteration, we take the n-th term of the numeral and divide by the
3047 // (k-n)th term of the denominator. This division will always produce an
3048 // integral result, and helps reduce the chance of overflow in the
3049 // intermediate computations. However, we can still overflow even when the
3050 // final result would fit.
3051
3052 if (n == 0 || n == k) return 1;
3053 if (k > n) return 0;
3054
3055 if (k > n/2)
3056 k = n-k;
3057
3058 uint64_t r = 1;
3059 for (uint64_t i = 1; i <= k; ++i) {
3060 r = umul_ov(r, n-(i-1), Overflow);
3061 r /= i;
3062 }
3063 return r;
3064}
3065
3066/// Determine if any of the operands in this SCEV are a constant or if
3067/// any of the add or multiply expressions in this SCEV contain a constant.
3068static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3069 struct FindConstantInAddMulChain {
3070 bool FoundConstant = false;
3071
3072 bool follow(const SCEV *S) {
3073 FoundConstant |= isa<SCEVConstant>(S);
3074 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3075 }
3076
3077 bool isDone() const {
3078 return FoundConstant;
3079 }
3080 };
3081
3082 FindConstantInAddMulChain F;
3084 ST.visitAll(StartExpr);
3085 return F.FoundConstant;
3086}
3087
3088/// Get a canonical multiply expression, or something simpler if possible.
3090 SCEV::NoWrapFlags OrigFlags,
3091 unsigned Depth) {
3092 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3093 "only nuw or nsw allowed");
3094 assert(!Ops.empty() && "Cannot get empty mul!");
3095 if (Ops.size() == 1) return Ops[0];
3096#ifndef NDEBUG
3097 Type *ETy = Ops[0]->getType();
3098 assert(!ETy->isPointerTy());
3099 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3100 assert(Ops[i]->getType() == ETy &&
3101 "SCEVMulExpr operand types don't match!");
3102#endif
3103
3104 const SCEV *Folded = constantFoldAndGroupOps(
3105 *this, LI, DT, Ops,
3106 [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3107 [](const APInt &C) { return C.isOne(); }, // identity
3108 [](const APInt &C) { return C.isZero(); }); // absorber
3109 if (Folded)
3110 return Folded;
3111
3112 // Delay expensive flag strengthening until necessary.
3113 auto ComputeFlags = [this, OrigFlags](ArrayRef<const SCEV *> Ops) {
3114 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3115 };
3116
3117 // Limit recursion calls depth.
3119 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3120
3121 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3122 // Don't strengthen flags if we have no new information.
3123 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3124 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3125 Mul->setNoWrapFlags(ComputeFlags(Ops));
3126 return S;
3127 }
3128
3129 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3130 if (Ops.size() == 2) {
3131 // C1*(C2+V) -> C1*C2 + C1*V
3132 // If any of Add's ops are Adds or Muls with a constant, apply this
3133 // transformation as well.
3134 //
3135 // TODO: There are some cases where this transformation is not
3136 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3137 // this transformation should be narrowed down.
3138 const SCEV *Op0, *Op1;
3139 if (match(Ops[1], m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))) &&
3141 const SCEV *LHS = getMulExpr(LHSC, Op0, SCEV::FlagAnyWrap, Depth + 1);
3142 const SCEV *RHS = getMulExpr(LHSC, Op1, SCEV::FlagAnyWrap, Depth + 1);
3143 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3144 }
3145
3146 if (Ops[0]->isAllOnesValue()) {
3147 // If we have a mul by -1 of an add, try distributing the -1 among the
3148 // add operands.
3149 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3151 bool AnyFolded = false;
3152 for (const SCEV *AddOp : Add->operands()) {
3153 const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3154 Depth + 1);
3155 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3156 NewOps.push_back(Mul);
3157 }
3158 if (AnyFolded)
3159 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3160 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3161 // Negation preserves a recurrence's no self-wrap property.
3163 for (const SCEV *AddRecOp : AddRec->operands())
3164 Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3165 Depth + 1));
3166 // Let M be the minimum representable signed value. AddRec with nsw
3167 // multiplied by -1 can have signed overflow if and only if it takes a
3168 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3169 // maximum signed value. In all other cases signed overflow is
3170 // impossible.
3171 auto FlagsMask = SCEV::FlagNW;
3172 if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) {
3173 auto MinInt =
3174 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3175 if (getSignedRangeMin(AddRec) != MinInt)
3176 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3177 }
3178 return getAddRecExpr(Operands, AddRec->getLoop(),
3179 AddRec->getNoWrapFlags(FlagsMask));
3180 }
3181 }
3182
3183 // Try to push the constant operand into a ZExt: C * zext (A + B) ->
3184 // zext (C*A + C*B) if trunc (C) * (A + B) does not unsigned-wrap.
3185 const SCEVAddExpr *InnerAdd;
3186 if (match(Ops[1], m_scev_ZExt(m_scev_Add(InnerAdd)))) {
3187 const SCEV *NarrowC = getTruncateExpr(LHSC, InnerAdd->getType());
3188 if (isa<SCEVConstant>(InnerAdd->getOperand(0)) &&
3189 getZeroExtendExpr(NarrowC, Ops[1]->getType()) == LHSC &&
3190 hasFlags(StrengthenNoWrapFlags(this, scMulExpr, {NarrowC, InnerAdd},
3192 SCEV::FlagNUW)) {
3193 auto *Res = getMulExpr(NarrowC, InnerAdd, SCEV::FlagNUW, Depth + 1);
3194 return getZeroExtendExpr(Res, Ops[1]->getType(), Depth + 1);
3195 };
3196 }
3197
3198 // Try to fold (C1 * D /u C2) -> C1/C2 * D, if C1 and C2 are powers-of-2,
3199 // D is a multiple of C2, and C1 is a multiple of C2. If C2 is a multiple
3200 // of C1, fold to (D /u (C2 /u C1)).
3201 const SCEV *D;
3202 APInt C1V = LHSC->getAPInt();
3203 // (C1 * D /u C2) == -1 * -C1 * D /u C2 when C1 != INT_MIN. Don't treat -1
3204 // as -1 * 1, as it won't enable additional folds.
3205 if (C1V.isNegative() && !C1V.isMinSignedValue() && !C1V.isAllOnes())
3206 C1V = C1V.abs();
3207 const SCEVConstant *C2;
3208 if (C1V.isPowerOf2() &&
3210 C2->getAPInt().isPowerOf2() &&
3211 C1V.logBase2() <= getMinTrailingZeros(D)) {
3212 const SCEV *NewMul = nullptr;
3213 if (C1V.uge(C2->getAPInt())) {
3214 NewMul = getMulExpr(getUDivExpr(getConstant(C1V), C2), D);
3215 } else if (C2->getAPInt().logBase2() <= getMinTrailingZeros(D)) {
3216 assert(C1V.ugt(1) && "C1 <= 1 should have been folded earlier");
3217 NewMul = getUDivExpr(D, getUDivExpr(C2, getConstant(C1V)));
3218 }
3219 if (NewMul)
3220 return C1V == LHSC->getAPInt() ? NewMul : getNegativeSCEV(NewMul);
3221 }
3222 }
3223 }
3224
3225 // Skip over the add expression until we get to a multiply.
3226 unsigned Idx = 0;
3227 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3228 ++Idx;
3229
3230 // If there are mul operands inline them all into this expression.
3231 if (Idx < Ops.size()) {
3232 bool DeletedMul = false;
3233 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3234 if (Ops.size() > MulOpsInlineThreshold)
3235 break;
3236 // If we have an mul, expand the mul operands onto the end of the
3237 // operands list.
3238 Ops.erase(Ops.begin()+Idx);
3239 append_range(Ops, Mul->operands());
3240 DeletedMul = true;
3241 }
3242
3243 // If we deleted at least one mul, we added operands to the end of the
3244 // list, and they are not necessarily sorted. Recurse to resort and
3245 // resimplify any operands we just acquired.
3246 if (DeletedMul)
3247 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3248 }
3249
3250 // If there are any add recurrences in the operands list, see if any other
3251 // added values are loop invariant. If so, we can fold them into the
3252 // recurrence.
3253 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3254 ++Idx;
3255
3256 // Scan over all recurrences, trying to fold loop invariants into them.
3257 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3258 // Scan all of the other operands to this mul and add them to the vector
3259 // if they are loop invariant w.r.t. the recurrence.
3261 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3262 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3263 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3264 LIOps.push_back(Ops[i]);
3265 Ops.erase(Ops.begin()+i);
3266 --i; --e;
3267 }
3268
3269 // If we found some loop invariants, fold them into the recurrence.
3270 if (!LIOps.empty()) {
3271 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3273 NewOps.reserve(AddRec->getNumOperands());
3274 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3275
3276 // If both the mul and addrec are nuw, we can preserve nuw.
3277 // If both the mul and addrec are nsw, we can only preserve nsw if either
3278 // a) they are also nuw, or
3279 // b) all multiplications of addrec operands with scale are nsw.
3280 SCEV::NoWrapFlags Flags =
3281 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3282
3283 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3284 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3285 SCEV::FlagAnyWrap, Depth + 1));
3286
3287 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3289 Instruction::Mul, getSignedRange(Scale),
3291 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3292 Flags = clearFlags(Flags, SCEV::FlagNSW);
3293 }
3294 }
3295
3296 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3297
3298 // If all of the other operands were loop invariant, we are done.
3299 if (Ops.size() == 1) return NewRec;
3300
3301 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3302 for (unsigned i = 0;; ++i)
3303 if (Ops[i] == AddRec) {
3304 Ops[i] = NewRec;
3305 break;
3306 }
3307 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3308 }
3309
3310 // Okay, if there weren't any loop invariants to be folded, check to see
3311 // if there are multiple AddRec's with the same loop induction variable
3312 // being multiplied together. If so, we can fold them.
3313
3314 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3315 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3316 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3317 // ]]],+,...up to x=2n}.
3318 // Note that the arguments to choose() are always integers with values
3319 // known at compile time, never SCEV objects.
3320 //
3321 // The implementation avoids pointless extra computations when the two
3322 // addrec's are of different length (mathematically, it's equivalent to
3323 // an infinite stream of zeros on the right).
3324 bool OpsModified = false;
3325 for (unsigned OtherIdx = Idx+1;
3326 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3327 ++OtherIdx) {
3328 const SCEVAddRecExpr *OtherAddRec =
3329 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3330 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3331 continue;
3332
3333 // Limit max number of arguments to avoid creation of unreasonably big
3334 // SCEVAddRecs with very complex operands.
3335 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3336 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3337 continue;
3338
3339 bool Overflow = false;
3340 Type *Ty = AddRec->getType();
3341 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3343 for (int x = 0, xe = AddRec->getNumOperands() +
3344 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3346 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3347 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3348 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3349 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3350 z < ze && !Overflow; ++z) {
3351 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3352 uint64_t Coeff;
3353 if (LargerThan64Bits)
3354 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3355 else
3356 Coeff = Coeff1*Coeff2;
3357 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3358 const SCEV *Term1 = AddRec->getOperand(y-z);
3359 const SCEV *Term2 = OtherAddRec->getOperand(z);
3360 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3361 SCEV::FlagAnyWrap, Depth + 1));
3362 }
3363 }
3364 if (SumOps.empty())
3365 SumOps.push_back(getZero(Ty));
3366 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3367 }
3368 if (!Overflow) {
3369 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3371 if (Ops.size() == 2) return NewAddRec;
3372 Ops[Idx] = NewAddRec;
3373 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3374 OpsModified = true;
3375 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3376 if (!AddRec)
3377 break;
3378 }
3379 }
3380 if (OpsModified)
3381 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3382
3383 // Otherwise couldn't fold anything into this recurrence. Move onto the
3384 // next one.
3385 }
3386
3387 // Okay, it looks like we really DO need an mul expr. Check to see if we
3388 // already have one, otherwise create a new one.
3389 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3390}
3391
3392/// Represents an unsigned remainder expression based on unsigned division.
3394 const SCEV *RHS) {
3395 assert(getEffectiveSCEVType(LHS->getType()) ==
3396 getEffectiveSCEVType(RHS->getType()) &&
3397 "SCEVURemExpr operand types don't match!");
3398
3399 // Short-circuit easy cases
3400 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3401 // If constant is one, the result is trivial
3402 if (RHSC->getValue()->isOne())
3403 return getZero(LHS->getType()); // X urem 1 --> 0
3404
3405 // If constant is a power of two, fold into a zext(trunc(LHS)).
3406 if (RHSC->getAPInt().isPowerOf2()) {
3407 Type *FullTy = LHS->getType();
3408 Type *TruncTy =
3409 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3410 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3411 }
3412 }
3413
3414 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3415 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3416 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3417 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3418}
3419
3420/// Get a canonical unsigned division expression, or something simpler if
3421/// possible.
3423 const SCEV *RHS) {
3424 assert(!LHS->getType()->isPointerTy() &&
3425 "SCEVUDivExpr operand can't be pointer!");
3426 assert(LHS->getType() == RHS->getType() &&
3427 "SCEVUDivExpr operand types don't match!");
3428
3430 ID.AddInteger(scUDivExpr);
3431 ID.AddPointer(LHS);
3432 ID.AddPointer(RHS);
3433 void *IP = nullptr;
3434 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3435 return S;
3436
3437 // 0 udiv Y == 0
3438 if (match(LHS, m_scev_Zero()))
3439 return LHS;
3440
3441 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3442 if (RHSC->getValue()->isOne())
3443 return LHS; // X udiv 1 --> x
3444 // If the denominator is zero, the result of the udiv is undefined. Don't
3445 // try to analyze it, because the resolution chosen here may differ from
3446 // the resolution chosen in other parts of the compiler.
3447 if (!RHSC->getValue()->isZero()) {
3448 // Determine if the division can be folded into the operands of
3449 // its operands.
3450 // TODO: Generalize this to non-constants by using known-bits information.
3451 Type *Ty = LHS->getType();
3452 unsigned LZ = RHSC->getAPInt().countl_zero();
3453 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3454 // For non-power-of-two values, effectively round the value up to the
3455 // nearest power of two.
3456 if (!RHSC->getAPInt().isPowerOf2())
3457 ++MaxShiftAmt;
3458 IntegerType *ExtTy =
3459 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3460 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3461 if (const SCEVConstant *Step =
3462 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3463 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3464 const APInt &StepInt = Step->getAPInt();
3465 const APInt &DivInt = RHSC->getAPInt();
3466 if (!StepInt.urem(DivInt) &&
3467 getZeroExtendExpr(AR, ExtTy) ==
3468 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3469 getZeroExtendExpr(Step, ExtTy),
3470 AR->getLoop(), SCEV::FlagAnyWrap)) {
3472 for (const SCEV *Op : AR->operands())
3473 Operands.push_back(getUDivExpr(Op, RHS));
3474 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3475 }
3476 /// Get a canonical UDivExpr for a recurrence.
3477 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3478 const APInt *StartRem;
3479 if (!DivInt.urem(StepInt) && match(getURemExpr(AR->getStart(), Step),
3480 m_scev_APInt(StartRem))) {
3481 bool NoWrap =
3482 getZeroExtendExpr(AR, ExtTy) ==
3483 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3484 getZeroExtendExpr(Step, ExtTy), AR->getLoop(),
3486
3487 // With N <= C and both N, C as powers-of-2, the transformation
3488 // {X,+,N}/C => {(X - X%N),+,N}/C preserves division results even
3489 // if wrapping occurs, as the division results remain equivalent for
3490 // all offsets in [[(X - X%N), X).
3491 bool CanFoldWithWrap = StepInt.ule(DivInt) && // N <= C
3492 StepInt.isPowerOf2() && DivInt.isPowerOf2();
3493 // Only fold if the subtraction can be folded in the start
3494 // expression.
3495 const SCEV *NewStart =
3496 getMinusSCEV(AR->getStart(), getConstant(*StartRem));
3497 if (*StartRem != 0 && (NoWrap || CanFoldWithWrap) &&
3498 !isa<SCEVAddExpr>(NewStart)) {
3499 const SCEV *NewLHS =
3500 getAddRecExpr(NewStart, Step, AR->getLoop(),
3501 NoWrap ? SCEV::FlagNW : SCEV::FlagAnyWrap);
3502 if (LHS != NewLHS) {
3503 LHS = NewLHS;
3504
3505 // Reset the ID to include the new LHS, and check if it is
3506 // already cached.
3507 ID.clear();
3508 ID.AddInteger(scUDivExpr);
3509 ID.AddPointer(LHS);
3510 ID.AddPointer(RHS);
3511 IP = nullptr;
3512 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3513 return S;
3514 }
3515 }
3516 }
3517 }
3518 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3519 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3521 for (const SCEV *Op : M->operands())
3522 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3523 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3524 // Find an operand that's safely divisible.
3525 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3526 const SCEV *Op = M->getOperand(i);
3527 const SCEV *Div = getUDivExpr(Op, RHSC);
3528 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3529 Operands = SmallVector<const SCEV *, 4>(M->operands());
3530 Operands[i] = Div;
3531 return getMulExpr(Operands);
3532 }
3533 }
3534 }
3535
3536 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3537 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3538 if (auto *DivisorConstant =
3539 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3540 bool Overflow = false;
3541 APInt NewRHS =
3542 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3543 if (Overflow) {
3544 return getConstant(RHSC->getType(), 0, false);
3545 }
3546 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3547 }
3548 }
3549
3550 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3551 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3553 for (const SCEV *Op : A->operands())
3554 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3555 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3556 Operands.clear();
3557 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3558 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3559 if (isa<SCEVUDivExpr>(Op) ||
3560 getMulExpr(Op, RHS) != A->getOperand(i))
3561 break;
3562 Operands.push_back(Op);
3563 }
3564 if (Operands.size() == A->getNumOperands())
3565 return getAddExpr(Operands);
3566 }
3567 }
3568
3569 // Fold if both operands are constant.
3570 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3571 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3572 }
3573 }
3574
3575 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3576 const APInt *NegC, *C;
3577 if (match(LHS,
3580 NegC->isNegative() && !NegC->isMinSignedValue() && *C == -*NegC)
3581 return getZero(LHS->getType());
3582
3583 // TODO: Generalize to handle any common factors.
3584 // udiv (mul nuw a, vscale), (mul nuw b, vscale) --> udiv a, b
3585 const SCEV *NewLHS, *NewRHS;
3586 if (match(LHS, m_scev_c_NUWMul(m_SCEV(NewLHS), m_SCEVVScale())) &&
3587 match(RHS, m_scev_c_NUWMul(m_SCEV(NewRHS), m_SCEVVScale())))
3588 return getUDivExpr(NewLHS, NewRHS);
3589
3590 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3591 // changes). Make sure we get a new one.
3592 IP = nullptr;
3593 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3594 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3595 LHS, RHS);
3596 UniqueSCEVs.InsertNode(S, IP);
3597 registerUser(S, {LHS, RHS});
3598 return S;
3599}
3600
3601APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3602 APInt A = C1->getAPInt().abs();
3603 APInt B = C2->getAPInt().abs();
3604 uint32_t ABW = A.getBitWidth();
3605 uint32_t BBW = B.getBitWidth();
3606
3607 if (ABW > BBW)
3608 B = B.zext(ABW);
3609 else if (ABW < BBW)
3610 A = A.zext(BBW);
3611
3612 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3613}
3614
3615/// Get a canonical unsigned division expression, or something simpler if
3616/// possible. There is no representation for an exact udiv in SCEV IR, but we
3617/// can attempt to remove factors from the LHS and RHS. We can't do this when
3618/// it's not exact because the udiv may be clearing bits.
3620 const SCEV *RHS) {
3621 // TODO: we could try to find factors in all sorts of things, but for now we
3622 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3623 // end of this file for inspiration.
3624
3626 if (!Mul || !Mul->hasNoUnsignedWrap())
3627 return getUDivExpr(LHS, RHS);
3628
3629 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3630 // If the mulexpr multiplies by a constant, then that constant must be the
3631 // first element of the mulexpr.
3632 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3633 if (LHSCst == RHSCst) {
3634 SmallVector<const SCEV *, 2> Operands(drop_begin(Mul->operands()));
3635 return getMulExpr(Operands);
3636 }
3637
3638 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3639 // that there's a factor provided by one of the other terms. We need to
3640 // check.
3641 APInt Factor = gcd(LHSCst, RHSCst);
3642 if (!Factor.isIntN(1)) {
3643 LHSCst =
3644 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3645 RHSCst =
3646 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3648 Operands.push_back(LHSCst);
3649 append_range(Operands, Mul->operands().drop_front());
3650 LHS = getMulExpr(Operands);
3651 RHS = RHSCst;
3653 if (!Mul)
3654 return getUDivExactExpr(LHS, RHS);
3655 }
3656 }
3657 }
3658
3659 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3660 if (Mul->getOperand(i) == RHS) {
3662 append_range(Operands, Mul->operands().take_front(i));
3663 append_range(Operands, Mul->operands().drop_front(i + 1));
3664 return getMulExpr(Operands);
3665 }
3666 }
3667
3668 return getUDivExpr(LHS, RHS);
3669}
3670
3671/// Get an add recurrence expression for the specified loop. Simplify the
3672/// expression as much as possible.
3673const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3674 const Loop *L,
3675 SCEV::NoWrapFlags Flags) {
3677 Operands.push_back(Start);
3678 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3679 if (StepChrec->getLoop() == L) {
3680 append_range(Operands, StepChrec->operands());
3681 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3682 }
3683
3684 Operands.push_back(Step);
3685 return getAddRecExpr(Operands, L, Flags);
3686}
3687
3688/// Get an add recurrence expression for the specified loop. Simplify the
3689/// expression as much as possible.
3690const SCEV *
3692 const Loop *L, SCEV::NoWrapFlags Flags) {
3693 if (Operands.size() == 1) return Operands[0];
3694#ifndef NDEBUG
3695 Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3696 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3697 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3698 "SCEVAddRecExpr operand types don't match!");
3699 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3700 }
3701 for (const SCEV *Op : Operands)
3703 "SCEVAddRecExpr operand is not available at loop entry!");
3704#endif
3705
3706 if (Operands.back()->isZero()) {
3707 Operands.pop_back();
3708 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3709 }
3710
3711 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3712 // use that information to infer NUW and NSW flags. However, computing a
3713 // BE count requires calling getAddRecExpr, so we may not yet have a
3714 // meaningful BE count at this point (and if we don't, we'd be stuck
3715 // with a SCEVCouldNotCompute as the cached BE count).
3716
3717 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3718
3719 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3720 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3721 const Loop *NestedLoop = NestedAR->getLoop();
3722 if (L->contains(NestedLoop)
3723 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3724 : (!NestedLoop->contains(L) &&
3725 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3726 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3727 Operands[0] = NestedAR->getStart();
3728 // AddRecs require their operands be loop-invariant with respect to their
3729 // loops. Don't perform this transformation if it would break this
3730 // requirement.
3731 bool AllInvariant = all_of(
3732 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3733
3734 if (AllInvariant) {
3735 // Create a recurrence for the outer loop with the same step size.
3736 //
3737 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3738 // inner recurrence has the same property.
3739 SCEV::NoWrapFlags OuterFlags =
3740 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3741
3742 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3743 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3744 return isLoopInvariant(Op, NestedLoop);
3745 });
3746
3747 if (AllInvariant) {
3748 // Ok, both add recurrences are valid after the transformation.
3749 //
3750 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3751 // the outer recurrence has the same property.
3752 SCEV::NoWrapFlags InnerFlags =
3753 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3754 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3755 }
3756 }
3757 // Reset Operands to its original state.
3758 Operands[0] = NestedAR;
3759 }
3760 }
3761
3762 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3763 // already have one, otherwise create a new one.
3764 return getOrCreateAddRecExpr(Operands, L, Flags);
3765}
3766
3768 ArrayRef<const SCEV *> IndexExprs) {
3769 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3770 // getSCEV(Base)->getType() has the same address space as Base->getType()
3771 // because SCEV::getType() preserves the address space.
3772 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3773 if (NW != GEPNoWrapFlags::none()) {
3774 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3775 // but to do that, we have to ensure that said flag is valid in the entire
3776 // defined scope of the SCEV.
3777 // TODO: non-instructions have global scope. We might be able to prove
3778 // some global scope cases
3779 auto *GEPI = dyn_cast<Instruction>(GEP);
3780 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3781 NW = GEPNoWrapFlags::none();
3782 }
3783
3784 return getGEPExpr(BaseExpr, IndexExprs, GEP->getSourceElementType(), NW);
3785}
3786
3788 ArrayRef<const SCEV *> IndexExprs,
3789 Type *SrcElementTy, GEPNoWrapFlags NW) {
3791 if (NW.hasNoUnsignedSignedWrap())
3792 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3793 if (NW.hasNoUnsignedWrap())
3794 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3795
3796 Type *CurTy = BaseExpr->getType();
3797 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3798 bool FirstIter = true;
3800 for (const SCEV *IndexExpr : IndexExprs) {
3801 // Compute the (potentially symbolic) offset in bytes for this index.
3802 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3803 // For a struct, add the member offset.
3804 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3805 unsigned FieldNo = Index->getZExtValue();
3806 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3807 Offsets.push_back(FieldOffset);
3808
3809 // Update CurTy to the type of the field at Index.
3810 CurTy = STy->getTypeAtIndex(Index);
3811 } else {
3812 // Update CurTy to its element type.
3813 if (FirstIter) {
3814 assert(isa<PointerType>(CurTy) &&
3815 "The first index of a GEP indexes a pointer");
3816 CurTy = SrcElementTy;
3817 FirstIter = false;
3818 } else {
3820 }
3821 // For an array, add the element offset, explicitly scaled.
3822 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3823 // Getelementptr indices are signed.
3824 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3825
3826 // Multiply the index by the element size to compute the element offset.
3827 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3828 Offsets.push_back(LocalOffset);
3829 }
3830 }
3831
3832 // Handle degenerate case of GEP without offsets.
3833 if (Offsets.empty())
3834 return BaseExpr;
3835
3836 // Add the offsets together, assuming nsw if inbounds.
3837 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3838 // Add the base address and the offset. We cannot use the nsw flag, as the
3839 // base address is unsigned. However, if we know that the offset is
3840 // non-negative, we can use nuw.
3841 bool NUW = NW.hasNoUnsignedWrap() ||
3844 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3845 assert(BaseExpr->getType() == GEPExpr->getType() &&
3846 "GEP should not change type mid-flight.");
3847 return GEPExpr;
3848}
3849
3850SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3853 ID.AddInteger(SCEVType);
3854 for (const SCEV *Op : Ops)
3855 ID.AddPointer(Op);
3856 void *IP = nullptr;
3857 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3858}
3859
3860const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3862 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3863}
3864
3867 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3868 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3869 if (Ops.size() == 1) return Ops[0];
3870#ifndef NDEBUG
3871 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3872 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3873 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3874 "Operand types don't match!");
3875 assert(Ops[0]->getType()->isPointerTy() ==
3876 Ops[i]->getType()->isPointerTy() &&
3877 "min/max should be consistently pointerish");
3878 }
3879#endif
3880
3881 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3882 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3883
3884 const SCEV *Folded = constantFoldAndGroupOps(
3885 *this, LI, DT, Ops,
3886 [&](const APInt &C1, const APInt &C2) {
3887 switch (Kind) {
3888 case scSMaxExpr:
3889 return APIntOps::smax(C1, C2);
3890 case scSMinExpr:
3891 return APIntOps::smin(C1, C2);
3892 case scUMaxExpr:
3893 return APIntOps::umax(C1, C2);
3894 case scUMinExpr:
3895 return APIntOps::umin(C1, C2);
3896 default:
3897 llvm_unreachable("Unknown SCEV min/max opcode");
3898 }
3899 },
3900 [&](const APInt &C) {
3901 // identity
3902 if (IsMax)
3903 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3904 else
3905 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3906 },
3907 [&](const APInt &C) {
3908 // absorber
3909 if (IsMax)
3910 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3911 else
3912 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3913 });
3914 if (Folded)
3915 return Folded;
3916
3917 // Check if we have created the same expression before.
3918 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3919 return S;
3920 }
3921
3922 // Find the first operation of the same kind
3923 unsigned Idx = 0;
3924 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3925 ++Idx;
3926
3927 // Check to see if one of the operands is of the same kind. If so, expand its
3928 // operands onto our operand list, and recurse to simplify.
3929 if (Idx < Ops.size()) {
3930 bool DeletedAny = false;
3931 while (Ops[Idx]->getSCEVType() == Kind) {
3932 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3933 Ops.erase(Ops.begin()+Idx);
3934 append_range(Ops, SMME->operands());
3935 DeletedAny = true;
3936 }
3937
3938 if (DeletedAny)
3939 return getMinMaxExpr(Kind, Ops);
3940 }
3941
3942 // Okay, check to see if the same value occurs in the operand list twice. If
3943 // so, delete one. Since we sorted the list, these values are required to
3944 // be adjacent.
3949 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3950 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3951 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3952 if (Ops[i] == Ops[i + 1] ||
3953 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3954 // X op Y op Y --> X op Y
3955 // X op Y --> X, if we know X, Y are ordered appropriately
3956 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3957 --i;
3958 --e;
3959 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3960 Ops[i + 1])) {
3961 // X op Y --> Y, if we know X, Y are ordered appropriately
3962 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3963 --i;
3964 --e;
3965 }
3966 }
3967
3968 if (Ops.size() == 1) return Ops[0];
3969
3970 assert(!Ops.empty() && "Reduced smax down to nothing!");
3971
3972 // Okay, it looks like we really DO need an expr. Check to see if we
3973 // already have one, otherwise create a new one.
3975 ID.AddInteger(Kind);
3976 for (const SCEV *Op : Ops)
3977 ID.AddPointer(Op);
3978 void *IP = nullptr;
3979 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3980 if (ExistingSCEV)
3981 return ExistingSCEV;
3982 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3984 SCEV *S = new (SCEVAllocator)
3985 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3986
3987 UniqueSCEVs.InsertNode(S, IP);
3988 registerUser(S, Ops);
3989 return S;
3990}
3991
3992namespace {
3993
3994class SCEVSequentialMinMaxDeduplicatingVisitor final
3995 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
3996 std::optional<const SCEV *>> {
3997 using RetVal = std::optional<const SCEV *>;
3999
4000 ScalarEvolution &SE;
4001 const SCEVTypes RootKind; // Must be a sequential min/max expression.
4002 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
4004
4005 bool canRecurseInto(SCEVTypes Kind) const {
4006 // We can only recurse into the SCEV expression of the same effective type
4007 // as the type of our root SCEV expression.
4008 return RootKind == Kind || NonSequentialRootKind == Kind;
4009 };
4010
4011 RetVal visitAnyMinMaxExpr(const SCEV *S) {
4013 "Only for min/max expressions.");
4014 SCEVTypes Kind = S->getSCEVType();
4015
4016 if (!canRecurseInto(Kind))
4017 return S;
4018
4019 auto *NAry = cast<SCEVNAryExpr>(S);
4021 bool Changed = visit(Kind, NAry->operands(), NewOps);
4022
4023 if (!Changed)
4024 return S;
4025 if (NewOps.empty())
4026 return std::nullopt;
4027
4029 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
4030 : SE.getMinMaxExpr(Kind, NewOps);
4031 }
4032
4033 RetVal visit(const SCEV *S) {
4034 // Has the whole operand been seen already?
4035 if (!SeenOps.insert(S).second)
4036 return std::nullopt;
4037 return Base::visit(S);
4038 }
4039
4040public:
4041 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4042 SCEVTypes RootKind)
4043 : SE(SE), RootKind(RootKind),
4044 NonSequentialRootKind(
4045 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4046 RootKind)) {}
4047
4048 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
4049 SmallVectorImpl<const SCEV *> &NewOps) {
4050 bool Changed = false;
4052 Ops.reserve(OrigOps.size());
4053
4054 for (const SCEV *Op : OrigOps) {
4055 RetVal NewOp = visit(Op);
4056 if (NewOp != Op)
4057 Changed = true;
4058 if (NewOp)
4059 Ops.emplace_back(*NewOp);
4060 }
4061
4062 if (Changed)
4063 NewOps = std::move(Ops);
4064 return Changed;
4065 }
4066
4067 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4068
4069 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4070
4071 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4072
4073 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4074
4075 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4076
4077 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4078
4079 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4080
4081 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4082
4083 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4084
4085 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4086
4087 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4088 return visitAnyMinMaxExpr(Expr);
4089 }
4090
4091 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4092 return visitAnyMinMaxExpr(Expr);
4093 }
4094
4095 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4096 return visitAnyMinMaxExpr(Expr);
4097 }
4098
4099 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4100 return visitAnyMinMaxExpr(Expr);
4101 }
4102
4103 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4104 return visitAnyMinMaxExpr(Expr);
4105 }
4106
4107 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4108
4109 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4110};
4111
4112} // namespace
4113
4115 switch (Kind) {
4116 case scConstant:
4117 case scVScale:
4118 case scTruncate:
4119 case scZeroExtend:
4120 case scSignExtend:
4121 case scPtrToInt:
4122 case scAddExpr:
4123 case scMulExpr:
4124 case scUDivExpr:
4125 case scAddRecExpr:
4126 case scUMaxExpr:
4127 case scSMaxExpr:
4128 case scUMinExpr:
4129 case scSMinExpr:
4130 case scUnknown:
4131 // If any operand is poison, the whole expression is poison.
4132 return true;
4134 // FIXME: if the *first* operand is poison, the whole expression is poison.
4135 return false; // Pessimistically, say that it does not propagate poison.
4136 case scCouldNotCompute:
4137 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4138 }
4139 llvm_unreachable("Unknown SCEV kind!");
4140}
4141
4142namespace {
4143// The only way poison may be introduced in a SCEV expression is from a
4144// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4145// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4146// introduce poison -- they encode guaranteed, non-speculated knowledge.
4147//
4148// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4149// with the notable exception of umin_seq, where only poison from the first
4150// operand is (unconditionally) propagated.
4151struct SCEVPoisonCollector {
4152 bool LookThroughMaybePoisonBlocking;
4153 SmallPtrSet<const SCEVUnknown *, 4> MaybePoison;
4154 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4155 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4156
4157 bool follow(const SCEV *S) {
4158 if (!LookThroughMaybePoisonBlocking &&
4160 return false;
4161
4162 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4163 if (!isGuaranteedNotToBePoison(SU->getValue()))
4164 MaybePoison.insert(SU);
4165 }
4166 return true;
4167 }
4168 bool isDone() const { return false; }
4169};
4170} // namespace
4171
4172/// Return true if V is poison given that AssumedPoison is already poison.
4173static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4174 // First collect all SCEVs that might result in AssumedPoison to be poison.
4175 // We need to look through potentially poison-blocking operations here,
4176 // because we want to find all SCEVs that *might* result in poison, not only
4177 // those that are *required* to.
4178 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4179 visitAll(AssumedPoison, PC1);
4180
4181 // AssumedPoison is never poison. As the assumption is false, the implication
4182 // is true. Don't bother walking the other SCEV in this case.
4183 if (PC1.MaybePoison.empty())
4184 return true;
4185
4186 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4187 // as well. We cannot look through potentially poison-blocking operations
4188 // here, as their arguments only *may* make the result poison.
4189 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4190 visitAll(S, PC2);
4191
4192 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4193 // it will also make S poison by being part of PC2.MaybePoison.
4194 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison);
4195}
4196
4198 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4199 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4200 visitAll(S, PC);
4201 for (const SCEVUnknown *SU : PC.MaybePoison)
4202 Result.insert(SU->getValue());
4203}
4204
4206 const SCEV *S, Instruction *I,
4207 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4208 // If the instruction cannot be poison, it's always safe to reuse.
4210 return true;
4211
4212 // Otherwise, it is possible that I is more poisonous that S. Collect the
4213 // poison-contributors of S, and then check whether I has any additional
4214 // poison-contributors. Poison that is contributed through poison-generating
4215 // flags is handled by dropping those flags instead.
4217 getPoisonGeneratingValues(PoisonVals, S);
4218
4219 SmallVector<Value *> Worklist;
4221 Worklist.push_back(I);
4222 while (!Worklist.empty()) {
4223 Value *V = Worklist.pop_back_val();
4224 if (!Visited.insert(V).second)
4225 continue;
4226
4227 // Avoid walking large instruction graphs.
4228 if (Visited.size() > 16)
4229 return false;
4230
4231 // Either the value can't be poison, or the S would also be poison if it
4232 // is.
4233 if (PoisonVals.contains(V) || ::isGuaranteedNotToBePoison(V))
4234 continue;
4235
4236 auto *I = dyn_cast<Instruction>(V);
4237 if (!I)
4238 return false;
4239
4240 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4241 // can't replace an arbitrary add with disjoint or, even if we drop the
4242 // flag. We would need to convert the or into an add.
4243 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4244 if (PDI->isDisjoint())
4245 return false;
4246
4247 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4248 // because SCEV currently assumes it can't be poison. Remove this special
4249 // case once we proper model when vscale can be poison.
4250 if (auto *II = dyn_cast<IntrinsicInst>(I);
4251 II && II->getIntrinsicID() == Intrinsic::vscale)
4252 continue;
4253
4254 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4255 return false;
4256
4257 // If the instruction can't create poison, we can recurse to its operands.
4258 if (I->hasPoisonGeneratingAnnotations())
4259 DropPoisonGeneratingInsts.push_back(I);
4260
4261 llvm::append_range(Worklist, I->operands());
4262 }
4263 return true;
4264}
4265
4266const SCEV *
4269 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4270 "Not a SCEVSequentialMinMaxExpr!");
4271 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4272 if (Ops.size() == 1)
4273 return Ops[0];
4274#ifndef NDEBUG
4275 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4276 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4277 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4278 "Operand types don't match!");
4279 assert(Ops[0]->getType()->isPointerTy() ==
4280 Ops[i]->getType()->isPointerTy() &&
4281 "min/max should be consistently pointerish");
4282 }
4283#endif
4284
4285 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4286 // so we can *NOT* do any kind of sorting of the expressions!
4287
4288 // Check if we have created the same expression before.
4289 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4290 return S;
4291
4292 // FIXME: there are *some* simplifications that we can do here.
4293
4294 // Keep only the first instance of an operand.
4295 {
4296 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4297 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4298 if (Changed)
4299 return getSequentialMinMaxExpr(Kind, Ops);
4300 }
4301
4302 // Check to see if one of the operands is of the same kind. If so, expand its
4303 // operands onto our operand list, and recurse to simplify.
4304 {
4305 unsigned Idx = 0;
4306 bool DeletedAny = false;
4307 while (Idx < Ops.size()) {
4308 if (Ops[Idx]->getSCEVType() != Kind) {
4309 ++Idx;
4310 continue;
4311 }
4312 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4313 Ops.erase(Ops.begin() + Idx);
4314 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4315 SMME->operands().end());
4316 DeletedAny = true;
4317 }
4318
4319 if (DeletedAny)
4320 return getSequentialMinMaxExpr(Kind, Ops);
4321 }
4322
4323 const SCEV *SaturationPoint;
4325 switch (Kind) {
4327 SaturationPoint = getZero(Ops[0]->getType());
4328 Pred = ICmpInst::ICMP_ULE;
4329 break;
4330 default:
4331 llvm_unreachable("Not a sequential min/max type.");
4332 }
4333
4334 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4335 if (!isGuaranteedNotToCauseUB(Ops[i]))
4336 continue;
4337 // We can replace %x umin_seq %y with %x umin %y if either:
4338 // * %y being poison implies %x is also poison.
4339 // * %x cannot be the saturating value (e.g. zero for umin).
4340 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4341 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4342 SaturationPoint)) {
4343 SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
4344 Ops[i - 1] = getMinMaxExpr(
4346 SeqOps);
4347 Ops.erase(Ops.begin() + i);
4348 return getSequentialMinMaxExpr(Kind, Ops);
4349 }
4350 // Fold %x umin_seq %y to %x if %x ule %y.
4351 // TODO: We might be able to prove the predicate for a later operand.
4352 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4353 Ops.erase(Ops.begin() + i);
4354 return getSequentialMinMaxExpr(Kind, Ops);
4355 }
4356 }
4357
4358 // Okay, it looks like we really DO need an expr. Check to see if we
4359 // already have one, otherwise create a new one.
4361 ID.AddInteger(Kind);
4362 for (const SCEV *Op : Ops)
4363 ID.AddPointer(Op);
4364 void *IP = nullptr;
4365 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4366 if (ExistingSCEV)
4367 return ExistingSCEV;
4368
4369 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4371 SCEV *S = new (SCEVAllocator)
4372 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4373
4374 UniqueSCEVs.InsertNode(S, IP);
4375 registerUser(S, Ops);
4376 return S;
4377}
4378
4379const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4380 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4381 return getSMaxExpr(Ops);
4382}
4383
4387
4388const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4389 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4390 return getUMaxExpr(Ops);
4391}
4392
4396
4398 const SCEV *RHS) {
4399 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4400 return getSMinExpr(Ops);
4401}
4402
4406
4407const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS,
4408 bool Sequential) {
4409 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4410 return getUMinExpr(Ops, Sequential);
4411}
4412
4418
4419const SCEV *
4421 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4422 if (Size.isScalable())
4423 Res = getMulExpr(Res, getVScale(IntTy));
4424 return Res;
4425}
4426
4428 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4429}
4430
4432 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4433}
4434
4436 StructType *STy,
4437 unsigned FieldNo) {
4438 // We can bypass creating a target-independent constant expression and then
4439 // folding it back into a ConstantInt. This is just a compile-time
4440 // optimization.
4441 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4442 assert(!SL->getSizeInBits().isScalable() &&
4443 "Cannot get offset for structure containing scalable vector types");
4444 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4445}
4446
4448 // Don't attempt to do anything other than create a SCEVUnknown object
4449 // here. createSCEV only calls getUnknown after checking for all other
4450 // interesting possibilities, and any other code that calls getUnknown
4451 // is doing so in order to hide a value from SCEV canonicalization.
4452
4454 ID.AddInteger(scUnknown);
4455 ID.AddPointer(V);
4456 void *IP = nullptr;
4457 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4458 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4459 "Stale SCEVUnknown in uniquing map!");
4460 return S;
4461 }
4462 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4463 FirstUnknown);
4464 FirstUnknown = cast<SCEVUnknown>(S);
4465 UniqueSCEVs.InsertNode(S, IP);
4466 return S;
4467}
4468
4469//===----------------------------------------------------------------------===//
4470// Basic SCEV Analysis and PHI Idiom Recognition Code
4471//
4472
4473/// Test if values of the given type are analyzable within the SCEV
4474/// framework. This primarily includes integer types, and it can optionally
4475/// include pointer types if the ScalarEvolution class has access to
4476/// target-specific information.
4478 // Integers and pointers are always SCEVable.
4479 return Ty->isIntOrPtrTy();
4480}
4481
4482/// Return the size in bits of the specified type, for which isSCEVable must
4483/// return true.
4485 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4486 if (Ty->isPointerTy())
4488 return getDataLayout().getTypeSizeInBits(Ty);
4489}
4490
4491/// Return a type with the same bitwidth as the given type and which represents
4492/// how SCEV will treat the given type, for which isSCEVable must return
4493/// true. For pointer types, this is the pointer index sized integer type.
4495 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4496
4497 if (Ty->isIntegerTy())
4498 return Ty;
4499
4500 // The only other support type is pointer.
4501 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4502 return getDataLayout().getIndexType(Ty);
4503}
4504
4506 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4507}
4508
4510 const SCEV *B) {
4511 /// For a valid use point to exist, the defining scope of one operand
4512 /// must dominate the other.
4513 bool PreciseA, PreciseB;
4514 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4515 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4516 if (!PreciseA || !PreciseB)
4517 // Can't tell.
4518 return false;
4519 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4520 DT.dominates(ScopeB, ScopeA);
4521}
4522
4524 return CouldNotCompute.get();
4525}
4526
4527bool ScalarEvolution::checkValidity(const SCEV *S) const {
4528 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4529 auto *SU = dyn_cast<SCEVUnknown>(S);
4530 return SU && SU->getValue() == nullptr;
4531 });
4532
4533 return !ContainsNulls;
4534}
4535
4537 HasRecMapType::iterator I = HasRecMap.find(S);
4538 if (I != HasRecMap.end())
4539 return I->second;
4540
4541 bool FoundAddRec =
4542 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4543 HasRecMap.insert({S, FoundAddRec});
4544 return FoundAddRec;
4545}
4546
4547/// Return the ValueOffsetPair set for \p S. \p S can be represented
4548/// by the value and offset from any ValueOffsetPair in the set.
4549ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4550 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4551 if (SI == ExprValueMap.end())
4552 return {};
4553 return SI->second.getArrayRef();
4554}
4555
4556/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4557/// cannot be used separately. eraseValueFromMap should be used to remove
4558/// V from ValueExprMap and ExprValueMap at the same time.
4559void ScalarEvolution::eraseValueFromMap(Value *V) {
4560 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4561 if (I != ValueExprMap.end()) {
4562 auto EVIt = ExprValueMap.find(I->second);
4563 bool Removed = EVIt->second.remove(V);
4564 (void) Removed;
4565 assert(Removed && "Value not in ExprValueMap?");
4566 ValueExprMap.erase(I);
4567 }
4568}
4569
4570void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4571 // A recursive query may have already computed the SCEV. It should be
4572 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4573 // inferred nowrap flags.
4574 auto It = ValueExprMap.find_as(V);
4575 if (It == ValueExprMap.end()) {
4576 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4577 ExprValueMap[S].insert(V);
4578 }
4579}
4580
4581/// Return an existing SCEV if it exists, otherwise analyze the expression and
4582/// create a new one.
4584 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4585
4586 if (const SCEV *S = getExistingSCEV(V))
4587 return S;
4588 return createSCEVIter(V);
4589}
4590
4592 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4593
4594 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4595 if (I != ValueExprMap.end()) {
4596 const SCEV *S = I->second;
4597 assert(checkValidity(S) &&
4598 "existing SCEV has not been properly invalidated");
4599 return S;
4600 }
4601 return nullptr;
4602}
4603
4604/// Return a SCEV corresponding to -V = -1*V
4606 SCEV::NoWrapFlags Flags) {
4607 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4608 return getConstant(
4609 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4610
4611 Type *Ty = V->getType();
4612 Ty = getEffectiveSCEVType(Ty);
4613 return getMulExpr(V, getMinusOne(Ty), Flags);
4614}
4615
4616/// If Expr computes ~A, return A else return nullptr
4617static const SCEV *MatchNotExpr(const SCEV *Expr) {
4618 const SCEV *MulOp;
4619 if (match(Expr, m_scev_Add(m_scev_AllOnes(),
4620 m_scev_Mul(m_scev_AllOnes(), m_SCEV(MulOp)))))
4621 return MulOp;
4622 return nullptr;
4623}
4624
4625/// Return a SCEV corresponding to ~V = -1-V
4627 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4628
4629 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4630 return getConstant(
4631 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4632
4633 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4634 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4635 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4636 SmallVector<const SCEV *, 2> MatchedOperands;
4637 for (const SCEV *Operand : MME->operands()) {
4638 const SCEV *Matched = MatchNotExpr(Operand);
4639 if (!Matched)
4640 return (const SCEV *)nullptr;
4641 MatchedOperands.push_back(Matched);
4642 }
4643 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4644 MatchedOperands);
4645 };
4646 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4647 return Replaced;
4648 }
4649
4650 Type *Ty = V->getType();
4651 Ty = getEffectiveSCEVType(Ty);
4652 return getMinusSCEV(getMinusOne(Ty), V);
4653}
4654
4656 assert(P->getType()->isPointerTy());
4657
4658 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4659 // The base of an AddRec is the first operand.
4660 SmallVector<const SCEV *> Ops{AddRec->operands()};
4661 Ops[0] = removePointerBase(Ops[0]);
4662 // Don't try to transfer nowrap flags for now. We could in some cases
4663 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4664 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4665 }
4666 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4667 // The base of an Add is the pointer operand.
4668 SmallVector<const SCEV *> Ops{Add->operands()};
4669 const SCEV **PtrOp = nullptr;
4670 for (const SCEV *&AddOp : Ops) {
4671 if (AddOp->getType()->isPointerTy()) {
4672 assert(!PtrOp && "Cannot have multiple pointer ops");
4673 PtrOp = &AddOp;
4674 }
4675 }
4676 *PtrOp = removePointerBase(*PtrOp);
4677 // Don't try to transfer nowrap flags for now. We could in some cases
4678 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4679 return getAddExpr(Ops);
4680 }
4681 // Any other expression must be a pointer base.
4682 return getZero(P->getType());
4683}
4684
4685const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4686 SCEV::NoWrapFlags Flags,
4687 unsigned Depth) {
4688 // Fast path: X - X --> 0.
4689 if (LHS == RHS)
4690 return getZero(LHS->getType());
4691
4692 // If we subtract two pointers with different pointer bases, bail.
4693 // Eventually, we're going to add an assertion to getMulExpr that we
4694 // can't multiply by a pointer.
4695 if (RHS->getType()->isPointerTy()) {
4696 if (!LHS->getType()->isPointerTy() ||
4697 getPointerBase(LHS) != getPointerBase(RHS))
4698 return getCouldNotCompute();
4699 LHS = removePointerBase(LHS);
4700 RHS = removePointerBase(RHS);
4701 }
4702
4703 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4704 // makes it so that we cannot make much use of NUW.
4705 auto AddFlags = SCEV::FlagAnyWrap;
4706 const bool RHSIsNotMinSigned =
4708 if (hasFlags(Flags, SCEV::FlagNSW)) {
4709 // Let M be the minimum representable signed value. Then (-1)*RHS
4710 // signed-wraps if and only if RHS is M. That can happen even for
4711 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4712 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4713 // (-1)*RHS, we need to prove that RHS != M.
4714 //
4715 // If LHS is non-negative and we know that LHS - RHS does not
4716 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4717 // either by proving that RHS > M or that LHS >= 0.
4718 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4719 AddFlags = SCEV::FlagNSW;
4720 }
4721 }
4722
4723 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4724 // RHS is NSW and LHS >= 0.
4725 //
4726 // The difficulty here is that the NSW flag may have been proven
4727 // relative to a loop that is to be found in a recurrence in LHS and
4728 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4729 // larger scope than intended.
4730 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4731
4732 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4733}
4734
4736 unsigned Depth) {
4737 Type *SrcTy = V->getType();
4738 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4739 "Cannot truncate or zero extend with non-integer arguments!");
4740 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4741 return V; // No conversion
4742 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4743 return getTruncateExpr(V, Ty, Depth);
4744 return getZeroExtendExpr(V, Ty, Depth);
4745}
4746
4748 unsigned Depth) {
4749 Type *SrcTy = V->getType();
4750 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4751 "Cannot truncate or zero extend with non-integer arguments!");
4752 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4753 return V; // No conversion
4754 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4755 return getTruncateExpr(V, Ty, Depth);
4756 return getSignExtendExpr(V, Ty, Depth);
4757}
4758
4759const SCEV *
4761 Type *SrcTy = V->getType();
4762 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4763 "Cannot noop or zero extend with non-integer arguments!");
4765 "getNoopOrZeroExtend cannot truncate!");
4766 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4767 return V; // No conversion
4768 return getZeroExtendExpr(V, Ty);
4769}
4770
4771const SCEV *
4773 Type *SrcTy = V->getType();
4774 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4775 "Cannot noop or sign extend with non-integer arguments!");
4777 "getNoopOrSignExtend cannot truncate!");
4778 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4779 return V; // No conversion
4780 return getSignExtendExpr(V, Ty);
4781}
4782
4783const SCEV *
4785 Type *SrcTy = V->getType();
4786 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4787 "Cannot noop or any extend with non-integer arguments!");
4789 "getNoopOrAnyExtend cannot truncate!");
4790 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4791 return V; // No conversion
4792 return getAnyExtendExpr(V, Ty);
4793}
4794
4795const SCEV *
4797 Type *SrcTy = V->getType();
4798 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4799 "Cannot truncate or noop with non-integer arguments!");
4801 "getTruncateOrNoop cannot extend!");
4802 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4803 return V; // No conversion
4804 return getTruncateExpr(V, Ty);
4805}
4806
4808 const SCEV *RHS) {
4809 const SCEV *PromotedLHS = LHS;
4810 const SCEV *PromotedRHS = RHS;
4811
4812 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4813 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4814 else
4815 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4816
4817 return getUMaxExpr(PromotedLHS, PromotedRHS);
4818}
4819
4821 const SCEV *RHS,
4822 bool Sequential) {
4823 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4824 return getUMinFromMismatchedTypes(Ops, Sequential);
4825}
4826
4827const SCEV *
4829 bool Sequential) {
4830 assert(!Ops.empty() && "At least one operand must be!");
4831 // Trivial case.
4832 if (Ops.size() == 1)
4833 return Ops[0];
4834
4835 // Find the max type first.
4836 Type *MaxType = nullptr;
4837 for (const auto *S : Ops)
4838 if (MaxType)
4839 MaxType = getWiderType(MaxType, S->getType());
4840 else
4841 MaxType = S->getType();
4842 assert(MaxType && "Failed to find maximum type!");
4843
4844 // Extend all ops to max type.
4845 SmallVector<const SCEV *, 2> PromotedOps;
4846 for (const auto *S : Ops)
4847 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4848
4849 // Generate umin.
4850 return getUMinExpr(PromotedOps, Sequential);
4851}
4852
4854 // A pointer operand may evaluate to a nonpointer expression, such as null.
4855 if (!V->getType()->isPointerTy())
4856 return V;
4857
4858 while (true) {
4859 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4860 V = AddRec->getStart();
4861 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4862 const SCEV *PtrOp = nullptr;
4863 for (const SCEV *AddOp : Add->operands()) {
4864 if (AddOp->getType()->isPointerTy()) {
4865 assert(!PtrOp && "Cannot have multiple pointer ops");
4866 PtrOp = AddOp;
4867 }
4868 }
4869 assert(PtrOp && "Must have pointer op");
4870 V = PtrOp;
4871 } else // Not something we can look further into.
4872 return V;
4873 }
4874}
4875
4876/// Push users of the given Instruction onto the given Worklist.
4880 // Push the def-use children onto the Worklist stack.
4881 for (User *U : I->users()) {
4882 auto *UserInsn = cast<Instruction>(U);
4883 if (Visited.insert(UserInsn).second)
4884 Worklist.push_back(UserInsn);
4885 }
4886}
4887
4888namespace {
4889
4890/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4891/// expression in case its Loop is L. If it is not L then
4892/// if IgnoreOtherLoops is true then use AddRec itself
4893/// otherwise rewrite cannot be done.
4894/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4895class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4896public:
4897 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4898 bool IgnoreOtherLoops = true) {
4899 SCEVInitRewriter Rewriter(L, SE);
4900 const SCEV *Result = Rewriter.visit(S);
4901 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4902 return SE.getCouldNotCompute();
4903 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4904 ? SE.getCouldNotCompute()
4905 : Result;
4906 }
4907
4908 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4909 if (!SE.isLoopInvariant(Expr, L))
4910 SeenLoopVariantSCEVUnknown = true;
4911 return Expr;
4912 }
4913
4914 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4915 // Only re-write AddRecExprs for this loop.
4916 if (Expr->getLoop() == L)
4917 return Expr->getStart();
4918 SeenOtherLoops = true;
4919 return Expr;
4920 }
4921
4922 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4923
4924 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4925
4926private:
4927 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4928 : SCEVRewriteVisitor(SE), L(L) {}
4929
4930 const Loop *L;
4931 bool SeenLoopVariantSCEVUnknown = false;
4932 bool SeenOtherLoops = false;
4933};
4934
4935/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4936/// increment expression in case its Loop is L. If it is not L then
4937/// use AddRec itself.
4938/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4939class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4940public:
4941 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4942 SCEVPostIncRewriter Rewriter(L, SE);
4943 const SCEV *Result = Rewriter.visit(S);
4944 return Rewriter.hasSeenLoopVariantSCEVUnknown()
4945 ? SE.getCouldNotCompute()
4946 : Result;
4947 }
4948
4949 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4950 if (!SE.isLoopInvariant(Expr, L))
4951 SeenLoopVariantSCEVUnknown = true;
4952 return Expr;
4953 }
4954
4955 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4956 // Only re-write AddRecExprs for this loop.
4957 if (Expr->getLoop() == L)
4958 return Expr->getPostIncExpr(SE);
4959 SeenOtherLoops = true;
4960 return Expr;
4961 }
4962
4963 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4964
4965 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4966
4967private:
4968 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
4969 : SCEVRewriteVisitor(SE), L(L) {}
4970
4971 const Loop *L;
4972 bool SeenLoopVariantSCEVUnknown = false;
4973 bool SeenOtherLoops = false;
4974};
4975
4976/// This class evaluates the compare condition by matching it against the
4977/// condition of loop latch. If there is a match we assume a true value
4978/// for the condition while building SCEV nodes.
4979class SCEVBackedgeConditionFolder
4980 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
4981public:
4982 static const SCEV *rewrite(const SCEV *S, const Loop *L,
4983 ScalarEvolution &SE) {
4984 bool IsPosBECond = false;
4985 Value *BECond = nullptr;
4986 if (BasicBlock *Latch = L->getLoopLatch()) {
4987 BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
4988 if (BI && BI->isConditional()) {
4989 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
4990 "Both outgoing branches should not target same header!");
4991 BECond = BI->getCondition();
4992 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
4993 } else {
4994 return S;
4995 }
4996 }
4997 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
4998 return Rewriter.visit(S);
4999 }
5000
5001 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5002 const SCEV *Result = Expr;
5003 bool InvariantF = SE.isLoopInvariant(Expr, L);
5004
5005 if (!InvariantF) {
5007 switch (I->getOpcode()) {
5008 case Instruction::Select: {
5009 SelectInst *SI = cast<SelectInst>(I);
5010 std::optional<const SCEV *> Res =
5011 compareWithBackedgeCondition(SI->getCondition());
5012 if (Res) {
5013 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
5014 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
5015 }
5016 break;
5017 }
5018 default: {
5019 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
5020 if (Res)
5021 Result = *Res;
5022 break;
5023 }
5024 }
5025 }
5026 return Result;
5027 }
5028
5029private:
5030 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5031 bool IsPosBECond, ScalarEvolution &SE)
5032 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5033 IsPositiveBECond(IsPosBECond) {}
5034
5035 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5036
5037 const Loop *L;
5038 /// Loop back condition.
5039 Value *BackedgeCond = nullptr;
5040 /// Set to true if loop back is on positive branch condition.
5041 bool IsPositiveBECond;
5042};
5043
5044std::optional<const SCEV *>
5045SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5046
5047 // If value matches the backedge condition for loop latch,
5048 // then return a constant evolution node based on loopback
5049 // branch taken.
5050 if (BackedgeCond == IC)
5051 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5053 return std::nullopt;
5054}
5055
5056class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5057public:
5058 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5059 ScalarEvolution &SE) {
5060 SCEVShiftRewriter Rewriter(L, SE);
5061 const SCEV *Result = Rewriter.visit(S);
5062 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5063 }
5064
5065 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5066 // Only allow AddRecExprs for this loop.
5067 if (!SE.isLoopInvariant(Expr, L))
5068 Valid = false;
5069 return Expr;
5070 }
5071
5072 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5073 if (Expr->getLoop() == L && Expr->isAffine())
5074 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5075 Valid = false;
5076 return Expr;
5077 }
5078
5079 bool isValid() { return Valid; }
5080
5081private:
5082 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5083 : SCEVRewriteVisitor(SE), L(L) {}
5084
5085 const Loop *L;
5086 bool Valid = true;
5087};
5088
5089} // end anonymous namespace
5090
5092ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5093 if (!AR->isAffine())
5094 return SCEV::FlagAnyWrap;
5095
5096 using OBO = OverflowingBinaryOperator;
5097
5099
5100 if (!AR->hasNoSelfWrap()) {
5101 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5102 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5103 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5104 const APInt &BECountAP = BECountMax->getAPInt();
5105 unsigned NoOverflowBitWidth =
5106 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5107 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5109 }
5110 }
5111
5112 if (!AR->hasNoSignedWrap()) {
5113 ConstantRange AddRecRange = getSignedRange(AR);
5114 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5115
5117 Instruction::Add, IncRange, OBO::NoSignedWrap);
5118 if (NSWRegion.contains(AddRecRange))
5120 }
5121
5122 if (!AR->hasNoUnsignedWrap()) {
5123 ConstantRange AddRecRange = getUnsignedRange(AR);
5124 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5125
5127 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5128 if (NUWRegion.contains(AddRecRange))
5130 }
5131
5132 return Result;
5133}
5134
5136ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5138
5139 if (AR->hasNoSignedWrap())
5140 return Result;
5141
5142 if (!AR->isAffine())
5143 return Result;
5144
5145 // This function can be expensive, only try to prove NSW once per AddRec.
5146 if (!SignedWrapViaInductionTried.insert(AR).second)
5147 return Result;
5148
5149 const SCEV *Step = AR->getStepRecurrence(*this);
5150 const Loop *L = AR->getLoop();
5151
5152 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5153 // Note that this serves two purposes: It filters out loops that are
5154 // simply not analyzable, and it covers the case where this code is
5155 // being called from within backedge-taken count analysis, such that
5156 // attempting to ask for the backedge-taken count would likely result
5157 // in infinite recursion. In the later case, the analysis code will
5158 // cope with a conservative value, and it will take care to purge
5159 // that value once it has finished.
5160 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5161
5162 // Normally, in the cases we can prove no-overflow via a
5163 // backedge guarding condition, we can also compute a backedge
5164 // taken count for the loop. The exceptions are assumptions and
5165 // guards present in the loop -- SCEV is not great at exploiting
5166 // these to compute max backedge taken counts, but can still use
5167 // these to prove lack of overflow. Use this fact to avoid
5168 // doing extra work that may not pay off.
5169
5170 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5171 AC.assumptions().empty())
5172 return Result;
5173
5174 // If the backedge is guarded by a comparison with the pre-inc value the
5175 // addrec is safe. Also, if the entry is guarded by a comparison with the
5176 // start value and the backedge is guarded by a comparison with the post-inc
5177 // value, the addrec is safe.
5179 const SCEV *OverflowLimit =
5180 getSignedOverflowLimitForStep(Step, &Pred, this);
5181 if (OverflowLimit &&
5182 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5183 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5184 Result = setFlags(Result, SCEV::FlagNSW);
5185 }
5186 return Result;
5187}
5189ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5191
5192 if (AR->hasNoUnsignedWrap())
5193 return Result;
5194
5195 if (!AR->isAffine())
5196 return Result;
5197
5198 // This function can be expensive, only try to prove NUW once per AddRec.
5199 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5200 return Result;
5201
5202 const SCEV *Step = AR->getStepRecurrence(*this);
5203 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5204 const Loop *L = AR->getLoop();
5205
5206 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5207 // Note that this serves two purposes: It filters out loops that are
5208 // simply not analyzable, and it covers the case where this code is
5209 // being called from within backedge-taken count analysis, such that
5210 // attempting to ask for the backedge-taken count would likely result
5211 // in infinite recursion. In the later case, the analysis code will
5212 // cope with a conservative value, and it will take care to purge
5213 // that value once it has finished.
5214 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5215
5216 // Normally, in the cases we can prove no-overflow via a
5217 // backedge guarding condition, we can also compute a backedge
5218 // taken count for the loop. The exceptions are assumptions and
5219 // guards present in the loop -- SCEV is not great at exploiting
5220 // these to compute max backedge taken counts, but can still use
5221 // these to prove lack of overflow. Use this fact to avoid
5222 // doing extra work that may not pay off.
5223
5224 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5225 AC.assumptions().empty())
5226 return Result;
5227
5228 // If the backedge is guarded by a comparison with the pre-inc value the
5229 // addrec is safe. Also, if the entry is guarded by a comparison with the
5230 // start value and the backedge is guarded by a comparison with the post-inc
5231 // value, the addrec is safe.
5232 if (isKnownPositive(Step)) {
5233 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
5234 getUnsignedRangeMax(Step));
5237 Result = setFlags(Result, SCEV::FlagNUW);
5238 }
5239 }
5240
5241 return Result;
5242}
5243
5244namespace {
5245
5246/// Represents an abstract binary operation. This may exist as a
5247/// normal instruction or constant expression, or may have been
5248/// derived from an expression tree.
5249struct BinaryOp {
5250 unsigned Opcode;
5251 Value *LHS;
5252 Value *RHS;
5253 bool IsNSW = false;
5254 bool IsNUW = false;
5255
5256 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5257 /// constant expression.
5258 Operator *Op = nullptr;
5259
5260 explicit BinaryOp(Operator *Op)
5261 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5262 Op(Op) {
5263 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5264 IsNSW = OBO->hasNoSignedWrap();
5265 IsNUW = OBO->hasNoUnsignedWrap();
5266 }
5267 }
5268
5269 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5270 bool IsNUW = false)
5271 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5272};
5273
5274} // end anonymous namespace
5275
5276/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5277static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5278 AssumptionCache &AC,
5279 const DominatorTree &DT,
5280 const Instruction *CxtI) {
5281 auto *Op = dyn_cast<Operator>(V);
5282 if (!Op)
5283 return std::nullopt;
5284
5285 // Implementation detail: all the cleverness here should happen without
5286 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5287 // SCEV expressions when possible, and we should not break that.
5288
5289 switch (Op->getOpcode()) {
5290 case Instruction::Add:
5291 case Instruction::Sub:
5292 case Instruction::Mul:
5293 case Instruction::UDiv:
5294 case Instruction::URem:
5295 case Instruction::And:
5296 case Instruction::AShr:
5297 case Instruction::Shl:
5298 return BinaryOp(Op);
5299
5300 case Instruction::Or: {
5301 // Convert or disjoint into add nuw nsw.
5302 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5303 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5304 /*IsNSW=*/true, /*IsNUW=*/true);
5305 return BinaryOp(Op);
5306 }
5307
5308 case Instruction::Xor:
5309 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5310 // If the RHS of the xor is a signmask, then this is just an add.
5311 // Instcombine turns add of signmask into xor as a strength reduction step.
5312 if (RHSC->getValue().isSignMask())
5313 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5314 // Binary `xor` is a bit-wise `add`.
5315 if (V->getType()->isIntegerTy(1))
5316 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5317 return BinaryOp(Op);
5318
5319 case Instruction::LShr:
5320 // Turn logical shift right of a constant into a unsigned divide.
5321 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5322 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5323
5324 // If the shift count is not less than the bitwidth, the result of
5325 // the shift is undefined. Don't try to analyze it, because the
5326 // resolution chosen here may differ from the resolution chosen in
5327 // other parts of the compiler.
5328 if (SA->getValue().ult(BitWidth)) {
5329 Constant *X =
5330 ConstantInt::get(SA->getContext(),
5331 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5332 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5333 }
5334 }
5335 return BinaryOp(Op);
5336
5337 case Instruction::ExtractValue: {
5338 auto *EVI = cast<ExtractValueInst>(Op);
5339 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5340 break;
5341
5342 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5343 if (!WO)
5344 break;
5345
5346 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5347 bool Signed = WO->isSigned();
5348 // TODO: Should add nuw/nsw flags for mul as well.
5349 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5350 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5351
5352 // Now that we know that all uses of the arithmetic-result component of
5353 // CI are guarded by the overflow check, we can go ahead and pretend
5354 // that the arithmetic is non-overflowing.
5355 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5356 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5357 }
5358
5359 default:
5360 break;
5361 }
5362
5363 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5364 // semantics as a Sub, return a binary sub expression.
5365 if (auto *II = dyn_cast<IntrinsicInst>(V))
5366 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5367 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5368
5369 return std::nullopt;
5370}
5371
5372/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5373/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5374/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5375/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5376/// follows one of the following patterns:
5377/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5378/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5379/// If the SCEV expression of \p Op conforms with one of the expected patterns
5380/// we return the type of the truncation operation, and indicate whether the
5381/// truncated type should be treated as signed/unsigned by setting
5382/// \p Signed to true/false, respectively.
5383static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5384 bool &Signed, ScalarEvolution &SE) {
5385 // The case where Op == SymbolicPHI (that is, with no type conversions on
5386 // the way) is handled by the regular add recurrence creating logic and
5387 // would have already been triggered in createAddRecForPHI. Reaching it here
5388 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5389 // because one of the other operands of the SCEVAddExpr updating this PHI is
5390 // not invariant).
5391 //
5392 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5393 // this case predicates that allow us to prove that Op == SymbolicPHI will
5394 // be added.
5395 if (Op == SymbolicPHI)
5396 return nullptr;
5397
5398 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5399 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5400 if (SourceBits != NewBits)
5401 return nullptr;
5402
5403 if (match(Op, m_scev_SExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5404 Signed = true;
5405 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5406 }
5407 if (match(Op, m_scev_ZExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5408 Signed = false;
5409 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5410 }
5411 return nullptr;
5412}
5413
5414static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5415 if (!PN->getType()->isIntegerTy())
5416 return nullptr;
5417 const Loop *L = LI.getLoopFor(PN->getParent());
5418 if (!L || L->getHeader() != PN->getParent())
5419 return nullptr;
5420 return L;
5421}
5422
5423// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5424// computation that updates the phi follows the following pattern:
5425// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5426// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5427// If so, try to see if it can be rewritten as an AddRecExpr under some
5428// Predicates. If successful, return them as a pair. Also cache the results
5429// of the analysis.
5430//
5431// Example usage scenario:
5432// Say the Rewriter is called for the following SCEV:
5433// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5434// where:
5435// %X = phi i64 (%Start, %BEValue)
5436// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5437// and call this function with %SymbolicPHI = %X.
5438//
5439// The analysis will find that the value coming around the backedge has
5440// the following SCEV:
5441// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5442// Upon concluding that this matches the desired pattern, the function
5443// will return the pair {NewAddRec, SmallPredsVec} where:
5444// NewAddRec = {%Start,+,%Step}
5445// SmallPredsVec = {P1, P2, P3} as follows:
5446// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5447// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5448// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5449// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5450// under the predicates {P1,P2,P3}.
5451// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5452// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5453//
5454// TODO's:
5455//
5456// 1) Extend the Induction descriptor to also support inductions that involve
5457// casts: When needed (namely, when we are called in the context of the
5458// vectorizer induction analysis), a Set of cast instructions will be
5459// populated by this method, and provided back to isInductionPHI. This is
5460// needed to allow the vectorizer to properly record them to be ignored by
5461// the cost model and to avoid vectorizing them (otherwise these casts,
5462// which are redundant under the runtime overflow checks, will be
5463// vectorized, which can be costly).
5464//
5465// 2) Support additional induction/PHISCEV patterns: We also want to support
5466// inductions where the sext-trunc / zext-trunc operations (partly) occur
5467// after the induction update operation (the induction increment):
5468//
5469// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5470// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5471//
5472// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5473// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5474//
5475// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5476std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5477ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5479
5480 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5481 // return an AddRec expression under some predicate.
5482
5483 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5484 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5485 assert(L && "Expecting an integer loop header phi");
5486
5487 // The loop may have multiple entrances or multiple exits; we can analyze
5488 // this phi as an addrec if it has a unique entry value and a unique
5489 // backedge value.
5490 Value *BEValueV = nullptr, *StartValueV = nullptr;
5491 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5492 Value *V = PN->getIncomingValue(i);
5493 if (L->contains(PN->getIncomingBlock(i))) {
5494 if (!BEValueV) {
5495 BEValueV = V;
5496 } else if (BEValueV != V) {
5497 BEValueV = nullptr;
5498 break;
5499 }
5500 } else if (!StartValueV) {
5501 StartValueV = V;
5502 } else if (StartValueV != V) {
5503 StartValueV = nullptr;
5504 break;
5505 }
5506 }
5507 if (!BEValueV || !StartValueV)
5508 return std::nullopt;
5509
5510 const SCEV *BEValue = getSCEV(BEValueV);
5511
5512 // If the value coming around the backedge is an add with the symbolic
5513 // value we just inserted, possibly with casts that we can ignore under
5514 // an appropriate runtime guard, then we found a simple induction variable!
5515 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5516 if (!Add)
5517 return std::nullopt;
5518
5519 // If there is a single occurrence of the symbolic value, possibly
5520 // casted, replace it with a recurrence.
5521 unsigned FoundIndex = Add->getNumOperands();
5522 Type *TruncTy = nullptr;
5523 bool Signed;
5524 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5525 if ((TruncTy =
5526 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5527 if (FoundIndex == e) {
5528 FoundIndex = i;
5529 break;
5530 }
5531
5532 if (FoundIndex == Add->getNumOperands())
5533 return std::nullopt;
5534
5535 // Create an add with everything but the specified operand.
5537 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5538 if (i != FoundIndex)
5539 Ops.push_back(Add->getOperand(i));
5540 const SCEV *Accum = getAddExpr(Ops);
5541
5542 // The runtime checks will not be valid if the step amount is
5543 // varying inside the loop.
5544 if (!isLoopInvariant(Accum, L))
5545 return std::nullopt;
5546
5547 // *** Part2: Create the predicates
5548
5549 // Analysis was successful: we have a phi-with-cast pattern for which we
5550 // can return an AddRec expression under the following predicates:
5551 //
5552 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5553 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5554 // P2: An Equal predicate that guarantees that
5555 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5556 // P3: An Equal predicate that guarantees that
5557 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5558 //
5559 // As we next prove, the above predicates guarantee that:
5560 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5561 //
5562 //
5563 // More formally, we want to prove that:
5564 // Expr(i+1) = Start + (i+1) * Accum
5565 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5566 //
5567 // Given that:
5568 // 1) Expr(0) = Start
5569 // 2) Expr(1) = Start + Accum
5570 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5571 // 3) Induction hypothesis (step i):
5572 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5573 //
5574 // Proof:
5575 // Expr(i+1) =
5576 // = Start + (i+1)*Accum
5577 // = (Start + i*Accum) + Accum
5578 // = Expr(i) + Accum
5579 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5580 // :: from step i
5581 //
5582 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5583 //
5584 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5585 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5586 // + Accum :: from P3
5587 //
5588 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5589 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5590 //
5591 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5592 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5593 //
5594 // By induction, the same applies to all iterations 1<=i<n:
5595 //
5596
5597 // Create a truncated addrec for which we will add a no overflow check (P1).
5598 const SCEV *StartVal = getSCEV(StartValueV);
5599 const SCEV *PHISCEV =
5600 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5601 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5602
5603 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5604 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5605 // will be constant.
5606 //
5607 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5608 // add P1.
5609 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5613 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5614 Predicates.push_back(AddRecPred);
5615 }
5616
5617 // Create the Equal Predicates P2,P3:
5618
5619 // It is possible that the predicates P2 and/or P3 are computable at
5620 // compile time due to StartVal and/or Accum being constants.
5621 // If either one is, then we can check that now and escape if either P2
5622 // or P3 is false.
5623
5624 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5625 // for each of StartVal and Accum
5626 auto getExtendedExpr = [&](const SCEV *Expr,
5627 bool CreateSignExtend) -> const SCEV * {
5628 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5629 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5630 const SCEV *ExtendedExpr =
5631 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5632 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5633 return ExtendedExpr;
5634 };
5635
5636 // Given:
5637 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5638 // = getExtendedExpr(Expr)
5639 // Determine whether the predicate P: Expr == ExtendedExpr
5640 // is known to be false at compile time
5641 auto PredIsKnownFalse = [&](const SCEV *Expr,
5642 const SCEV *ExtendedExpr) -> bool {
5643 return Expr != ExtendedExpr &&
5644 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5645 };
5646
5647 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5648 if (PredIsKnownFalse(StartVal, StartExtended)) {
5649 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5650 return std::nullopt;
5651 }
5652
5653 // The Step is always Signed (because the overflow checks are either
5654 // NSSW or NUSW)
5655 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5656 if (PredIsKnownFalse(Accum, AccumExtended)) {
5657 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5658 return std::nullopt;
5659 }
5660
5661 auto AppendPredicate = [&](const SCEV *Expr,
5662 const SCEV *ExtendedExpr) -> void {
5663 if (Expr != ExtendedExpr &&
5664 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5665 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5666 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5667 Predicates.push_back(Pred);
5668 }
5669 };
5670
5671 AppendPredicate(StartVal, StartExtended);
5672 AppendPredicate(Accum, AccumExtended);
5673
5674 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5675 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5676 // into NewAR if it will also add the runtime overflow checks specified in
5677 // Predicates.
5678 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5679
5680 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5681 std::make_pair(NewAR, Predicates);
5682 // Remember the result of the analysis for this SCEV at this locayyytion.
5683 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5684 return PredRewrite;
5685}
5686
5687std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5689 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5690 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5691 if (!L)
5692 return std::nullopt;
5693
5694 // Check to see if we already analyzed this PHI.
5695 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5696 if (I != PredicatedSCEVRewrites.end()) {
5697 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5698 I->second;
5699 // Analysis was done before and failed to create an AddRec:
5700 if (Rewrite.first == SymbolicPHI)
5701 return std::nullopt;
5702 // Analysis was done before and succeeded to create an AddRec under
5703 // a predicate:
5704 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5705 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5706 return Rewrite;
5707 }
5708
5709 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5710 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5711
5712 // Record in the cache that the analysis failed
5713 if (!Rewrite) {
5715 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5716 return std::nullopt;
5717 }
5718
5719 return Rewrite;
5720}
5721
5722// FIXME: This utility is currently required because the Rewriter currently
5723// does not rewrite this expression:
5724// {0, +, (sext ix (trunc iy to ix) to iy)}
5725// into {0, +, %step},
5726// even when the following Equal predicate exists:
5727// "%step == (sext ix (trunc iy to ix) to iy)".
5729 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5730 if (AR1 == AR2)
5731 return true;
5732
5733 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5734 if (Expr1 != Expr2 &&
5735 !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5736 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5737 return false;
5738 return true;
5739 };
5740
5741 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5742 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5743 return false;
5744 return true;
5745}
5746
5747/// A helper function for createAddRecFromPHI to handle simple cases.
5748///
5749/// This function tries to find an AddRec expression for the simplest (yet most
5750/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5751/// If it fails, createAddRecFromPHI will use a more general, but slow,
5752/// technique for finding the AddRec expression.
5753const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5754 Value *BEValueV,
5755 Value *StartValueV) {
5756 const Loop *L = LI.getLoopFor(PN->getParent());
5757 assert(L && L->getHeader() == PN->getParent());
5758 assert(BEValueV && StartValueV);
5759
5760 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5761 if (!BO)
5762 return nullptr;
5763
5764 if (BO->Opcode != Instruction::Add)
5765 return nullptr;
5766
5767 const SCEV *Accum = nullptr;
5768 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5769 Accum = getSCEV(BO->RHS);
5770 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5771 Accum = getSCEV(BO->LHS);
5772
5773 if (!Accum)
5774 return nullptr;
5775
5777 if (BO->IsNUW)
5778 Flags = setFlags(Flags, SCEV::FlagNUW);
5779 if (BO->IsNSW)
5780 Flags = setFlags(Flags, SCEV::FlagNSW);
5781
5782 const SCEV *StartVal = getSCEV(StartValueV);
5783 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5784 insertValueToMap(PN, PHISCEV);
5785
5786 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5787 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5789 proveNoWrapViaConstantRanges(AR)));
5790 }
5791
5792 // We can add Flags to the post-inc expression only if we
5793 // know that it is *undefined behavior* for BEValueV to
5794 // overflow.
5795 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5796 assert(isLoopInvariant(Accum, L) &&
5797 "Accum is defined outside L, but is not invariant?");
5798 if (isAddRecNeverPoison(BEInst, L))
5799 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5800 }
5801
5802 return PHISCEV;
5803}
5804
5805const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5806 const Loop *L = LI.getLoopFor(PN->getParent());
5807 if (!L || L->getHeader() != PN->getParent())
5808 return nullptr;
5809
5810 // The loop may have multiple entrances or multiple exits; we can analyze
5811 // this phi as an addrec if it has a unique entry value and a unique
5812 // backedge value.
5813 Value *BEValueV = nullptr, *StartValueV = nullptr;
5814 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5815 Value *V = PN->getIncomingValue(i);
5816 if (L->contains(PN->getIncomingBlock(i))) {
5817 if (!BEValueV) {
5818 BEValueV = V;
5819 } else if (BEValueV != V) {
5820 BEValueV = nullptr;
5821 break;
5822 }
5823 } else if (!StartValueV) {
5824 StartValueV = V;
5825 } else if (StartValueV != V) {
5826 StartValueV = nullptr;
5827 break;
5828 }
5829 }
5830 if (!BEValueV || !StartValueV)
5831 return nullptr;
5832
5833 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5834 "PHI node already processed?");
5835
5836 // First, try to find AddRec expression without creating a fictituos symbolic
5837 // value for PN.
5838 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5839 return S;
5840
5841 // Handle PHI node value symbolically.
5842 const SCEV *SymbolicName = getUnknown(PN);
5843 insertValueToMap(PN, SymbolicName);
5844
5845 // Using this symbolic name for the PHI, analyze the value coming around
5846 // the back-edge.
5847 const SCEV *BEValue = getSCEV(BEValueV);
5848
5849 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5850 // has a special value for the first iteration of the loop.
5851
5852 // If the value coming around the backedge is an add with the symbolic
5853 // value we just inserted, then we found a simple induction variable!
5854 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5855 // If there is a single occurrence of the symbolic value, replace it
5856 // with a recurrence.
5857 unsigned FoundIndex = Add->getNumOperands();
5858 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5859 if (Add->getOperand(i) == SymbolicName)
5860 if (FoundIndex == e) {
5861 FoundIndex = i;
5862 break;
5863 }
5864
5865 if (FoundIndex != Add->getNumOperands()) {
5866 // Create an add with everything but the specified operand.
5868 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5869 if (i != FoundIndex)
5870 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5871 L, *this));
5872 const SCEV *Accum = getAddExpr(Ops);
5873
5874 // This is not a valid addrec if the step amount is varying each
5875 // loop iteration, but is not itself an addrec in this loop.
5876 if (isLoopInvariant(Accum, L) ||
5877 (isa<SCEVAddRecExpr>(Accum) &&
5878 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5880
5881 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
5882 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5883 if (BO->IsNUW)
5884 Flags = setFlags(Flags, SCEV::FlagNUW);
5885 if (BO->IsNSW)
5886 Flags = setFlags(Flags, SCEV::FlagNSW);
5887 }
5888 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5889 if (GEP->getOperand(0) == PN) {
5890 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
5891 // If the increment has any nowrap flags, then we know the address
5892 // space cannot be wrapped around.
5893 if (NW != GEPNoWrapFlags::none())
5894 Flags = setFlags(Flags, SCEV::FlagNW);
5895 // If the GEP is nuw or nusw with non-negative offset, we know that
5896 // no unsigned wrap occurs. We cannot set the nsw flag as only the
5897 // offset is treated as signed, while the base is unsigned.
5898 if (NW.hasNoUnsignedWrap() ||
5900 Flags = setFlags(Flags, SCEV::FlagNUW);
5901 }
5902
5903 // We cannot transfer nuw and nsw flags from subtraction
5904 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5905 // for instance.
5906 }
5907
5908 const SCEV *StartVal = getSCEV(StartValueV);
5909 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5910
5911 // Okay, for the entire analysis of this edge we assumed the PHI
5912 // to be symbolic. We now need to go back and purge all of the
5913 // entries for the scalars that use the symbolic expression.
5914 forgetMemoizedResults(SymbolicName);
5915 insertValueToMap(PN, PHISCEV);
5916
5917 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5918 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5920 proveNoWrapViaConstantRanges(AR)));
5921 }
5922
5923 // We can add Flags to the post-inc expression only if we
5924 // know that it is *undefined behavior* for BEValueV to
5925 // overflow.
5926 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5927 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5928 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5929
5930 return PHISCEV;
5931 }
5932 }
5933 } else {
5934 // Otherwise, this could be a loop like this:
5935 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
5936 // In this case, j = {1,+,1} and BEValue is j.
5937 // Because the other in-value of i (0) fits the evolution of BEValue
5938 // i really is an addrec evolution.
5939 //
5940 // We can generalize this saying that i is the shifted value of BEValue
5941 // by one iteration:
5942 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
5943
5944 // Do not allow refinement in rewriting of BEValue.
5945 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
5946 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
5947 if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() &&
5948 isGuaranteedNotToCauseUB(Shifted) && ::impliesPoison(Shifted, Start)) {
5949 const SCEV *StartVal = getSCEV(StartValueV);
5950 if (Start == StartVal) {
5951 // Okay, for the entire analysis of this edge we assumed the PHI
5952 // to be symbolic. We now need to go back and purge all of the
5953 // entries for the scalars that use the symbolic expression.
5954 forgetMemoizedResults(SymbolicName);
5955 insertValueToMap(PN, Shifted);
5956 return Shifted;
5957 }
5958 }
5959 }
5960
5961 // Remove the temporary PHI node SCEV that has been inserted while intending
5962 // to create an AddRecExpr for this PHI node. We can not keep this temporary
5963 // as it will prevent later (possibly simpler) SCEV expressions to be added
5964 // to the ValueExprMap.
5965 eraseValueFromMap(PN);
5966
5967 return nullptr;
5968}
5969
5970// Try to match a control flow sequence that branches out at BI and merges back
5971// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
5972// match.
5974 Value *&C, Value *&LHS, Value *&RHS) {
5975 C = BI->getCondition();
5976
5977 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
5978 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
5979
5980 if (!LeftEdge.isSingleEdge())
5981 return false;
5982
5983 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
5984
5985 Use &LeftUse = Merge->getOperandUse(0);
5986 Use &RightUse = Merge->getOperandUse(1);
5987
5988 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
5989 LHS = LeftUse;
5990 RHS = RightUse;
5991 return true;
5992 }
5993
5994 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
5995 LHS = RightUse;
5996 RHS = LeftUse;
5997 return true;
5998 }
5999
6000 return false;
6001}
6002
6003const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
6004 auto IsReachable =
6005 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
6006 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
6007 // Try to match
6008 //
6009 // br %cond, label %left, label %right
6010 // left:
6011 // br label %merge
6012 // right:
6013 // br label %merge
6014 // merge:
6015 // V = phi [ %x, %left ], [ %y, %right ]
6016 //
6017 // as "select %cond, %x, %y"
6018
6019 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
6020 assert(IDom && "At least the entry block should dominate PN");
6021
6022 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
6023 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6024
6025 if (BI && BI->isConditional() &&
6026 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
6029 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6030 }
6031
6032 return nullptr;
6033}
6034
6035/// Returns SCEV for the first operand of a phi if all phi operands have
6036/// identical opcodes and operands
6037/// eg.
6038/// a: %add = %a + %b
6039/// br %c
6040/// b: %add1 = %a + %b
6041/// br %c
6042/// c: %phi = phi [%add, a], [%add1, b]
6043/// scev(%phi) => scev(%add)
6044const SCEV *
6045ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) {
6046 BinaryOperator *CommonInst = nullptr;
6047 // Check if instructions are identical.
6048 for (Value *Incoming : PN->incoming_values()) {
6049 auto *IncomingInst = dyn_cast<BinaryOperator>(Incoming);
6050 if (!IncomingInst)
6051 return nullptr;
6052 if (CommonInst) {
6053 if (!CommonInst->isIdenticalToWhenDefined(IncomingInst))
6054 return nullptr; // Not identical, give up
6055 } else {
6056 // Remember binary operator
6057 CommonInst = IncomingInst;
6058 }
6059 }
6060 if (!CommonInst)
6061 return nullptr;
6062
6063 // Check if SCEV exprs for instructions are identical.
6064 const SCEV *CommonSCEV = getSCEV(CommonInst);
6065 bool SCEVExprsIdentical =
6067 [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); });
6068 return SCEVExprsIdentical ? CommonSCEV : nullptr;
6069}
6070
6071const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6072 if (const SCEV *S = createAddRecFromPHI(PN))
6073 return S;
6074
6075 // We do not allow simplifying phi (undef, X) to X here, to avoid reusing the
6076 // phi node for X.
6077 if (Value *V = simplifyInstruction(
6078 PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
6079 /*UseInstrInfo=*/true, /*CanUseUndef=*/false}))
6080 return getSCEV(V);
6081
6082 if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN))
6083 return S;
6084
6085 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6086 return S;
6087
6088 // If it's not a loop phi, we can't handle it yet.
6089 return getUnknown(PN);
6090}
6091
6092bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6093 SCEVTypes RootKind) {
6094 struct FindClosure {
6095 const SCEV *OperandToFind;
6096 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6097 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6098
6099 bool Found = false;
6100
6101 bool canRecurseInto(SCEVTypes Kind) const {
6102 // We can only recurse into the SCEV expression of the same effective type
6103 // as the type of our root SCEV expression, and into zero-extensions.
6104 return RootKind == Kind || NonSequentialRootKind == Kind ||
6105 scZeroExtend == Kind;
6106 };
6107
6108 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6109 : OperandToFind(OperandToFind), RootKind(RootKind),
6110 NonSequentialRootKind(
6112 RootKind)) {}
6113
6114 bool follow(const SCEV *S) {
6115 Found = S == OperandToFind;
6116
6117 return !isDone() && canRecurseInto(S->getSCEVType());
6118 }
6119
6120 bool isDone() const { return Found; }
6121 };
6122
6123 FindClosure FC(OperandToFind, RootKind);
6124 visitAll(Root, FC);
6125 return FC.Found;
6126}
6127
6128std::optional<const SCEV *>
6129ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6130 ICmpInst *Cond,
6131 Value *TrueVal,
6132 Value *FalseVal) {
6133 // Try to match some simple smax or umax patterns.
6134 auto *ICI = Cond;
6135
6136 Value *LHS = ICI->getOperand(0);
6137 Value *RHS = ICI->getOperand(1);
6138
6139 switch (ICI->getPredicate()) {
6140 case ICmpInst::ICMP_SLT:
6141 case ICmpInst::ICMP_SLE:
6142 case ICmpInst::ICMP_ULT:
6143 case ICmpInst::ICMP_ULE:
6144 std::swap(LHS, RHS);
6145 [[fallthrough]];
6146 case ICmpInst::ICMP_SGT:
6147 case ICmpInst::ICMP_SGE:
6148 case ICmpInst::ICMP_UGT:
6149 case ICmpInst::ICMP_UGE:
6150 // a > b ? a+x : b+x -> max(a, b)+x
6151 // a > b ? b+x : a+x -> min(a, b)+x
6153 bool Signed = ICI->isSigned();
6154 const SCEV *LA = getSCEV(TrueVal);
6155 const SCEV *RA = getSCEV(FalseVal);
6156 const SCEV *LS = getSCEV(LHS);
6157 const SCEV *RS = getSCEV(RHS);
6158 if (LA->getType()->isPointerTy()) {
6159 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6160 // Need to make sure we can't produce weird expressions involving
6161 // negated pointers.
6162 if (LA == LS && RA == RS)
6163 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6164 if (LA == RS && RA == LS)
6165 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6166 }
6167 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6168 if (Op->getType()->isPointerTy()) {
6171 return Op;
6172 }
6173 if (Signed)
6174 Op = getNoopOrSignExtend(Op, Ty);
6175 else
6176 Op = getNoopOrZeroExtend(Op, Ty);
6177 return Op;
6178 };
6179 LS = CoerceOperand(LS);
6180 RS = CoerceOperand(RS);
6182 break;
6183 const SCEV *LDiff = getMinusSCEV(LA, LS);
6184 const SCEV *RDiff = getMinusSCEV(RA, RS);
6185 if (LDiff == RDiff)
6186 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6187 LDiff);
6188 LDiff = getMinusSCEV(LA, RS);
6189 RDiff = getMinusSCEV(RA, LS);
6190 if (LDiff == RDiff)
6191 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6192 LDiff);
6193 }
6194 break;
6195 case ICmpInst::ICMP_NE:
6196 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6197 std::swap(TrueVal, FalseVal);
6198 [[fallthrough]];
6199 case ICmpInst::ICMP_EQ:
6200 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6203 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6204 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6205 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6206 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6207 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6208 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6209 return getAddExpr(getUMaxExpr(X, C), Y);
6210 }
6211 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6212 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6213 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6214 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6216 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6217 const SCEV *X = getSCEV(LHS);
6218 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6219 X = ZExt->getOperand();
6220 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6221 const SCEV *FalseValExpr = getSCEV(FalseVal);
6222 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6223 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6224 /*Sequential=*/true);
6225 }
6226 }
6227 break;
6228 default:
6229 break;
6230 }
6231
6232 return std::nullopt;
6233}
6234
6235static std::optional<const SCEV *>
6237 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6238 assert(CondExpr->getType()->isIntegerTy(1) &&
6239 TrueExpr->getType() == FalseExpr->getType() &&
6240 TrueExpr->getType()->isIntegerTy(1) &&
6241 "Unexpected operands of a select.");
6242
6243 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6244 // --> C + (umin_seq cond, x - C)
6245 //
6246 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6247 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6248 // --> C + (umin_seq ~cond, x - C)
6249
6250 // FIXME: while we can't legally model the case where both of the hands
6251 // are fully variable, we only require that the *difference* is constant.
6252 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6253 return std::nullopt;
6254
6255 const SCEV *X, *C;
6256 if (isa<SCEVConstant>(TrueExpr)) {
6257 CondExpr = SE->getNotSCEV(CondExpr);
6258 X = FalseExpr;
6259 C = TrueExpr;
6260 } else {
6261 X = TrueExpr;
6262 C = FalseExpr;
6263 }
6264 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6265 /*Sequential=*/true));
6266}
6267
6268static std::optional<const SCEV *>
6270 Value *FalseVal) {
6271 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6272 return std::nullopt;
6273
6274 const auto *SECond = SE->getSCEV(Cond);
6275 const auto *SETrue = SE->getSCEV(TrueVal);
6276 const auto *SEFalse = SE->getSCEV(FalseVal);
6277 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6278}
6279
6280const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6281 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6282 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6283 assert(TrueVal->getType() == FalseVal->getType() &&
6284 V->getType() == TrueVal->getType() &&
6285 "Types of select hands and of the result must match.");
6286
6287 // For now, only deal with i1-typed `select`s.
6288 if (!V->getType()->isIntegerTy(1))
6289 return getUnknown(V);
6290
6291 if (std::optional<const SCEV *> S =
6292 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6293 return *S;
6294
6295 return getUnknown(V);
6296}
6297
6298const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6299 Value *TrueVal,
6300 Value *FalseVal) {
6301 // Handle "constant" branch or select. This can occur for instance when a
6302 // loop pass transforms an inner loop and moves on to process the outer loop.
6303 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6304 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6305
6306 if (auto *I = dyn_cast<Instruction>(V)) {
6307 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6308 if (std::optional<const SCEV *> S =
6309 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6310 TrueVal, FalseVal))
6311 return *S;
6312 }
6313 }
6314
6315 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6316}
6317
6318/// Expand GEP instructions into add and multiply operations. This allows them
6319/// to be analyzed by regular SCEV code.
6320const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6321 assert(GEP->getSourceElementType()->isSized() &&
6322 "GEP source element type must be sized");
6323
6325 for (Value *Index : GEP->indices())
6326 IndexExprs.push_back(getSCEV(Index));
6327 return getGEPExpr(GEP, IndexExprs);
6328}
6329
6330APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S,
6331 const Instruction *CtxI) {
6332 uint64_t BitWidth = getTypeSizeInBits(S->getType());
6333 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6334 return TrailingZeros >= BitWidth
6336 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6337 };
6338 auto GetGCDMultiple = [this, CtxI](const SCEVNAryExpr *N) {
6339 // The result is GCD of all operands results.
6340 APInt Res = getConstantMultiple(N->getOperand(0), CtxI);
6341 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6343 Res, getConstantMultiple(N->getOperand(I), CtxI));
6344 return Res;
6345 };
6346
6347 switch (S->getSCEVType()) {
6348 case scConstant:
6349 return cast<SCEVConstant>(S)->getAPInt();
6350 case scPtrToInt:
6351 return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand(), CtxI);
6352 case scUDivExpr:
6353 case scVScale:
6354 return APInt(BitWidth, 1);
6355 case scTruncate: {
6356 // Only multiples that are a power of 2 will hold after truncation.
6357 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6358 uint32_t TZ = getMinTrailingZeros(T->getOperand(), CtxI);
6359 return GetShiftedByZeros(TZ);
6360 }
6361 case scZeroExtend: {
6362 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6363 return getConstantMultiple(Z->getOperand(), CtxI).zext(BitWidth);
6364 }
6365 case scSignExtend: {
6366 // Only multiples that are a power of 2 will hold after sext.
6367 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6368 uint32_t TZ = getMinTrailingZeros(E->getOperand(), CtxI);
6369 return GetShiftedByZeros(TZ);
6370 }
6371 case scMulExpr: {
6372 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6373 if (M->hasNoUnsignedWrap()) {
6374 // The result is the product of all operand results.
6375 APInt Res = getConstantMultiple(M->getOperand(0), CtxI);
6376 for (const SCEV *Operand : M->operands().drop_front())
6377 Res = Res * getConstantMultiple(Operand, CtxI);
6378 return Res;
6379 }
6380
6381 // If there are no wrap guarentees, find the trailing zeros, which is the
6382 // sum of trailing zeros for all its operands.
6383 uint32_t TZ = 0;
6384 for (const SCEV *Operand : M->operands())
6385 TZ += getMinTrailingZeros(Operand, CtxI);
6386 return GetShiftedByZeros(TZ);
6387 }
6388 case scAddExpr:
6389 case scAddRecExpr: {
6390 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6391 if (N->hasNoUnsignedWrap())
6392 return GetGCDMultiple(N);
6393 // Find the trailing bits, which is the minimum of its operands.
6394 uint32_t TZ = getMinTrailingZeros(N->getOperand(0), CtxI);
6395 for (const SCEV *Operand : N->operands().drop_front())
6396 TZ = std::min(TZ, getMinTrailingZeros(Operand, CtxI));
6397 return GetShiftedByZeros(TZ);
6398 }
6399 case scUMaxExpr:
6400 case scSMaxExpr:
6401 case scUMinExpr:
6402 case scSMinExpr:
6404 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6405 case scUnknown: {
6406 // Ask ValueTracking for known bits. SCEVUnknown only become available at
6407 // the point their underlying IR instruction has been defined. If CtxI was
6408 // not provided, use:
6409 // * the first instruction in the entry block if it is an argument
6410 // * the instruction itself otherwise.
6411 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6412 if (!CtxI) {
6413 if (isa<Argument>(U->getValue()))
6414 CtxI = &*F.getEntryBlock().begin();
6415 else if (auto *I = dyn_cast<Instruction>(U->getValue()))
6416 CtxI = I;
6417 }
6418 unsigned Known =
6419 computeKnownBits(U->getValue(), getDataLayout(), &AC, CtxI, &DT)
6420 .countMinTrailingZeros();
6421 return GetShiftedByZeros(Known);
6422 }
6423 case scCouldNotCompute:
6424 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6425 }
6426 llvm_unreachable("Unknown SCEV kind!");
6427}
6428
6430 const Instruction *CtxI) {
6431 // Skip looking up and updating the cache if there is a context instruction,
6432 // as the result will only be valid in the specified context.
6433 if (CtxI)
6434 return getConstantMultipleImpl(S, CtxI);
6435
6436 auto I = ConstantMultipleCache.find(S);
6437 if (I != ConstantMultipleCache.end())
6438 return I->second;
6439
6440 APInt Result = getConstantMultipleImpl(S, CtxI);
6441 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6442 assert(InsertPair.second && "Should insert a new key");
6443 return InsertPair.first->second;
6444}
6445
6447 APInt Multiple = getConstantMultiple(S);
6448 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6449}
6450
6452 const Instruction *CtxI) {
6453 return std::min(getConstantMultiple(S, CtxI).countTrailingZeros(),
6454 (unsigned)getTypeSizeInBits(S->getType()));
6455}
6456
6457/// Helper method to assign a range to V from metadata present in the IR.
6458static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6460 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6461 return getConstantRangeFromMetadata(*MD);
6462 if (const auto *CB = dyn_cast<CallBase>(V))
6463 if (std::optional<ConstantRange> Range = CB->getRange())
6464 return Range;
6465 }
6466 if (auto *A = dyn_cast<Argument>(V))
6467 if (std::optional<ConstantRange> Range = A->getRange())
6468 return Range;
6469
6470 return std::nullopt;
6471}
6472
6474 SCEV::NoWrapFlags Flags) {
6475 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6476 AddRec->setNoWrapFlags(Flags);
6477 UnsignedRanges.erase(AddRec);
6478 SignedRanges.erase(AddRec);
6479 ConstantMultipleCache.erase(AddRec);
6480 }
6481}
6482
6483ConstantRange ScalarEvolution::
6484getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6485 const DataLayout &DL = getDataLayout();
6486
6487 unsigned BitWidth = getTypeSizeInBits(U->getType());
6488 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6489
6490 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6491 // use information about the trip count to improve our available range. Note
6492 // that the trip count independent cases are already handled by known bits.
6493 // WARNING: The definition of recurrence used here is subtly different than
6494 // the one used by AddRec (and thus most of this file). Step is allowed to
6495 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6496 // and other addrecs in the same loop (for non-affine addrecs). The code
6497 // below intentionally handles the case where step is not loop invariant.
6498 auto *P = dyn_cast<PHINode>(U->getValue());
6499 if (!P)
6500 return FullSet;
6501
6502 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6503 // even the values that are not available in these blocks may come from them,
6504 // and this leads to false-positive recurrence test.
6505 for (auto *Pred : predecessors(P->getParent()))
6506 if (!DT.isReachableFromEntry(Pred))
6507 return FullSet;
6508
6509 BinaryOperator *BO;
6510 Value *Start, *Step;
6511 if (!matchSimpleRecurrence(P, BO, Start, Step))
6512 return FullSet;
6513
6514 // If we found a recurrence in reachable code, we must be in a loop. Note
6515 // that BO might be in some subloop of L, and that's completely okay.
6516 auto *L = LI.getLoopFor(P->getParent());
6517 assert(L && L->getHeader() == P->getParent());
6518 if (!L->contains(BO->getParent()))
6519 // NOTE: This bailout should be an assert instead. However, asserting
6520 // the condition here exposes a case where LoopFusion is querying SCEV
6521 // with malformed loop information during the midst of the transform.
6522 // There doesn't appear to be an obvious fix, so for the moment bailout
6523 // until the caller issue can be fixed. PR49566 tracks the bug.
6524 return FullSet;
6525
6526 // TODO: Extend to other opcodes such as mul, and div
6527 switch (BO->getOpcode()) {
6528 default:
6529 return FullSet;
6530 case Instruction::AShr:
6531 case Instruction::LShr:
6532 case Instruction::Shl:
6533 break;
6534 };
6535
6536 if (BO->getOperand(0) != P)
6537 // TODO: Handle the power function forms some day.
6538 return FullSet;
6539
6540 unsigned TC = getSmallConstantMaxTripCount(L);
6541 if (!TC || TC >= BitWidth)
6542 return FullSet;
6543
6544 auto KnownStart = computeKnownBits(Start, DL, &AC, nullptr, &DT);
6545 auto KnownStep = computeKnownBits(Step, DL, &AC, nullptr, &DT);
6546 assert(KnownStart.getBitWidth() == BitWidth &&
6547 KnownStep.getBitWidth() == BitWidth);
6548
6549 // Compute total shift amount, being careful of overflow and bitwidths.
6550 auto MaxShiftAmt = KnownStep.getMaxValue();
6551 APInt TCAP(BitWidth, TC-1);
6552 bool Overflow = false;
6553 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6554 if (Overflow)
6555 return FullSet;
6556
6557 switch (BO->getOpcode()) {
6558 default:
6559 llvm_unreachable("filtered out above");
6560 case Instruction::AShr: {
6561 // For each ashr, three cases:
6562 // shift = 0 => unchanged value
6563 // saturation => 0 or -1
6564 // other => a value closer to zero (of the same sign)
6565 // Thus, the end value is closer to zero than the start.
6566 auto KnownEnd = KnownBits::ashr(KnownStart,
6567 KnownBits::makeConstant(TotalShift));
6568 if (KnownStart.isNonNegative())
6569 // Analogous to lshr (simply not yet canonicalized)
6570 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6571 KnownStart.getMaxValue() + 1);
6572 if (KnownStart.isNegative())
6573 // End >=u Start && End <=s Start
6574 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6575 KnownEnd.getMaxValue() + 1);
6576 break;
6577 }
6578 case Instruction::LShr: {
6579 // For each lshr, three cases:
6580 // shift = 0 => unchanged value
6581 // saturation => 0
6582 // other => a smaller positive number
6583 // Thus, the low end of the unsigned range is the last value produced.
6584 auto KnownEnd = KnownBits::lshr(KnownStart,
6585 KnownBits::makeConstant(TotalShift));
6586 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6587 KnownStart.getMaxValue() + 1);
6588 }
6589 case Instruction::Shl: {
6590 // Iff no bits are shifted out, value increases on every shift.
6591 auto KnownEnd = KnownBits::shl(KnownStart,
6592 KnownBits::makeConstant(TotalShift));
6593 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6594 return ConstantRange(KnownStart.getMinValue(),
6595 KnownEnd.getMaxValue() + 1);
6596 break;
6597 }
6598 };
6599 return FullSet;
6600}
6601
6602const ConstantRange &
6603ScalarEvolution::getRangeRefIter(const SCEV *S,
6604 ScalarEvolution::RangeSignHint SignHint) {
6605 DenseMap<const SCEV *, ConstantRange> &Cache =
6606 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6607 : SignedRanges;
6609 SmallPtrSet<const SCEV *, 8> Seen;
6610
6611 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6612 // SCEVUnknown PHI node.
6613 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6614 if (!Seen.insert(Expr).second)
6615 return;
6616 if (Cache.contains(Expr))
6617 return;
6618 switch (Expr->getSCEVType()) {
6619 case scUnknown:
6620 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6621 break;
6622 [[fallthrough]];
6623 case scConstant:
6624 case scVScale:
6625 case scTruncate:
6626 case scZeroExtend:
6627 case scSignExtend:
6628 case scPtrToInt:
6629 case scAddExpr:
6630 case scMulExpr:
6631 case scUDivExpr:
6632 case scAddRecExpr:
6633 case scUMaxExpr:
6634 case scSMaxExpr:
6635 case scUMinExpr:
6636 case scSMinExpr:
6638 WorkList.push_back(Expr);
6639 break;
6640 case scCouldNotCompute:
6641 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6642 }
6643 };
6644 AddToWorklist(S);
6645
6646 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6647 for (unsigned I = 0; I != WorkList.size(); ++I) {
6648 const SCEV *P = WorkList[I];
6649 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6650 // If it is not a `SCEVUnknown`, just recurse into operands.
6651 if (!UnknownS) {
6652 for (const SCEV *Op : P->operands())
6653 AddToWorklist(Op);
6654 continue;
6655 }
6656 // `SCEVUnknown`'s require special treatment.
6657 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6658 if (!PendingPhiRangesIter.insert(P).second)
6659 continue;
6660 for (auto &Op : reverse(P->operands()))
6661 AddToWorklist(getSCEV(Op));
6662 }
6663 }
6664
6665 if (!WorkList.empty()) {
6666 // Use getRangeRef to compute ranges for items in the worklist in reverse
6667 // order. This will force ranges for earlier operands to be computed before
6668 // their users in most cases.
6669 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6670 getRangeRef(P, SignHint);
6671
6672 if (auto *UnknownS = dyn_cast<SCEVUnknown>(P))
6673 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue()))
6674 PendingPhiRangesIter.erase(P);
6675 }
6676 }
6677
6678 return getRangeRef(S, SignHint, 0);
6679}
6680
6681/// Determine the range for a particular SCEV. If SignHint is
6682/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6683/// with a "cleaner" unsigned (resp. signed) representation.
6684const ConstantRange &ScalarEvolution::getRangeRef(
6685 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6686 DenseMap<const SCEV *, ConstantRange> &Cache =
6687 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6688 : SignedRanges;
6690 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6692
6693 // See if we've computed this range already.
6695 if (I != Cache.end())
6696 return I->second;
6697
6698 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6699 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6700
6701 // Switch to iteratively computing the range for S, if it is part of a deeply
6702 // nested expression.
6704 return getRangeRefIter(S, SignHint);
6705
6706 unsigned BitWidth = getTypeSizeInBits(S->getType());
6707 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6708 using OBO = OverflowingBinaryOperator;
6709
6710 // If the value has known zeros, the maximum value will have those known zeros
6711 // as well.
6712 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6713 APInt Multiple = getNonZeroConstantMultiple(S);
6714 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6715 if (!Remainder.isZero())
6716 ConservativeResult =
6717 ConstantRange(APInt::getMinValue(BitWidth),
6718 APInt::getMaxValue(BitWidth) - Remainder + 1);
6719 }
6720 else {
6721 uint32_t TZ = getMinTrailingZeros(S);
6722 if (TZ != 0) {
6723 ConservativeResult = ConstantRange(
6725 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6726 }
6727 }
6728
6729 switch (S->getSCEVType()) {
6730 case scConstant:
6731 llvm_unreachable("Already handled above.");
6732 case scVScale:
6733 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6734 case scTruncate: {
6735 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6736 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6737 return setRange(
6738 Trunc, SignHint,
6739 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6740 }
6741 case scZeroExtend: {
6742 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6743 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6744 return setRange(
6745 ZExt, SignHint,
6746 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6747 }
6748 case scSignExtend: {
6749 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6750 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6751 return setRange(
6752 SExt, SignHint,
6753 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6754 }
6755 case scPtrToInt: {
6756 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(S);
6757 ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint, Depth + 1);
6758 return setRange(PtrToInt, SignHint, X);
6759 }
6760 case scAddExpr: {
6761 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6762 // Check if this is a URem pattern: A - (A / B) * B, which is always < B.
6763 const SCEV *URemLHS = nullptr, *URemRHS = nullptr;
6764 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED &&
6765 match(S, m_scev_URem(m_SCEV(URemLHS), m_SCEV(URemRHS), *this))) {
6766 ConstantRange LHSRange = getRangeRef(URemLHS, SignHint, Depth + 1);
6767 ConstantRange RHSRange = getRangeRef(URemRHS, SignHint, Depth + 1);
6768 ConservativeResult =
6769 ConservativeResult.intersectWith(LHSRange.urem(RHSRange), RangeType);
6770 }
6771 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6772 unsigned WrapType = OBO::AnyWrap;
6773 if (Add->hasNoSignedWrap())
6774 WrapType |= OBO::NoSignedWrap;
6775 if (Add->hasNoUnsignedWrap())
6776 WrapType |= OBO::NoUnsignedWrap;
6777 for (const SCEV *Op : drop_begin(Add->operands()))
6778 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType,
6779 RangeType);
6780 return setRange(Add, SignHint,
6781 ConservativeResult.intersectWith(X, RangeType));
6782 }
6783 case scMulExpr: {
6784 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6785 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6786 for (const SCEV *Op : drop_begin(Mul->operands()))
6787 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1));
6788 return setRange(Mul, SignHint,
6789 ConservativeResult.intersectWith(X, RangeType));
6790 }
6791 case scUDivExpr: {
6792 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6793 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6794 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6795 return setRange(UDiv, SignHint,
6796 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6797 }
6798 case scAddRecExpr: {
6799 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6800 // If there's no unsigned wrap, the value will never be less than its
6801 // initial value.
6802 if (AddRec->hasNoUnsignedWrap()) {
6803 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6804 if (!UnsignedMinValue.isZero())
6805 ConservativeResult = ConservativeResult.intersectWith(
6806 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6807 }
6808
6809 // If there's no signed wrap, and all the operands except initial value have
6810 // the same sign or zero, the value won't ever be:
6811 // 1: smaller than initial value if operands are non negative,
6812 // 2: bigger than initial value if operands are non positive.
6813 // For both cases, value can not cross signed min/max boundary.
6814 if (AddRec->hasNoSignedWrap()) {
6815 bool AllNonNeg = true;
6816 bool AllNonPos = true;
6817 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6818 if (!isKnownNonNegative(AddRec->getOperand(i)))
6819 AllNonNeg = false;
6820 if (!isKnownNonPositive(AddRec->getOperand(i)))
6821 AllNonPos = false;
6822 }
6823 if (AllNonNeg)
6824 ConservativeResult = ConservativeResult.intersectWith(
6827 RangeType);
6828 else if (AllNonPos)
6829 ConservativeResult = ConservativeResult.intersectWith(
6831 getSignedRangeMax(AddRec->getStart()) +
6832 1),
6833 RangeType);
6834 }
6835
6836 // TODO: non-affine addrec
6837 if (AddRec->isAffine()) {
6838 const SCEV *MaxBEScev =
6840 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
6841 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
6842
6843 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
6844 // MaxBECount's active bits are all <= AddRec's bit width.
6845 if (MaxBECount.getBitWidth() > BitWidth &&
6846 MaxBECount.getActiveBits() <= BitWidth)
6847 MaxBECount = MaxBECount.trunc(BitWidth);
6848 else if (MaxBECount.getBitWidth() < BitWidth)
6849 MaxBECount = MaxBECount.zext(BitWidth);
6850
6851 if (MaxBECount.getBitWidth() == BitWidth) {
6852 auto RangeFromAffine = getRangeForAffineAR(
6853 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6854 ConservativeResult =
6855 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
6856
6857 auto RangeFromFactoring = getRangeViaFactoring(
6858 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6859 ConservativeResult =
6860 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
6861 }
6862 }
6863
6864 // Now try symbolic BE count and more powerful methods.
6866 const SCEV *SymbolicMaxBECount =
6868 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
6869 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
6870 AddRec->hasNoSelfWrap()) {
6871 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
6872 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
6873 ConservativeResult =
6874 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
6875 }
6876 }
6877 }
6878
6879 return setRange(AddRec, SignHint, std::move(ConservativeResult));
6880 }
6881 case scUMaxExpr:
6882 case scSMaxExpr:
6883 case scUMinExpr:
6884 case scSMinExpr:
6885 case scSequentialUMinExpr: {
6887 switch (S->getSCEVType()) {
6888 case scUMaxExpr:
6889 ID = Intrinsic::umax;
6890 break;
6891 case scSMaxExpr:
6892 ID = Intrinsic::smax;
6893 break;
6894 case scUMinExpr:
6896 ID = Intrinsic::umin;
6897 break;
6898 case scSMinExpr:
6899 ID = Intrinsic::smin;
6900 break;
6901 default:
6902 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
6903 }
6904
6905 const auto *NAry = cast<SCEVNAryExpr>(S);
6906 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
6907 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
6908 X = X.intrinsic(
6909 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
6910 return setRange(S, SignHint,
6911 ConservativeResult.intersectWith(X, RangeType));
6912 }
6913 case scUnknown: {
6914 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6915 Value *V = U->getValue();
6916
6917 // Check if the IR explicitly contains !range metadata.
6918 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
6919 if (MDRange)
6920 ConservativeResult =
6921 ConservativeResult.intersectWith(*MDRange, RangeType);
6922
6923 // Use facts about recurrences in the underlying IR. Note that add
6924 // recurrences are AddRecExprs and thus don't hit this path. This
6925 // primarily handles shift recurrences.
6926 auto CR = getRangeForUnknownRecurrence(U);
6927 ConservativeResult = ConservativeResult.intersectWith(CR);
6928
6929 // See if ValueTracking can give us a useful range.
6930 const DataLayout &DL = getDataLayout();
6931 KnownBits Known = computeKnownBits(V, DL, &AC, nullptr, &DT);
6932 if (Known.getBitWidth() != BitWidth)
6933 Known = Known.zextOrTrunc(BitWidth);
6934
6935 // ValueTracking may be able to compute a tighter result for the number of
6936 // sign bits than for the value of those sign bits.
6937 unsigned NS = ComputeNumSignBits(V, DL, &AC, nullptr, &DT);
6938 if (U->getType()->isPointerTy()) {
6939 // If the pointer size is larger than the index size type, this can cause
6940 // NS to be larger than BitWidth. So compensate for this.
6941 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
6942 int ptrIdxDiff = ptrSize - BitWidth;
6943 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
6944 NS -= ptrIdxDiff;
6945 }
6946
6947 if (NS > 1) {
6948 // If we know any of the sign bits, we know all of the sign bits.
6949 if (!Known.Zero.getHiBits(NS).isZero())
6950 Known.Zero.setHighBits(NS);
6951 if (!Known.One.getHiBits(NS).isZero())
6952 Known.One.setHighBits(NS);
6953 }
6954
6955 if (Known.getMinValue() != Known.getMaxValue() + 1)
6956 ConservativeResult = ConservativeResult.intersectWith(
6957 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
6958 RangeType);
6959 if (NS > 1)
6960 ConservativeResult = ConservativeResult.intersectWith(
6961 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
6962 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
6963 RangeType);
6964
6965 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
6966 // Strengthen the range if the underlying IR value is a
6967 // global/alloca/heap allocation using the size of the object.
6968 bool CanBeNull, CanBeFreed;
6969 uint64_t DerefBytes =
6970 V->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed);
6971 if (DerefBytes > 1 && isUIntN(BitWidth, DerefBytes)) {
6972 // The highest address the object can start is DerefBytes bytes before
6973 // the end (unsigned max value). If this value is not a multiple of the
6974 // alignment, the last possible start value is the next lowest multiple
6975 // of the alignment. Note: The computations below cannot overflow,
6976 // because if they would there's no possible start address for the
6977 // object.
6978 APInt MaxVal =
6979 APInt::getMaxValue(BitWidth) - APInt(BitWidth, DerefBytes);
6980 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
6981 uint64_t Rem = MaxVal.urem(Align);
6982 MaxVal -= APInt(BitWidth, Rem);
6983 APInt MinVal = APInt::getZero(BitWidth);
6984 if (llvm::isKnownNonZero(V, DL))
6985 MinVal = Align;
6986 ConservativeResult = ConservativeResult.intersectWith(
6987 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
6988 }
6989 }
6990
6991 // A range of Phi is a subset of union of all ranges of its input.
6992 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
6993 // Make sure that we do not run over cycled Phis.
6994 if (PendingPhiRanges.insert(Phi).second) {
6995 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
6996
6997 for (const auto &Op : Phi->operands()) {
6998 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
6999 RangeFromOps = RangeFromOps.unionWith(OpRange);
7000 // No point to continue if we already have a full set.
7001 if (RangeFromOps.isFullSet())
7002 break;
7003 }
7004 ConservativeResult =
7005 ConservativeResult.intersectWith(RangeFromOps, RangeType);
7006 bool Erased = PendingPhiRanges.erase(Phi);
7007 assert(Erased && "Failed to erase Phi properly?");
7008 (void)Erased;
7009 }
7010 }
7011
7012 // vscale can't be equal to zero
7013 if (const auto *II = dyn_cast<IntrinsicInst>(V))
7014 if (II->getIntrinsicID() == Intrinsic::vscale) {
7015 ConstantRange Disallowed = APInt::getZero(BitWidth);
7016 ConservativeResult = ConservativeResult.difference(Disallowed);
7017 }
7018
7019 return setRange(U, SignHint, std::move(ConservativeResult));
7020 }
7021 case scCouldNotCompute:
7022 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
7023 }
7024
7025 return setRange(S, SignHint, std::move(ConservativeResult));
7026}
7027
7028// Given a StartRange, Step and MaxBECount for an expression compute a range of
7029// values that the expression can take. Initially, the expression has a value
7030// from StartRange and then is changed by Step up to MaxBECount times. Signed
7031// argument defines if we treat Step as signed or unsigned.
7033 const ConstantRange &StartRange,
7034 const APInt &MaxBECount,
7035 bool Signed) {
7036 unsigned BitWidth = Step.getBitWidth();
7037 assert(BitWidth == StartRange.getBitWidth() &&
7038 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
7039 // If either Step or MaxBECount is 0, then the expression won't change, and we
7040 // just need to return the initial range.
7041 if (Step == 0 || MaxBECount == 0)
7042 return StartRange;
7043
7044 // If we don't know anything about the initial value (i.e. StartRange is
7045 // FullRange), then we don't know anything about the final range either.
7046 // Return FullRange.
7047 if (StartRange.isFullSet())
7048 return ConstantRange::getFull(BitWidth);
7049
7050 // If Step is signed and negative, then we use its absolute value, but we also
7051 // note that we're moving in the opposite direction.
7052 bool Descending = Signed && Step.isNegative();
7053
7054 if (Signed)
7055 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
7056 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
7057 // This equations hold true due to the well-defined wrap-around behavior of
7058 // APInt.
7059 Step = Step.abs();
7060
7061 // Check if Offset is more than full span of BitWidth. If it is, the
7062 // expression is guaranteed to overflow.
7063 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
7064 return ConstantRange::getFull(BitWidth);
7065
7066 // Offset is by how much the expression can change. Checks above guarantee no
7067 // overflow here.
7068 APInt Offset = Step * MaxBECount;
7069
7070 // Minimum value of the final range will match the minimal value of StartRange
7071 // if the expression is increasing and will be decreased by Offset otherwise.
7072 // Maximum value of the final range will match the maximal value of StartRange
7073 // if the expression is decreasing and will be increased by Offset otherwise.
7074 APInt StartLower = StartRange.getLower();
7075 APInt StartUpper = StartRange.getUpper() - 1;
7076 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
7077 : (StartUpper + std::move(Offset));
7078
7079 // It's possible that the new minimum/maximum value will fall into the initial
7080 // range (due to wrap around). This means that the expression can take any
7081 // value in this bitwidth, and we have to return full range.
7082 if (StartRange.contains(MovedBoundary))
7083 return ConstantRange::getFull(BitWidth);
7084
7085 APInt NewLower =
7086 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7087 APInt NewUpper =
7088 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7089 NewUpper += 1;
7090
7091 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7092 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7093}
7094
7095ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7096 const SCEV *Step,
7097 const APInt &MaxBECount) {
7098 assert(getTypeSizeInBits(Start->getType()) ==
7099 getTypeSizeInBits(Step->getType()) &&
7100 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7101 "mismatched bit widths");
7102
7103 // First, consider step signed.
7104 ConstantRange StartSRange = getSignedRange(Start);
7105 ConstantRange StepSRange = getSignedRange(Step);
7106
7107 // If Step can be both positive and negative, we need to find ranges for the
7108 // maximum absolute step values in both directions and union them.
7109 ConstantRange SR = getRangeForAffineARHelper(
7110 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7112 StartSRange, MaxBECount,
7113 /* Signed = */ true));
7114
7115 // Next, consider step unsigned.
7116 ConstantRange UR = getRangeForAffineARHelper(
7117 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7118 /* Signed = */ false);
7119
7120 // Finally, intersect signed and unsigned ranges.
7122}
7123
7124ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7125 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7126 ScalarEvolution::RangeSignHint SignHint) {
7127 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7128 assert(AddRec->hasNoSelfWrap() &&
7129 "This only works for non-self-wrapping AddRecs!");
7130 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7131 const SCEV *Step = AddRec->getStepRecurrence(*this);
7132 // Only deal with constant step to save compile time.
7133 if (!isa<SCEVConstant>(Step))
7134 return ConstantRange::getFull(BitWidth);
7135 // Let's make sure that we can prove that we do not self-wrap during
7136 // MaxBECount iterations. We need this because MaxBECount is a maximum
7137 // iteration count estimate, and we might infer nw from some exit for which we
7138 // do not know max exit count (or any other side reasoning).
7139 // TODO: Turn into assert at some point.
7140 if (getTypeSizeInBits(MaxBECount->getType()) >
7141 getTypeSizeInBits(AddRec->getType()))
7142 return ConstantRange::getFull(BitWidth);
7143 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7144 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7145 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7146 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7147 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7148 MaxItersWithoutWrap))
7149 return ConstantRange::getFull(BitWidth);
7150
7151 ICmpInst::Predicate LEPred =
7153 ICmpInst::Predicate GEPred =
7155 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7156
7157 // We know that there is no self-wrap. Let's take Start and End values and
7158 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7159 // the iteration. They either lie inside the range [Min(Start, End),
7160 // Max(Start, End)] or outside it:
7161 //
7162 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7163 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7164 //
7165 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7166 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7167 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7168 // Start <= End and step is positive, or Start >= End and step is negative.
7169 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7170 ConstantRange StartRange = getRangeRef(Start, SignHint);
7171 ConstantRange EndRange = getRangeRef(End, SignHint);
7172 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7173 // If they already cover full iteration space, we will know nothing useful
7174 // even if we prove what we want to prove.
7175 if (RangeBetween.isFullSet())
7176 return RangeBetween;
7177 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7178 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7179 : RangeBetween.isWrappedSet();
7180 if (IsWrappedSet)
7181 return ConstantRange::getFull(BitWidth);
7182
7183 if (isKnownPositive(Step) &&
7184 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7185 return RangeBetween;
7186 if (isKnownNegative(Step) &&
7187 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7188 return RangeBetween;
7189 return ConstantRange::getFull(BitWidth);
7190}
7191
7192ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7193 const SCEV *Step,
7194 const APInt &MaxBECount) {
7195 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7196 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7197
7198 unsigned BitWidth = MaxBECount.getBitWidth();
7199 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7200 getTypeSizeInBits(Step->getType()) == BitWidth &&
7201 "mismatched bit widths");
7202
7203 struct SelectPattern {
7204 Value *Condition = nullptr;
7205 APInt TrueValue;
7206 APInt FalseValue;
7207
7208 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7209 const SCEV *S) {
7210 std::optional<unsigned> CastOp;
7211 APInt Offset(BitWidth, 0);
7212
7214 "Should be!");
7215
7216 // Peel off a constant offset. In the future we could consider being
7217 // smarter here and handle {Start+Step,+,Step} too.
7218 const APInt *Off;
7219 if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
7220 Offset = *Off;
7221
7222 // Peel off a cast operation
7223 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7224 CastOp = SCast->getSCEVType();
7225 S = SCast->getOperand();
7226 }
7227
7228 using namespace llvm::PatternMatch;
7229
7230 auto *SU = dyn_cast<SCEVUnknown>(S);
7231 const APInt *TrueVal, *FalseVal;
7232 if (!SU ||
7233 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7234 m_APInt(FalseVal)))) {
7235 Condition = nullptr;
7236 return;
7237 }
7238
7239 TrueValue = *TrueVal;
7240 FalseValue = *FalseVal;
7241
7242 // Re-apply the cast we peeled off earlier
7243 if (CastOp)
7244 switch (*CastOp) {
7245 default:
7246 llvm_unreachable("Unknown SCEV cast type!");
7247
7248 case scTruncate:
7249 TrueValue = TrueValue.trunc(BitWidth);
7250 FalseValue = FalseValue.trunc(BitWidth);
7251 break;
7252 case scZeroExtend:
7253 TrueValue = TrueValue.zext(BitWidth);
7254 FalseValue = FalseValue.zext(BitWidth);
7255 break;
7256 case scSignExtend:
7257 TrueValue = TrueValue.sext(BitWidth);
7258 FalseValue = FalseValue.sext(BitWidth);
7259 break;
7260 }
7261
7262 // Re-apply the constant offset we peeled off earlier
7263 TrueValue += Offset;
7264 FalseValue += Offset;
7265 }
7266
7267 bool isRecognized() { return Condition != nullptr; }
7268 };
7269
7270 SelectPattern StartPattern(*this, BitWidth, Start);
7271 if (!StartPattern.isRecognized())
7272 return ConstantRange::getFull(BitWidth);
7273
7274 SelectPattern StepPattern(*this, BitWidth, Step);
7275 if (!StepPattern.isRecognized())
7276 return ConstantRange::getFull(BitWidth);
7277
7278 if (StartPattern.Condition != StepPattern.Condition) {
7279 // We don't handle this case today; but we could, by considering four
7280 // possibilities below instead of two. I'm not sure if there are cases where
7281 // that will help over what getRange already does, though.
7282 return ConstantRange::getFull(BitWidth);
7283 }
7284
7285 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7286 // construct arbitrary general SCEV expressions here. This function is called
7287 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7288 // say) can end up caching a suboptimal value.
7289
7290 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7291 // C2352 and C2512 (otherwise it isn't needed).
7292
7293 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7294 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7295 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7296 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7297
7298 ConstantRange TrueRange =
7299 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7300 ConstantRange FalseRange =
7301 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7302
7303 return TrueRange.unionWith(FalseRange);
7304}
7305
7306SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7307 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7308 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7309
7310 // Return early if there are no flags to propagate to the SCEV.
7312 if (BinOp->hasNoUnsignedWrap())
7314 if (BinOp->hasNoSignedWrap())
7316 if (Flags == SCEV::FlagAnyWrap)
7317 return SCEV::FlagAnyWrap;
7318
7319 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7320}
7321
7322const Instruction *
7323ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7324 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7325 return &*AddRec->getLoop()->getHeader()->begin();
7326 if (auto *U = dyn_cast<SCEVUnknown>(S))
7327 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7328 return I;
7329 return nullptr;
7330}
7331
7332const Instruction *
7333ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops,
7334 bool &Precise) {
7335 Precise = true;
7336 // Do a bounded search of the def relation of the requested SCEVs.
7337 SmallPtrSet<const SCEV *, 16> Visited;
7339 auto pushOp = [&](const SCEV *S) {
7340 if (!Visited.insert(S).second)
7341 return;
7342 // Threshold of 30 here is arbitrary.
7343 if (Visited.size() > 30) {
7344 Precise = false;
7345 return;
7346 }
7347 Worklist.push_back(S);
7348 };
7349
7350 for (const auto *S : Ops)
7351 pushOp(S);
7352
7353 const Instruction *Bound = nullptr;
7354 while (!Worklist.empty()) {
7355 auto *S = Worklist.pop_back_val();
7356 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7357 if (!Bound || DT.dominates(Bound, DefI))
7358 Bound = DefI;
7359 } else {
7360 for (const auto *Op : S->operands())
7361 pushOp(Op);
7362 }
7363 }
7364 return Bound ? Bound : &*F.getEntryBlock().begin();
7365}
7366
7367const Instruction *
7368ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) {
7369 bool Discard;
7370 return getDefiningScopeBound(Ops, Discard);
7371}
7372
7373bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7374 const Instruction *B) {
7375 if (A->getParent() == B->getParent() &&
7377 B->getIterator()))
7378 return true;
7379
7380 auto *BLoop = LI.getLoopFor(B->getParent());
7381 if (BLoop && BLoop->getHeader() == B->getParent() &&
7382 BLoop->getLoopPreheader() == A->getParent() &&
7384 A->getParent()->end()) &&
7385 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7386 B->getIterator()))
7387 return true;
7388 return false;
7389}
7390
7391bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
7392 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ true);
7393 visitAll(Op, PC);
7394 return PC.MaybePoison.empty();
7395}
7396
7397bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
7398 return !SCEVExprContains(Op, [this](const SCEV *S) {
7399 const SCEV *Op1;
7400 bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
7401 // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
7402 // is a non-zero constant, we have to assume the UDiv may be UB.
7403 return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
7404 });
7405}
7406
7407bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7408 // Only proceed if we can prove that I does not yield poison.
7410 return false;
7411
7412 // At this point we know that if I is executed, then it does not wrap
7413 // according to at least one of NSW or NUW. If I is not executed, then we do
7414 // not know if the calculation that I represents would wrap. Multiple
7415 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7416 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7417 // derived from other instructions that map to the same SCEV. We cannot make
7418 // that guarantee for cases where I is not executed. So we need to find a
7419 // upper bound on the defining scope for the SCEV, and prove that I is
7420 // executed every time we enter that scope. When the bounding scope is a
7421 // loop (the common case), this is equivalent to proving I executes on every
7422 // iteration of that loop.
7424 for (const Use &Op : I->operands()) {
7425 // I could be an extractvalue from a call to an overflow intrinsic.
7426 // TODO: We can do better here in some cases.
7427 if (isSCEVable(Op->getType()))
7428 SCEVOps.push_back(getSCEV(Op));
7429 }
7430 auto *DefI = getDefiningScopeBound(SCEVOps);
7431 return isGuaranteedToTransferExecutionTo(DefI, I);
7432}
7433
7434bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7435 // If we know that \c I can never be poison period, then that's enough.
7436 if (isSCEVExprNeverPoison(I))
7437 return true;
7438
7439 // If the loop only has one exit, then we know that, if the loop is entered,
7440 // any instruction dominating that exit will be executed. If any such
7441 // instruction would result in UB, the addrec cannot be poison.
7442 //
7443 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7444 // also handles uses outside the loop header (they just need to dominate the
7445 // single exit).
7446
7447 auto *ExitingBB = L->getExitingBlock();
7448 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7449 return false;
7450
7451 SmallPtrSet<const Value *, 16> KnownPoison;
7453
7454 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7455 // things that are known to be poison under that assumption go on the
7456 // Worklist.
7457 KnownPoison.insert(I);
7458 Worklist.push_back(I);
7459
7460 while (!Worklist.empty()) {
7461 const Instruction *Poison = Worklist.pop_back_val();
7462
7463 for (const Use &U : Poison->uses()) {
7464 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7465 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7466 DT.dominates(PoisonUser->getParent(), ExitingBB))
7467 return true;
7468
7469 if (propagatesPoison(U) && L->contains(PoisonUser))
7470 if (KnownPoison.insert(PoisonUser).second)
7471 Worklist.push_back(PoisonUser);
7472 }
7473 }
7474
7475 return false;
7476}
7477
7478ScalarEvolution::LoopProperties
7479ScalarEvolution::getLoopProperties(const Loop *L) {
7480 using LoopProperties = ScalarEvolution::LoopProperties;
7481
7482 auto Itr = LoopPropertiesCache.find(L);
7483 if (Itr == LoopPropertiesCache.end()) {
7484 auto HasSideEffects = [](Instruction *I) {
7485 if (auto *SI = dyn_cast<StoreInst>(I))
7486 return !SI->isSimple();
7487
7488 if (I->mayThrow())
7489 return true;
7490
7491 // Non-volatile memset / memcpy do not count as side-effect for forward
7492 // progress.
7493 if (isa<MemIntrinsic>(I) && !I->isVolatile())
7494 return false;
7495
7496 return I->mayWriteToMemory();
7497 };
7498
7499 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7500 /*HasNoSideEffects*/ true};
7501
7502 for (auto *BB : L->getBlocks())
7503 for (auto &I : *BB) {
7505 LP.HasNoAbnormalExits = false;
7506 if (HasSideEffects(&I))
7507 LP.HasNoSideEffects = false;
7508 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7509 break; // We're already as pessimistic as we can get.
7510 }
7511
7512 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7513 assert(InsertPair.second && "We just checked!");
7514 Itr = InsertPair.first;
7515 }
7516
7517 return Itr->second;
7518}
7519
7521 // A mustprogress loop without side effects must be finite.
7522 // TODO: The check used here is very conservative. It's only *specific*
7523 // side effects which are well defined in infinite loops.
7524 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7525}
7526
7527const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7528 // Worklist item with a Value and a bool indicating whether all operands have
7529 // been visited already.
7532
7533 Stack.emplace_back(V, true);
7534 Stack.emplace_back(V, false);
7535 while (!Stack.empty()) {
7536 auto E = Stack.pop_back_val();
7537 Value *CurV = E.getPointer();
7538
7539 if (getExistingSCEV(CurV))
7540 continue;
7541
7543 const SCEV *CreatedSCEV = nullptr;
7544 // If all operands have been visited already, create the SCEV.
7545 if (E.getInt()) {
7546 CreatedSCEV = createSCEV(CurV);
7547 } else {
7548 // Otherwise get the operands we need to create SCEV's for before creating
7549 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7550 // just use it.
7551 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7552 }
7553
7554 if (CreatedSCEV) {
7555 insertValueToMap(CurV, CreatedSCEV);
7556 } else {
7557 // Queue CurV for SCEV creation, followed by its's operands which need to
7558 // be constructed first.
7559 Stack.emplace_back(CurV, true);
7560 for (Value *Op : Ops)
7561 Stack.emplace_back(Op, false);
7562 }
7563 }
7564
7565 return getExistingSCEV(V);
7566}
7567
7568const SCEV *
7569ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7570 if (!isSCEVable(V->getType()))
7571 return getUnknown(V);
7572
7573 if (Instruction *I = dyn_cast<Instruction>(V)) {
7574 // Don't attempt to analyze instructions in blocks that aren't
7575 // reachable. Such instructions don't matter, and they aren't required
7576 // to obey basic rules for definitions dominating uses which this
7577 // analysis depends on.
7578 if (!DT.isReachableFromEntry(I->getParent()))
7579 return getUnknown(PoisonValue::get(V->getType()));
7580 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7581 return getConstant(CI);
7582 else if (isa<GlobalAlias>(V))
7583 return getUnknown(V);
7584 else if (!isa<ConstantExpr>(V))
7585 return getUnknown(V);
7586
7588 if (auto BO =
7590 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7591 switch (BO->Opcode) {
7592 case Instruction::Add:
7593 case Instruction::Mul: {
7594 // For additions and multiplications, traverse add/mul chains for which we
7595 // can potentially create a single SCEV, to reduce the number of
7596 // get{Add,Mul}Expr calls.
7597 do {
7598 if (BO->Op) {
7599 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7600 Ops.push_back(BO->Op);
7601 break;
7602 }
7603 }
7604 Ops.push_back(BO->RHS);
7605 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7607 if (!NewBO ||
7608 (BO->Opcode == Instruction::Add &&
7609 (NewBO->Opcode != Instruction::Add &&
7610 NewBO->Opcode != Instruction::Sub)) ||
7611 (BO->Opcode == Instruction::Mul &&
7612 NewBO->Opcode != Instruction::Mul)) {
7613 Ops.push_back(BO->LHS);
7614 break;
7615 }
7616 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7617 // requires a SCEV for the LHS.
7618 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7619 auto *I = dyn_cast<Instruction>(BO->Op);
7620 if (I && programUndefinedIfPoison(I)) {
7621 Ops.push_back(BO->LHS);
7622 break;
7623 }
7624 }
7625 BO = NewBO;
7626 } while (true);
7627 return nullptr;
7628 }
7629 case Instruction::Sub:
7630 case Instruction::UDiv:
7631 case Instruction::URem:
7632 break;
7633 case Instruction::AShr:
7634 case Instruction::Shl:
7635 case Instruction::Xor:
7636 if (!IsConstArg)
7637 return nullptr;
7638 break;
7639 case Instruction::And:
7640 case Instruction::Or:
7641 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7642 return nullptr;
7643 break;
7644 case Instruction::LShr:
7645 return getUnknown(V);
7646 default:
7647 llvm_unreachable("Unhandled binop");
7648 break;
7649 }
7650
7651 Ops.push_back(BO->LHS);
7652 Ops.push_back(BO->RHS);
7653 return nullptr;
7654 }
7655
7656 switch (U->getOpcode()) {
7657 case Instruction::Trunc:
7658 case Instruction::ZExt:
7659 case Instruction::SExt:
7660 case Instruction::PtrToInt:
7661 Ops.push_back(U->getOperand(0));
7662 return nullptr;
7663
7664 case Instruction::BitCast:
7665 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7666 Ops.push_back(U->getOperand(0));
7667 return nullptr;
7668 }
7669 return getUnknown(V);
7670
7671 case Instruction::SDiv:
7672 case Instruction::SRem:
7673 Ops.push_back(U->getOperand(0));
7674 Ops.push_back(U->getOperand(1));
7675 return nullptr;
7676
7677 case Instruction::GetElementPtr:
7678 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7679 "GEP source element type must be sized");
7680 llvm::append_range(Ops, U->operands());
7681 return nullptr;
7682
7683 case Instruction::IntToPtr:
7684 return getUnknown(V);
7685
7686 case Instruction::PHI:
7687 // Keep constructing SCEVs' for phis recursively for now.
7688 return nullptr;
7689
7690 case Instruction::Select: {
7691 // Check if U is a select that can be simplified to a SCEVUnknown.
7692 auto CanSimplifyToUnknown = [this, U]() {
7693 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7694 return false;
7695
7696 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7697 if (!ICI)
7698 return false;
7699 Value *LHS = ICI->getOperand(0);
7700 Value *RHS = ICI->getOperand(1);
7701 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7702 ICI->getPredicate() == CmpInst::ICMP_NE) {
7704 return true;
7705 } else if (getTypeSizeInBits(LHS->getType()) >
7706 getTypeSizeInBits(U->getType()))
7707 return true;
7708 return false;
7709 };
7710 if (CanSimplifyToUnknown())
7711 return getUnknown(U);
7712
7713 llvm::append_range(Ops, U->operands());
7714 return nullptr;
7715 break;
7716 }
7717 case Instruction::Call:
7718 case Instruction::Invoke:
7719 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7720 Ops.push_back(RV);
7721 return nullptr;
7722 }
7723
7724 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7725 switch (II->getIntrinsicID()) {
7726 case Intrinsic::abs:
7727 Ops.push_back(II->getArgOperand(0));
7728 return nullptr;
7729 case Intrinsic::umax:
7730 case Intrinsic::umin:
7731 case Intrinsic::smax:
7732 case Intrinsic::smin:
7733 case Intrinsic::usub_sat:
7734 case Intrinsic::uadd_sat:
7735 Ops.push_back(II->getArgOperand(0));
7736 Ops.push_back(II->getArgOperand(1));
7737 return nullptr;
7738 case Intrinsic::start_loop_iterations:
7739 case Intrinsic::annotation:
7740 case Intrinsic::ptr_annotation:
7741 Ops.push_back(II->getArgOperand(0));
7742 return nullptr;
7743 default:
7744 break;
7745 }
7746 }
7747 break;
7748 }
7749
7750 return nullptr;
7751}
7752
7753const SCEV *ScalarEvolution::createSCEV(Value *V) {
7754 if (!isSCEVable(V->getType()))
7755 return getUnknown(V);
7756
7757 if (Instruction *I = dyn_cast<Instruction>(V)) {
7758 // Don't attempt to analyze instructions in blocks that aren't
7759 // reachable. Such instructions don't matter, and they aren't required
7760 // to obey basic rules for definitions dominating uses which this
7761 // analysis depends on.
7762 if (!DT.isReachableFromEntry(I->getParent()))
7763 return getUnknown(PoisonValue::get(V->getType()));
7764 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7765 return getConstant(CI);
7766 else if (isa<GlobalAlias>(V))
7767 return getUnknown(V);
7768 else if (!isa<ConstantExpr>(V))
7769 return getUnknown(V);
7770
7771 const SCEV *LHS;
7772 const SCEV *RHS;
7773
7775 if (auto BO =
7777 switch (BO->Opcode) {
7778 case Instruction::Add: {
7779 // The simple thing to do would be to just call getSCEV on both operands
7780 // and call getAddExpr with the result. However if we're looking at a
7781 // bunch of things all added together, this can be quite inefficient,
7782 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7783 // Instead, gather up all the operands and make a single getAddExpr call.
7784 // LLVM IR canonical form means we need only traverse the left operands.
7786 do {
7787 if (BO->Op) {
7788 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7789 AddOps.push_back(OpSCEV);
7790 break;
7791 }
7792
7793 // If a NUW or NSW flag can be applied to the SCEV for this
7794 // addition, then compute the SCEV for this addition by itself
7795 // with a separate call to getAddExpr. We need to do that
7796 // instead of pushing the operands of the addition onto AddOps,
7797 // since the flags are only known to apply to this particular
7798 // addition - they may not apply to other additions that can be
7799 // formed with operands from AddOps.
7800 const SCEV *RHS = getSCEV(BO->RHS);
7801 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7802 if (Flags != SCEV::FlagAnyWrap) {
7803 const SCEV *LHS = getSCEV(BO->LHS);
7804 if (BO->Opcode == Instruction::Sub)
7805 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
7806 else
7807 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
7808 break;
7809 }
7810 }
7811
7812 if (BO->Opcode == Instruction::Sub)
7813 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
7814 else
7815 AddOps.push_back(getSCEV(BO->RHS));
7816
7817 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7819 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7820 NewBO->Opcode != Instruction::Sub)) {
7821 AddOps.push_back(getSCEV(BO->LHS));
7822 break;
7823 }
7824 BO = NewBO;
7825 } while (true);
7826
7827 return getAddExpr(AddOps);
7828 }
7829
7830 case Instruction::Mul: {
7832 do {
7833 if (BO->Op) {
7834 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7835 MulOps.push_back(OpSCEV);
7836 break;
7837 }
7838
7839 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7840 if (Flags != SCEV::FlagAnyWrap) {
7841 LHS = getSCEV(BO->LHS);
7842 RHS = getSCEV(BO->RHS);
7843 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
7844 break;
7845 }
7846 }
7847
7848 MulOps.push_back(getSCEV(BO->RHS));
7849 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7851 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7852 MulOps.push_back(getSCEV(BO->LHS));
7853 break;
7854 }
7855 BO = NewBO;
7856 } while (true);
7857
7858 return getMulExpr(MulOps);
7859 }
7860 case Instruction::UDiv:
7861 LHS = getSCEV(BO->LHS);
7862 RHS = getSCEV(BO->RHS);
7863 return getUDivExpr(LHS, RHS);
7864 case Instruction::URem:
7865 LHS = getSCEV(BO->LHS);
7866 RHS = getSCEV(BO->RHS);
7867 return getURemExpr(LHS, RHS);
7868 case Instruction::Sub: {
7870 if (BO->Op)
7871 Flags = getNoWrapFlagsFromUB(BO->Op);
7872 LHS = getSCEV(BO->LHS);
7873 RHS = getSCEV(BO->RHS);
7874 return getMinusSCEV(LHS, RHS, Flags);
7875 }
7876 case Instruction::And:
7877 // For an expression like x&255 that merely masks off the high bits,
7878 // use zext(trunc(x)) as the SCEV expression.
7879 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7880 if (CI->isZero())
7881 return getSCEV(BO->RHS);
7882 if (CI->isMinusOne())
7883 return getSCEV(BO->LHS);
7884 const APInt &A = CI->getValue();
7885
7886 // Instcombine's ShrinkDemandedConstant may strip bits out of
7887 // constants, obscuring what would otherwise be a low-bits mask.
7888 // Use computeKnownBits to compute what ShrinkDemandedConstant
7889 // knew about to reconstruct a low-bits mask value.
7890 unsigned LZ = A.countl_zero();
7891 unsigned TZ = A.countr_zero();
7892 unsigned BitWidth = A.getBitWidth();
7893 KnownBits Known(BitWidth);
7894 computeKnownBits(BO->LHS, Known, getDataLayout(), &AC, nullptr, &DT);
7895
7896 APInt EffectiveMask =
7897 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
7898 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
7899 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
7900 const SCEV *LHS = getSCEV(BO->LHS);
7901 const SCEV *ShiftedLHS = nullptr;
7902 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
7903 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
7904 // For an expression like (x * 8) & 8, simplify the multiply.
7905 unsigned MulZeros = OpC->getAPInt().countr_zero();
7906 unsigned GCD = std::min(MulZeros, TZ);
7907 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
7909 MulOps.push_back(getConstant(OpC->getAPInt().ashr(GCD)));
7910 append_range(MulOps, LHSMul->operands().drop_front());
7911 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
7912 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
7913 }
7914 }
7915 if (!ShiftedLHS)
7916 ShiftedLHS = getUDivExpr(LHS, MulCount);
7917 return getMulExpr(
7919 getTruncateExpr(ShiftedLHS,
7920 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
7921 BO->LHS->getType()),
7922 MulCount);
7923 }
7924 }
7925 // Binary `and` is a bit-wise `umin`.
7926 if (BO->LHS->getType()->isIntegerTy(1)) {
7927 LHS = getSCEV(BO->LHS);
7928 RHS = getSCEV(BO->RHS);
7929 return getUMinExpr(LHS, RHS);
7930 }
7931 break;
7932
7933 case Instruction::Or:
7934 // Binary `or` is a bit-wise `umax`.
7935 if (BO->LHS->getType()->isIntegerTy(1)) {
7936 LHS = getSCEV(BO->LHS);
7937 RHS = getSCEV(BO->RHS);
7938 return getUMaxExpr(LHS, RHS);
7939 }
7940 break;
7941
7942 case Instruction::Xor:
7943 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7944 // If the RHS of xor is -1, then this is a not operation.
7945 if (CI->isMinusOne())
7946 return getNotSCEV(getSCEV(BO->LHS));
7947
7948 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
7949 // This is a variant of the check for xor with -1, and it handles
7950 // the case where instcombine has trimmed non-demanded bits out
7951 // of an xor with -1.
7952 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
7953 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
7954 if (LBO->getOpcode() == Instruction::And &&
7955 LCI->getValue() == CI->getValue())
7956 if (const SCEVZeroExtendExpr *Z =
7958 Type *UTy = BO->LHS->getType();
7959 const SCEV *Z0 = Z->getOperand();
7960 Type *Z0Ty = Z0->getType();
7961 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
7962
7963 // If C is a low-bits mask, the zero extend is serving to
7964 // mask off the high bits. Complement the operand and
7965 // re-apply the zext.
7966 if (CI->getValue().isMask(Z0TySize))
7967 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
7968
7969 // If C is a single bit, it may be in the sign-bit position
7970 // before the zero-extend. In this case, represent the xor
7971 // using an add, which is equivalent, and re-apply the zext.
7972 APInt Trunc = CI->getValue().trunc(Z0TySize);
7973 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
7974 Trunc.isSignMask())
7975 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
7976 UTy);
7977 }
7978 }
7979 break;
7980
7981 case Instruction::Shl:
7982 // Turn shift left of a constant amount into a multiply.
7983 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
7984 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
7985
7986 // If the shift count is not less than the bitwidth, the result of
7987 // the shift is undefined. Don't try to analyze it, because the
7988 // resolution chosen here may differ from the resolution chosen in
7989 // other parts of the compiler.
7990 if (SA->getValue().uge(BitWidth))
7991 break;
7992
7993 // We can safely preserve the nuw flag in all cases. It's also safe to
7994 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
7995 // requires special handling. It can be preserved as long as we're not
7996 // left shifting by bitwidth - 1.
7997 auto Flags = SCEV::FlagAnyWrap;
7998 if (BO->Op) {
7999 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
8000 if ((MulFlags & SCEV::FlagNSW) &&
8001 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
8003 if (MulFlags & SCEV::FlagNUW)
8005 }
8006
8007 ConstantInt *X = ConstantInt::get(
8008 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
8009 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
8010 }
8011 break;
8012
8013 case Instruction::AShr:
8014 // AShr X, C, where C is a constant.
8015 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
8016 if (!CI)
8017 break;
8018
8019 Type *OuterTy = BO->LHS->getType();
8020 uint64_t BitWidth = getTypeSizeInBits(OuterTy);
8021 // If the shift count is not less than the bitwidth, the result of
8022 // the shift is undefined. Don't try to analyze it, because the
8023 // resolution chosen here may differ from the resolution chosen in
8024 // other parts of the compiler.
8025 if (CI->getValue().uge(BitWidth))
8026 break;
8027
8028 if (CI->isZero())
8029 return getSCEV(BO->LHS); // shift by zero --> noop
8030
8031 uint64_t AShrAmt = CI->getZExtValue();
8032 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
8033
8034 Operator *L = dyn_cast<Operator>(BO->LHS);
8035 const SCEV *AddTruncateExpr = nullptr;
8036 ConstantInt *ShlAmtCI = nullptr;
8037 const SCEV *AddConstant = nullptr;
8038
8039 if (L && L->getOpcode() == Instruction::Add) {
8040 // X = Shl A, n
8041 // Y = Add X, c
8042 // Z = AShr Y, m
8043 // n, c and m are constants.
8044
8045 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
8046 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
8047 if (LShift && LShift->getOpcode() == Instruction::Shl) {
8048 if (AddOperandCI) {
8049 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
8050 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
8051 // since we truncate to TruncTy, the AddConstant should be of the
8052 // same type, so create a new Constant with type same as TruncTy.
8053 // Also, the Add constant should be shifted right by AShr amount.
8054 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
8055 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
8056 // we model the expression as sext(add(trunc(A), c << n)), since the
8057 // sext(trunc) part is already handled below, we create a
8058 // AddExpr(TruncExp) which will be used later.
8059 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8060 }
8061 }
8062 } else if (L && L->getOpcode() == Instruction::Shl) {
8063 // X = Shl A, n
8064 // Y = AShr X, m
8065 // Both n and m are constant.
8066
8067 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
8068 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
8069 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8070 }
8071
8072 if (AddTruncateExpr && ShlAmtCI) {
8073 // We can merge the two given cases into a single SCEV statement,
8074 // incase n = m, the mul expression will be 2^0, so it gets resolved to
8075 // a simpler case. The following code handles the two cases:
8076 //
8077 // 1) For a two-shift sext-inreg, i.e. n = m,
8078 // use sext(trunc(x)) as the SCEV expression.
8079 //
8080 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
8081 // expression. We already checked that ShlAmt < BitWidth, so
8082 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
8083 // ShlAmt - AShrAmt < Amt.
8084 const APInt &ShlAmt = ShlAmtCI->getValue();
8085 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
8086 APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
8087 ShlAmtCI->getZExtValue() - AShrAmt);
8088 const SCEV *CompositeExpr =
8089 getMulExpr(AddTruncateExpr, getConstant(Mul));
8090 if (L->getOpcode() != Instruction::Shl)
8091 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8092
8093 return getSignExtendExpr(CompositeExpr, OuterTy);
8094 }
8095 }
8096 break;
8097 }
8098 }
8099
8100 switch (U->getOpcode()) {
8101 case Instruction::Trunc:
8102 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8103
8104 case Instruction::ZExt:
8105 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8106
8107 case Instruction::SExt:
8108 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8110 // The NSW flag of a subtract does not always survive the conversion to
8111 // A + (-1)*B. By pushing sign extension onto its operands we are much
8112 // more likely to preserve NSW and allow later AddRec optimisations.
8113 //
8114 // NOTE: This is effectively duplicating this logic from getSignExtend:
8115 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8116 // but by that point the NSW information has potentially been lost.
8117 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8118 Type *Ty = U->getType();
8119 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8120 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8121 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8122 }
8123 }
8124 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8125
8126 case Instruction::BitCast:
8127 // BitCasts are no-op casts so we just eliminate the cast.
8128 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8129 return getSCEV(U->getOperand(0));
8130 break;
8131
8132 case Instruction::PtrToInt: {
8133 // Pointer to integer cast is straight-forward, so do model it.
8134 const SCEV *Op = getSCEV(U->getOperand(0));
8135 Type *DstIntTy = U->getType();
8136 // But only if effective SCEV (integer) type is wide enough to represent
8137 // all possible pointer values.
8138 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8139 if (isa<SCEVCouldNotCompute>(IntOp))
8140 return getUnknown(V);
8141 return IntOp;
8142 }
8143 case Instruction::IntToPtr:
8144 // Just don't deal with inttoptr casts.
8145 return getUnknown(V);
8146
8147 case Instruction::SDiv:
8148 // If both operands are non-negative, this is just an udiv.
8149 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8150 isKnownNonNegative(getSCEV(U->getOperand(1))))
8151 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8152 break;
8153
8154 case Instruction::SRem:
8155 // If both operands are non-negative, this is just an urem.
8156 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8157 isKnownNonNegative(getSCEV(U->getOperand(1))))
8158 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8159 break;
8160
8161 case Instruction::GetElementPtr:
8162 return createNodeForGEP(cast<GEPOperator>(U));
8163
8164 case Instruction::PHI:
8165 return createNodeForPHI(cast<PHINode>(U));
8166
8167 case Instruction::Select:
8168 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8169 U->getOperand(2));
8170
8171 case Instruction::Call:
8172 case Instruction::Invoke:
8173 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8174 return getSCEV(RV);
8175
8176 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8177 switch (II->getIntrinsicID()) {
8178 case Intrinsic::abs:
8179 return getAbsExpr(
8180 getSCEV(II->getArgOperand(0)),
8181 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8182 case Intrinsic::umax:
8183 LHS = getSCEV(II->getArgOperand(0));
8184 RHS = getSCEV(II->getArgOperand(1));
8185 return getUMaxExpr(LHS, RHS);
8186 case Intrinsic::umin:
8187 LHS = getSCEV(II->getArgOperand(0));
8188 RHS = getSCEV(II->getArgOperand(1));
8189 return getUMinExpr(LHS, RHS);
8190 case Intrinsic::smax:
8191 LHS = getSCEV(II->getArgOperand(0));
8192 RHS = getSCEV(II->getArgOperand(1));
8193 return getSMaxExpr(LHS, RHS);
8194 case Intrinsic::smin:
8195 LHS = getSCEV(II->getArgOperand(0));
8196 RHS = getSCEV(II->getArgOperand(1));
8197 return getSMinExpr(LHS, RHS);
8198 case Intrinsic::usub_sat: {
8199 const SCEV *X = getSCEV(II->getArgOperand(0));
8200 const SCEV *Y = getSCEV(II->getArgOperand(1));
8201 const SCEV *ClampedY = getUMinExpr(X, Y);
8202 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8203 }
8204 case Intrinsic::uadd_sat: {
8205 const SCEV *X = getSCEV(II->getArgOperand(0));
8206 const SCEV *Y = getSCEV(II->getArgOperand(1));
8207 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8208 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8209 }
8210 case Intrinsic::start_loop_iterations:
8211 case Intrinsic::annotation:
8212 case Intrinsic::ptr_annotation:
8213 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8214 // just eqivalent to the first operand for SCEV purposes.
8215 return getSCEV(II->getArgOperand(0));
8216 case Intrinsic::vscale:
8217 return getVScale(II->getType());
8218 default:
8219 break;
8220 }
8221 }
8222 break;
8223 }
8224
8225 return getUnknown(V);
8226}
8227
8228//===----------------------------------------------------------------------===//
8229// Iteration Count Computation Code
8230//
8231
8233 if (isa<SCEVCouldNotCompute>(ExitCount))
8234 return getCouldNotCompute();
8235
8236 auto *ExitCountType = ExitCount->getType();
8237 assert(ExitCountType->isIntegerTy());
8238 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8239 1 + ExitCountType->getScalarSizeInBits());
8240 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8241}
8242
8244 Type *EvalTy,
8245 const Loop *L) {
8246 if (isa<SCEVCouldNotCompute>(ExitCount))
8247 return getCouldNotCompute();
8248
8249 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8250 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8251
8252 auto CanAddOneWithoutOverflow = [&]() {
8253 ConstantRange ExitCountRange =
8254 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8255 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8256 return true;
8257
8258 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8259 getMinusOne(ExitCount->getType()));
8260 };
8261
8262 // If we need to zero extend the backedge count, check if we can add one to
8263 // it prior to zero extending without overflow. Provided this is safe, it
8264 // allows better simplification of the +1.
8265 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8266 return getZeroExtendExpr(
8267 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8268
8269 // Get the total trip count from the count by adding 1. This may wrap.
8270 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8271}
8272
8273static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8274 if (!ExitCount)
8275 return 0;
8276
8277 ConstantInt *ExitConst = ExitCount->getValue();
8278
8279 // Guard against huge trip counts.
8280 if (ExitConst->getValue().getActiveBits() > 32)
8281 return 0;
8282
8283 // In case of integer overflow, this returns 0, which is correct.
8284 return ((unsigned)ExitConst->getZExtValue()) + 1;
8285}
8286
8288 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8289 return getConstantTripCount(ExitCount);
8290}
8291
8292unsigned
8294 const BasicBlock *ExitingBlock) {
8295 assert(ExitingBlock && "Must pass a non-null exiting block!");
8296 assert(L->isLoopExiting(ExitingBlock) &&
8297 "Exiting block must actually branch out of the loop!");
8298 const SCEVConstant *ExitCount =
8299 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8300 return getConstantTripCount(ExitCount);
8301}
8302
8304 const Loop *L, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8305
8306 const auto *MaxExitCount =
8307 Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
8309 return getConstantTripCount(dyn_cast<SCEVConstant>(MaxExitCount));
8310}
8311
8313 SmallVector<BasicBlock *, 8> ExitingBlocks;
8314 L->getExitingBlocks(ExitingBlocks);
8315
8316 std::optional<unsigned> Res;
8317 for (auto *ExitingBB : ExitingBlocks) {
8318 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8319 if (!Res)
8320 Res = Multiple;
8321 Res = std::gcd(*Res, Multiple);
8322 }
8323 return Res.value_or(1);
8324}
8325
8327 const SCEV *ExitCount) {
8328 if (isa<SCEVCouldNotCompute>(ExitCount))
8329 return 1;
8330
8331 // Get the trip count
8332 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8333
8334 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8335 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8336 // the greatest power of 2 divisor less than 2^32.
8337 return Multiple.getActiveBits() > 32
8338 ? 1U << std::min(31U, Multiple.countTrailingZeros())
8339 : (unsigned)Multiple.getZExtValue();
8340}
8341
8342/// Returns the largest constant divisor of the trip count of this loop as a
8343/// normal unsigned value, if possible. This means that the actual trip count is
8344/// always a multiple of the returned value (don't forget the trip count could
8345/// very well be zero as well!).
8346///
8347/// Returns 1 if the trip count is unknown or not guaranteed to be the
8348/// multiple of a constant (which is also the case if the trip count is simply
8349/// constant, use getSmallConstantTripCount for that case), Will also return 1
8350/// if the trip count is very large (>= 2^32).
8351///
8352/// As explained in the comments for getSmallConstantTripCount, this assumes
8353/// that control exits the loop via ExitingBlock.
8354unsigned
8356 const BasicBlock *ExitingBlock) {
8357 assert(ExitingBlock && "Must pass a non-null exiting block!");
8358 assert(L->isLoopExiting(ExitingBlock) &&
8359 "Exiting block must actually branch out of the loop!");
8360 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8361 return getSmallConstantTripMultiple(L, ExitCount);
8362}
8363
8365 const BasicBlock *ExitingBlock,
8366 ExitCountKind Kind) {
8367 switch (Kind) {
8368 case Exact:
8369 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8370 case SymbolicMaximum:
8371 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8372 case ConstantMaximum:
8373 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8374 };
8375 llvm_unreachable("Invalid ExitCountKind!");
8376}
8377
8379 const Loop *L, const BasicBlock *ExitingBlock,
8381 switch (Kind) {
8382 case Exact:
8383 return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
8384 Predicates);
8385 case SymbolicMaximum:
8386 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
8387 Predicates);
8388 case ConstantMaximum:
8389 return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
8390 Predicates);
8391 };
8392 llvm_unreachable("Invalid ExitCountKind!");
8393}
8394
8397 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8398}
8399
8401 ExitCountKind Kind) {
8402 switch (Kind) {
8403 case Exact:
8404 return getBackedgeTakenInfo(L).getExact(L, this);
8405 case ConstantMaximum:
8406 return getBackedgeTakenInfo(L).getConstantMax(this);
8407 case SymbolicMaximum:
8408 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8409 };
8410 llvm_unreachable("Invalid ExitCountKind!");
8411}
8412
8415 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8416}
8417
8420 return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds);
8421}
8422
8424 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8425}
8426
8427/// Push PHI nodes in the header of the given loop onto the given Worklist.
8428static void PushLoopPHIs(const Loop *L,
8431 BasicBlock *Header = L->getHeader();
8432
8433 // Push all Loop-header PHIs onto the Worklist stack.
8434 for (PHINode &PN : Header->phis())
8435 if (Visited.insert(&PN).second)
8436 Worklist.push_back(&PN);
8437}
8438
8439ScalarEvolution::BackedgeTakenInfo &
8440ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8441 auto &BTI = getBackedgeTakenInfo(L);
8442 if (BTI.hasFullInfo())
8443 return BTI;
8444
8445 auto Pair = PredicatedBackedgeTakenCounts.try_emplace(L);
8446
8447 if (!Pair.second)
8448 return Pair.first->second;
8449
8450 BackedgeTakenInfo Result =
8451 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8452
8453 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8454}
8455
8456ScalarEvolution::BackedgeTakenInfo &
8457ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8458 // Initially insert an invalid entry for this loop. If the insertion
8459 // succeeds, proceed to actually compute a backedge-taken count and
8460 // update the value. The temporary CouldNotCompute value tells SCEV
8461 // code elsewhere that it shouldn't attempt to request a new
8462 // backedge-taken count, which could result in infinite recursion.
8463 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8464 BackedgeTakenCounts.try_emplace(L);
8465 if (!Pair.second)
8466 return Pair.first->second;
8467
8468 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8469 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8470 // must be cleared in this scope.
8471 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8472
8473 // Now that we know more about the trip count for this loop, forget any
8474 // existing SCEV values for PHI nodes in this loop since they are only
8475 // conservative estimates made without the benefit of trip count
8476 // information. This invalidation is not necessary for correctness, and is
8477 // only done to produce more precise results.
8478 if (Result.hasAnyInfo()) {
8479 // Invalidate any expression using an addrec in this loop.
8481 auto LoopUsersIt = LoopUsers.find(L);
8482 if (LoopUsersIt != LoopUsers.end())
8483 append_range(ToForget, LoopUsersIt->second);
8484 forgetMemoizedResults(ToForget);
8485
8486 // Invalidate constant-evolved loop header phis.
8487 for (PHINode &PN : L->getHeader()->phis())
8488 ConstantEvolutionLoopExitValue.erase(&PN);
8489 }
8490
8491 // Re-lookup the insert position, since the call to
8492 // computeBackedgeTakenCount above could result in a
8493 // recusive call to getBackedgeTakenInfo (on a different
8494 // loop), which would invalidate the iterator computed
8495 // earlier.
8496 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8497}
8498
8500 // This method is intended to forget all info about loops. It should
8501 // invalidate caches as if the following happened:
8502 // - The trip counts of all loops have changed arbitrarily
8503 // - Every llvm::Value has been updated in place to produce a different
8504 // result.
8505 BackedgeTakenCounts.clear();
8506 PredicatedBackedgeTakenCounts.clear();
8507 BECountUsers.clear();
8508 LoopPropertiesCache.clear();
8509 ConstantEvolutionLoopExitValue.clear();
8510 ValueExprMap.clear();
8511 ValuesAtScopes.clear();
8512 ValuesAtScopesUsers.clear();
8513 LoopDispositions.clear();
8514 BlockDispositions.clear();
8515 UnsignedRanges.clear();
8516 SignedRanges.clear();
8517 ExprValueMap.clear();
8518 HasRecMap.clear();
8519 ConstantMultipleCache.clear();
8520 PredicatedSCEVRewrites.clear();
8521 FoldCache.clear();
8522 FoldCacheUser.clear();
8523}
8524void ScalarEvolution::visitAndClearUsers(
8528 while (!Worklist.empty()) {
8529 Instruction *I = Worklist.pop_back_val();
8530 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8531 continue;
8532
8534 ValueExprMap.find_as(static_cast<Value *>(I));
8535 if (It != ValueExprMap.end()) {
8536 eraseValueFromMap(It->first);
8537 ToForget.push_back(It->second);
8538 if (PHINode *PN = dyn_cast<PHINode>(I))
8539 ConstantEvolutionLoopExitValue.erase(PN);
8540 }
8541
8542 PushDefUseChildren(I, Worklist, Visited);
8543 }
8544}
8545
8547 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8551
8552 // Iterate over all the loops and sub-loops to drop SCEV information.
8553 while (!LoopWorklist.empty()) {
8554 auto *CurrL = LoopWorklist.pop_back_val();
8555
8556 // Drop any stored trip count value.
8557 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8558 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8559
8560 // Drop information about predicated SCEV rewrites for this loop.
8561 for (auto I = PredicatedSCEVRewrites.begin();
8562 I != PredicatedSCEVRewrites.end();) {
8563 std::pair<const SCEV *, const Loop *> Entry = I->first;
8564 if (Entry.second == CurrL)
8565 PredicatedSCEVRewrites.erase(I++);
8566 else
8567 ++I;
8568 }
8569
8570 auto LoopUsersItr = LoopUsers.find(CurrL);
8571 if (LoopUsersItr != LoopUsers.end())
8572 llvm::append_range(ToForget, LoopUsersItr->second);
8573
8574 // Drop information about expressions based on loop-header PHIs.
8575 PushLoopPHIs(CurrL, Worklist, Visited);
8576 visitAndClearUsers(Worklist, Visited, ToForget);
8577
8578 LoopPropertiesCache.erase(CurrL);
8579 // Forget all contained loops too, to avoid dangling entries in the
8580 // ValuesAtScopes map.
8581 LoopWorklist.append(CurrL->begin(), CurrL->end());
8582 }
8583 forgetMemoizedResults(ToForget);
8584}
8585
8587 forgetLoop(L->getOutermostLoop());
8588}
8589
8592 if (!I) return;
8593
8594 // Drop information about expressions based on loop-header PHIs.
8598 Worklist.push_back(I);
8599 Visited.insert(I);
8600 visitAndClearUsers(Worklist, Visited, ToForget);
8601
8602 forgetMemoizedResults(ToForget);
8603}
8604
8606 if (!isSCEVable(V->getType()))
8607 return;
8608
8609 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8610 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8611 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8612 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8613 if (const SCEV *S = getExistingSCEV(V)) {
8614 struct InvalidationRootCollector {
8615 Loop *L;
8617
8618 InvalidationRootCollector(Loop *L) : L(L) {}
8619
8620 bool follow(const SCEV *S) {
8621 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8622 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8623 if (L->contains(I))
8624 Roots.push_back(S);
8625 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8626 if (L->contains(AddRec->getLoop()))
8627 Roots.push_back(S);
8628 }
8629 return true;
8630 }
8631 bool isDone() const { return false; }
8632 };
8633
8634 InvalidationRootCollector C(L);
8635 visitAll(S, C);
8636 forgetMemoizedResults(C.Roots);
8637 }
8638
8639 // Also perform the normal invalidation.
8640 forgetValue(V);
8641}
8642
8643void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8644
8646 // Unless a specific value is passed to invalidation, completely clear both
8647 // caches.
8648 if (!V) {
8649 BlockDispositions.clear();
8650 LoopDispositions.clear();
8651 return;
8652 }
8653
8654 if (!isSCEVable(V->getType()))
8655 return;
8656
8657 const SCEV *S = getExistingSCEV(V);
8658 if (!S)
8659 return;
8660
8661 // Invalidate the block and loop dispositions cached for S. Dispositions of
8662 // S's users may change if S's disposition changes (i.e. a user may change to
8663 // loop-invariant, if S changes to loop invariant), so also invalidate
8664 // dispositions of S's users recursively.
8665 SmallVector<const SCEV *, 8> Worklist = {S};
8667 while (!Worklist.empty()) {
8668 const SCEV *Curr = Worklist.pop_back_val();
8669 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8670 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8671 if (!LoopDispoRemoved && !BlockDispoRemoved)
8672 continue;
8673 auto Users = SCEVUsers.find(Curr);
8674 if (Users != SCEVUsers.end())
8675 for (const auto *User : Users->second)
8676 if (Seen.insert(User).second)
8677 Worklist.push_back(User);
8678 }
8679}
8680
8681/// Get the exact loop backedge taken count considering all loop exits. A
8682/// computable result can only be returned for loops with all exiting blocks
8683/// dominating the latch. howFarToZero assumes that the limit of each loop test
8684/// is never skipped. This is a valid assumption as long as the loop exits via
8685/// that test. For precise results, it is the caller's responsibility to specify
8686/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8687const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
8688 const Loop *L, ScalarEvolution *SE,
8690 // If any exits were not computable, the loop is not computable.
8691 if (!isComplete() || ExitNotTaken.empty())
8692 return SE->getCouldNotCompute();
8693
8694 const BasicBlock *Latch = L->getLoopLatch();
8695 // All exiting blocks we have collected must dominate the only backedge.
8696 if (!Latch)
8697 return SE->getCouldNotCompute();
8698
8699 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8700 // count is simply a minimum out of all these calculated exit counts.
8702 for (const auto &ENT : ExitNotTaken) {
8703 const SCEV *BECount = ENT.ExactNotTaken;
8704 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8705 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8706 "We should only have known counts for exiting blocks that dominate "
8707 "latch!");
8708
8709 Ops.push_back(BECount);
8710
8711 if (Preds)
8712 append_range(*Preds, ENT.Predicates);
8713
8714 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8715 "Predicate should be always true!");
8716 }
8717
8718 // If an earlier exit exits on the first iteration (exit count zero), then
8719 // a later poison exit count should not propagate into the result. This are
8720 // exactly the semantics provided by umin_seq.
8721 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8722}
8723
8724const ScalarEvolution::ExitNotTakenInfo *
8725ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
8726 const BasicBlock *ExitingBlock,
8727 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8728 for (const auto &ENT : ExitNotTaken)
8729 if (ENT.ExitingBlock == ExitingBlock) {
8730 if (ENT.hasAlwaysTruePredicate())
8731 return &ENT;
8732 else if (Predicates) {
8733 append_range(*Predicates, ENT.Predicates);
8734 return &ENT;
8735 }
8736 }
8737
8738 return nullptr;
8739}
8740
8741/// getConstantMax - Get the constant max backedge taken count for the loop.
8742const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8743 ScalarEvolution *SE,
8744 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8745 if (!getConstantMax())
8746 return SE->getCouldNotCompute();
8747
8748 for (const auto &ENT : ExitNotTaken)
8749 if (!ENT.hasAlwaysTruePredicate()) {
8750 if (!Predicates)
8751 return SE->getCouldNotCompute();
8752 append_range(*Predicates, ENT.Predicates);
8753 }
8754
8755 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8756 isa<SCEVConstant>(getConstantMax())) &&
8757 "No point in having a non-constant max backedge taken count!");
8758 return getConstantMax();
8759}
8760
8761const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8762 const Loop *L, ScalarEvolution *SE,
8763 SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8764 if (!SymbolicMax) {
8765 // Form an expression for the maximum exit count possible for this loop. We
8766 // merge the max and exact information to approximate a version of
8767 // getConstantMaxBackedgeTakenCount which isn't restricted to just
8768 // constants.
8770
8771 for (const auto &ENT : ExitNotTaken) {
8772 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
8773 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
8774 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
8775 "We should only have known counts for exiting blocks that "
8776 "dominate latch!");
8777 ExitCounts.push_back(ExitCount);
8778 if (Predicates)
8779 append_range(*Predicates, ENT.Predicates);
8780
8781 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
8782 "Predicate should be always true!");
8783 }
8784 }
8785 if (ExitCounts.empty())
8786 SymbolicMax = SE->getCouldNotCompute();
8787 else
8788 SymbolicMax =
8789 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
8790 }
8791 return SymbolicMax;
8792}
8793
8794bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
8795 ScalarEvolution *SE) const {
8796 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8797 return !ENT.hasAlwaysTruePredicate();
8798 };
8799 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
8800}
8801
8804
8806 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8807 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8811 // If we prove the max count is zero, so is the symbolic bound. This happens
8812 // in practice due to differences in a) how context sensitive we've chosen
8813 // to be and b) how we reason about bounds implied by UB.
8814 if (ConstantMaxNotTaken->isZero()) {
8815 this->ExactNotTaken = E = ConstantMaxNotTaken;
8816 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
8817 }
8818
8821 "Exact is not allowed to be less precise than Constant Max");
8824 "Exact is not allowed to be less precise than Symbolic Max");
8827 "Symbolic Max is not allowed to be less precise than Constant Max");
8830 "No point in having a non-constant max backedge taken count!");
8832 for (const auto PredList : PredLists)
8833 for (const auto *P : PredList) {
8834 if (SeenPreds.contains(P))
8835 continue;
8836 assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
8837 SeenPreds.insert(P);
8838 Predicates.push_back(P);
8839 }
8840 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
8841 "Backedge count should be int");
8843 !ConstantMaxNotTaken->getType()->isPointerTy()) &&
8844 "Max backedge count should be int");
8845}
8846
8854
8855/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
8856/// computable exit into a persistent ExitNotTakenInfo array.
8857ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
8859 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
8860 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
8861 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8862
8863 ExitNotTaken.reserve(ExitCounts.size());
8864 std::transform(ExitCounts.begin(), ExitCounts.end(),
8865 std::back_inserter(ExitNotTaken),
8866 [&](const EdgeExitInfo &EEI) {
8867 BasicBlock *ExitBB = EEI.first;
8868 const ExitLimit &EL = EEI.second;
8869 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
8870 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
8871 EL.Predicates);
8872 });
8873 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
8874 isa<SCEVConstant>(ConstantMax)) &&
8875 "No point in having a non-constant max backedge taken count!");
8876}
8877
8878/// Compute the number of times the backedge of the specified loop will execute.
8879ScalarEvolution::BackedgeTakenInfo
8880ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
8881 bool AllowPredicates) {
8882 SmallVector<BasicBlock *, 8> ExitingBlocks;
8883 L->getExitingBlocks(ExitingBlocks);
8884
8885 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8886
8888 bool CouldComputeBECount = true;
8889 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
8890 const SCEV *MustExitMaxBECount = nullptr;
8891 const SCEV *MayExitMaxBECount = nullptr;
8892 bool MustExitMaxOrZero = false;
8893 bool IsOnlyExit = ExitingBlocks.size() == 1;
8894
8895 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
8896 // and compute maxBECount.
8897 // Do a union of all the predicates here.
8898 for (BasicBlock *ExitBB : ExitingBlocks) {
8899 // We canonicalize untaken exits to br (constant), ignore them so that
8900 // proving an exit untaken doesn't negatively impact our ability to reason
8901 // about the loop as whole.
8902 if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
8903 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
8904 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8905 if (ExitIfTrue == CI->isZero())
8906 continue;
8907 }
8908
8909 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
8910
8911 assert((AllowPredicates || EL.Predicates.empty()) &&
8912 "Predicated exit limit when predicates are not allowed!");
8913
8914 // 1. For each exit that can be computed, add an entry to ExitCounts.
8915 // CouldComputeBECount is true only if all exits can be computed.
8916 if (EL.ExactNotTaken != getCouldNotCompute())
8917 ++NumExitCountsComputed;
8918 else
8919 // We couldn't compute an exact value for this exit, so
8920 // we won't be able to compute an exact value for the loop.
8921 CouldComputeBECount = false;
8922 // Remember exit count if either exact or symbolic is known. Because
8923 // Exact always implies symbolic, only check symbolic.
8924 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
8925 ExitCounts.emplace_back(ExitBB, EL);
8926 else {
8927 assert(EL.ExactNotTaken == getCouldNotCompute() &&
8928 "Exact is known but symbolic isn't?");
8929 ++NumExitCountsNotComputed;
8930 }
8931
8932 // 2. Derive the loop's MaxBECount from each exit's max number of
8933 // non-exiting iterations. Partition the loop exits into two kinds:
8934 // LoopMustExits and LoopMayExits.
8935 //
8936 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
8937 // is a LoopMayExit. If any computable LoopMustExit is found, then
8938 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
8939 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
8940 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
8941 // any
8942 // computable EL.ConstantMaxNotTaken.
8943 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
8944 DT.dominates(ExitBB, Latch)) {
8945 if (!MustExitMaxBECount) {
8946 MustExitMaxBECount = EL.ConstantMaxNotTaken;
8947 MustExitMaxOrZero = EL.MaxOrZero;
8948 } else {
8949 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
8950 EL.ConstantMaxNotTaken);
8951 }
8952 } else if (MayExitMaxBECount != getCouldNotCompute()) {
8953 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
8954 MayExitMaxBECount = EL.ConstantMaxNotTaken;
8955 else {
8956 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
8957 EL.ConstantMaxNotTaken);
8958 }
8959 }
8960 }
8961 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
8962 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
8963 // The loop backedge will be taken the maximum or zero times if there's
8964 // a single exit that must be taken the maximum or zero times.
8965 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
8966
8967 // Remember which SCEVs are used in exit limits for invalidation purposes.
8968 // We only care about non-constant SCEVs here, so we can ignore
8969 // EL.ConstantMaxNotTaken
8970 // and MaxBECount, which must be SCEVConstant.
8971 for (const auto &Pair : ExitCounts) {
8972 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
8973 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
8974 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
8975 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
8976 {L, AllowPredicates});
8977 }
8978 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
8979 MaxBECount, MaxOrZero);
8980}
8981
8982ScalarEvolution::ExitLimit
8983ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
8984 bool IsOnlyExit, bool AllowPredicates) {
8985 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
8986 // If our exiting block does not dominate the latch, then its connection with
8987 // loop's exit limit may be far from trivial.
8988 const BasicBlock *Latch = L->getLoopLatch();
8989 if (!Latch || !DT.dominates(ExitingBlock, Latch))
8990 return getCouldNotCompute();
8991
8992 Instruction *Term = ExitingBlock->getTerminator();
8993 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
8994 assert(BI->isConditional() && "If unconditional, it can't be in loop!");
8995 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8996 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
8997 "It should have one successor in loop and one exit block!");
8998 // Proceed to the next level to examine the exit condition expression.
8999 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
9000 /*ControlsOnlyExit=*/IsOnlyExit,
9001 AllowPredicates);
9002 }
9003
9004 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
9005 // For switch, make sure that there is a single exit from the loop.
9006 BasicBlock *Exit = nullptr;
9007 for (auto *SBB : successors(ExitingBlock))
9008 if (!L->contains(SBB)) {
9009 if (Exit) // Multiple exit successors.
9010 return getCouldNotCompute();
9011 Exit = SBB;
9012 }
9013 assert(Exit && "Exiting block must have at least one exit");
9014 return computeExitLimitFromSingleExitSwitch(
9015 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
9016 }
9017
9018 return getCouldNotCompute();
9019}
9020
9022 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9023 bool AllowPredicates) {
9024 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
9025 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
9026 ControlsOnlyExit, AllowPredicates);
9027}
9028
9029std::optional<ScalarEvolution::ExitLimit>
9030ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
9031 bool ExitIfTrue, bool ControlsOnlyExit,
9032 bool AllowPredicates) {
9033 (void)this->L;
9034 (void)this->ExitIfTrue;
9035 (void)this->AllowPredicates;
9036
9037 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9038 this->AllowPredicates == AllowPredicates &&
9039 "Variance in assumed invariant key components!");
9040 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
9041 if (Itr == TripCountMap.end())
9042 return std::nullopt;
9043 return Itr->second;
9044}
9045
9046void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
9047 bool ExitIfTrue,
9048 bool ControlsOnlyExit,
9049 bool AllowPredicates,
9050 const ExitLimit &EL) {
9051 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9052 this->AllowPredicates == AllowPredicates &&
9053 "Variance in assumed invariant key components!");
9054
9055 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
9056 assert(InsertResult.second && "Expected successful insertion!");
9057 (void)InsertResult;
9058 (void)ExitIfTrue;
9059}
9060
9061ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
9062 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9063 bool ControlsOnlyExit, bool AllowPredicates) {
9064
9065 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
9066 AllowPredicates))
9067 return *MaybeEL;
9068
9069 ExitLimit EL = computeExitLimitFromCondImpl(
9070 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
9071 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
9072 return EL;
9073}
9074
9075ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
9076 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9077 bool ControlsOnlyExit, bool AllowPredicates) {
9078 // Handle BinOp conditions (And, Or).
9079 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
9080 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
9081 return *LimitFromBinOp;
9082
9083 // With an icmp, it may be feasible to compute an exact backedge-taken count.
9084 // Proceed to the next level to examine the icmp.
9085 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
9086 ExitLimit EL =
9087 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
9088 if (EL.hasFullInfo() || !AllowPredicates)
9089 return EL;
9090
9091 // Try again, but use SCEV predicates this time.
9092 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
9093 ControlsOnlyExit,
9094 /*AllowPredicates=*/true);
9095 }
9096
9097 // Check for a constant condition. These are normally stripped out by
9098 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
9099 // preserve the CFG and is temporarily leaving constant conditions
9100 // in place.
9101 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
9102 if (ExitIfTrue == !CI->getZExtValue())
9103 // The backedge is always taken.
9104 return getCouldNotCompute();
9105 // The backedge is never taken.
9106 return getZero(CI->getType());
9107 }
9108
9109 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9110 // with a constant step, we can form an equivalent icmp predicate and figure
9111 // out how many iterations will be taken before we exit.
9112 const WithOverflowInst *WO;
9113 const APInt *C;
9114 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9115 match(WO->getRHS(), m_APInt(C))) {
9116 ConstantRange NWR =
9118 WO->getNoWrapKind());
9119 CmpInst::Predicate Pred;
9120 APInt NewRHSC, Offset;
9121 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9122 if (!ExitIfTrue)
9123 Pred = ICmpInst::getInversePredicate(Pred);
9124 auto *LHS = getSCEV(WO->getLHS());
9125 if (Offset != 0)
9127 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9128 ControlsOnlyExit, AllowPredicates);
9129 if (EL.hasAnyInfo())
9130 return EL;
9131 }
9132
9133 // If it's not an integer or pointer comparison then compute it the hard way.
9134 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9135}
9136
9137std::optional<ScalarEvolution::ExitLimit>
9138ScalarEvolution::computeExitLimitFromCondFromBinOp(
9139 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9140 bool ControlsOnlyExit, bool AllowPredicates) {
9141 // Check if the controlling expression for this loop is an And or Or.
9142 Value *Op0, *Op1;
9143 bool IsAnd = false;
9144 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9145 IsAnd = true;
9146 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9147 IsAnd = false;
9148 else
9149 return std::nullopt;
9150
9151 // EitherMayExit is true in these two cases:
9152 // br (and Op0 Op1), loop, exit
9153 // br (or Op0 Op1), exit, loop
9154 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9155 ExitLimit EL0 = computeExitLimitFromCondCached(
9156 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9157 AllowPredicates);
9158 ExitLimit EL1 = computeExitLimitFromCondCached(
9159 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9160 AllowPredicates);
9161
9162 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
9163 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9164 if (isa<ConstantInt>(Op1))
9165 return Op1 == NeutralElement ? EL0 : EL1;
9166 if (isa<ConstantInt>(Op0))
9167 return Op0 == NeutralElement ? EL1 : EL0;
9168
9169 const SCEV *BECount = getCouldNotCompute();
9170 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9171 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9172 if (EitherMayExit) {
9173 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9174 // Both conditions must be same for the loop to continue executing.
9175 // Choose the less conservative count.
9176 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9177 EL1.ExactNotTaken != getCouldNotCompute()) {
9178 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9179 UseSequentialUMin);
9180 }
9181 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9182 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9183 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9184 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9185 else
9186 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9187 EL1.ConstantMaxNotTaken);
9188 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9189 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9190 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9191 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9192 else
9193 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9194 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9195 } else {
9196 // Both conditions must be same at the same time for the loop to exit.
9197 // For now, be conservative.
9198 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9199 BECount = EL0.ExactNotTaken;
9200 }
9201
9202 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9203 // to be more aggressive when computing BECount than when computing
9204 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9205 // and
9206 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9207 // EL1.ConstantMaxNotTaken to not.
9208 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9209 !isa<SCEVCouldNotCompute>(BECount))
9210 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9211 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9212 SymbolicMaxBECount =
9213 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9214 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9215 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
9216}
9217
9218ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9219 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9220 bool AllowPredicates) {
9221 // If the condition was exit on true, convert the condition to exit on false
9222 CmpPredicate Pred;
9223 if (!ExitIfTrue)
9224 Pred = ExitCond->getCmpPredicate();
9225 else
9226 Pred = ExitCond->getInverseCmpPredicate();
9227 const ICmpInst::Predicate OriginalPred = Pred;
9228
9229 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9230 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9231
9232 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9233 AllowPredicates);
9234 if (EL.hasAnyInfo())
9235 return EL;
9236
9237 auto *ExhaustiveCount =
9238 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9239
9240 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9241 return ExhaustiveCount;
9242
9243 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9244 ExitCond->getOperand(1), L, OriginalPred);
9245}
9246ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9247 const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS,
9248 bool ControlsOnlyExit, bool AllowPredicates) {
9249
9250 // Try to evaluate any dependencies out of the loop.
9251 LHS = getSCEVAtScope(LHS, L);
9252 RHS = getSCEVAtScope(RHS, L);
9253
9254 // At this point, we would like to compute how many iterations of the
9255 // loop the predicate will return true for these inputs.
9256 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9257 // If there is a loop-invariant, force it into the RHS.
9258 std::swap(LHS, RHS);
9260 }
9261
9262 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9264 // Simplify the operands before analyzing them.
9265 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9266
9267 // If we have a comparison of a chrec against a constant, try to use value
9268 // ranges to answer this query.
9269 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9270 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9271 if (AddRec->getLoop() == L) {
9272 // Form the constant range.
9273 ConstantRange CompRange =
9274 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9275
9276 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9277 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9278 }
9279
9280 // If this loop must exit based on this condition (or execute undefined
9281 // behaviour), see if we can improve wrap flags. This is essentially
9282 // a must execute style proof.
9283 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9284 // If we can prove the test sequence produced must repeat the same values
9285 // on self-wrap of the IV, then we can infer that IV doesn't self wrap
9286 // because if it did, we'd have an infinite (undefined) loop.
9287 // TODO: We can peel off any functions which are invertible *in L*. Loop
9288 // invariant terms are effectively constants for our purposes here.
9289 auto *InnerLHS = LHS;
9290 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9291 InnerLHS = ZExt->getOperand();
9292 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9293 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9294 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true,
9295 /*OrNegative=*/true)) {
9296 auto Flags = AR->getNoWrapFlags();
9297 Flags = setFlags(Flags, SCEV::FlagNW);
9298 SmallVector<const SCEV *> Operands{AR->operands()};
9299 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9300 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9301 }
9302
9303 // For a slt/ult condition with a positive step, can we prove nsw/nuw?
9304 // From no-self-wrap, this follows trivially from the fact that every
9305 // (un)signed-wrapped, but not self-wrapped value must be LT than the
9306 // last value before (un)signed wrap. Since we know that last value
9307 // didn't exit, nor will any smaller one.
9308 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) {
9309 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW;
9310 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9311 AR && AR->getLoop() == L && AR->isAffine() &&
9312 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() &&
9313 isKnownPositive(AR->getStepRecurrence(*this))) {
9314 auto Flags = AR->getNoWrapFlags();
9315 Flags = setFlags(Flags, WrapType);
9316 SmallVector<const SCEV*> Operands{AR->operands()};
9317 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9318 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9319 }
9320 }
9321 }
9322
9323 switch (Pred) {
9324 case ICmpInst::ICMP_NE: { // while (X != Y)
9325 // Convert to: while (X-Y != 0)
9326 if (LHS->getType()->isPointerTy()) {
9329 return LHS;
9330 }
9331 if (RHS->getType()->isPointerTy()) {
9334 return RHS;
9335 }
9336 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9337 AllowPredicates);
9338 if (EL.hasAnyInfo())
9339 return EL;
9340 break;
9341 }
9342 case ICmpInst::ICMP_EQ: { // while (X == Y)
9343 // Convert to: while (X-Y == 0)
9344 if (LHS->getType()->isPointerTy()) {
9347 return LHS;
9348 }
9349 if (RHS->getType()->isPointerTy()) {
9352 return RHS;
9353 }
9354 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9355 if (EL.hasAnyInfo()) return EL;
9356 break;
9357 }
9358 case ICmpInst::ICMP_SLE:
9359 case ICmpInst::ICMP_ULE:
9360 // Since the loop is finite, an invariant RHS cannot include the boundary
9361 // value, otherwise it would loop forever.
9362 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9363 !isLoopInvariant(RHS, L)) {
9364 // Otherwise, perform the addition in a wider type, to avoid overflow.
9365 // If the LHS is an addrec with the appropriate nowrap flag, the
9366 // extension will be sunk into it and the exit count can be analyzed.
9367 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9368 if (!OldType)
9369 break;
9370 // Prefer doubling the bitwidth over adding a single bit to make it more
9371 // likely that we use a legal type.
9372 auto *NewType =
9373 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9374 if (ICmpInst::isSigned(Pred)) {
9375 LHS = getSignExtendExpr(LHS, NewType);
9376 RHS = getSignExtendExpr(RHS, NewType);
9377 } else {
9378 LHS = getZeroExtendExpr(LHS, NewType);
9379 RHS = getZeroExtendExpr(RHS, NewType);
9380 }
9381 }
9383 [[fallthrough]];
9384 case ICmpInst::ICMP_SLT:
9385 case ICmpInst::ICMP_ULT: { // while (X < Y)
9386 bool IsSigned = ICmpInst::isSigned(Pred);
9387 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9388 AllowPredicates);
9389 if (EL.hasAnyInfo())
9390 return EL;
9391 break;
9392 }
9393 case ICmpInst::ICMP_SGE:
9394 case ICmpInst::ICMP_UGE:
9395 // Since the loop is finite, an invariant RHS cannot include the boundary
9396 // value, otherwise it would loop forever.
9397 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9398 !isLoopInvariant(RHS, L))
9399 break;
9401 [[fallthrough]];
9402 case ICmpInst::ICMP_SGT:
9403 case ICmpInst::ICMP_UGT: { // while (X > Y)
9404 bool IsSigned = ICmpInst::isSigned(Pred);
9405 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9406 AllowPredicates);
9407 if (EL.hasAnyInfo())
9408 return EL;
9409 break;
9410 }
9411 default:
9412 break;
9413 }
9414
9415 return getCouldNotCompute();
9416}
9417
9418ScalarEvolution::ExitLimit
9419ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9420 SwitchInst *Switch,
9421 BasicBlock *ExitingBlock,
9422 bool ControlsOnlyExit) {
9423 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9424
9425 // Give up if the exit is the default dest of a switch.
9426 if (Switch->getDefaultDest() == ExitingBlock)
9427 return getCouldNotCompute();
9428
9429 assert(L->contains(Switch->getDefaultDest()) &&
9430 "Default case must not exit the loop!");
9431 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9432 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9433
9434 // while (X != Y) --> while (X-Y != 0)
9435 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9436 if (EL.hasAnyInfo())
9437 return EL;
9438
9439 return getCouldNotCompute();
9440}
9441
9442static ConstantInt *
9444 ScalarEvolution &SE) {
9445 const SCEV *InVal = SE.getConstant(C);
9446 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9448 "Evaluation of SCEV at constant didn't fold correctly?");
9449 return cast<SCEVConstant>(Val)->getValue();
9450}
9451
9452ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9453 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9454 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9455 if (!RHS)
9456 return getCouldNotCompute();
9457
9458 const BasicBlock *Latch = L->getLoopLatch();
9459 if (!Latch)
9460 return getCouldNotCompute();
9461
9462 const BasicBlock *Predecessor = L->getLoopPredecessor();
9463 if (!Predecessor)
9464 return getCouldNotCompute();
9465
9466 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9467 // Return LHS in OutLHS and shift_opt in OutOpCode.
9468 auto MatchPositiveShift =
9469 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9470
9471 using namespace PatternMatch;
9472
9473 ConstantInt *ShiftAmt;
9474 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9475 OutOpCode = Instruction::LShr;
9476 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9477 OutOpCode = Instruction::AShr;
9478 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9479 OutOpCode = Instruction::Shl;
9480 else
9481 return false;
9482
9483 return ShiftAmt->getValue().isStrictlyPositive();
9484 };
9485
9486 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9487 //
9488 // loop:
9489 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9490 // %iv.shifted = lshr i32 %iv, <positive constant>
9491 //
9492 // Return true on a successful match. Return the corresponding PHI node (%iv
9493 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9494 auto MatchShiftRecurrence =
9495 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9496 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9497
9498 {
9500 Value *V;
9501
9502 // If we encounter a shift instruction, "peel off" the shift operation,
9503 // and remember that we did so. Later when we inspect %iv's backedge
9504 // value, we will make sure that the backedge value uses the same
9505 // operation.
9506 //
9507 // Note: the peeled shift operation does not have to be the same
9508 // instruction as the one feeding into the PHI's backedge value. We only
9509 // really care about it being the same *kind* of shift instruction --
9510 // that's all that is required for our later inferences to hold.
9511 if (MatchPositiveShift(LHS, V, OpC)) {
9512 PostShiftOpCode = OpC;
9513 LHS = V;
9514 }
9515 }
9516
9517 PNOut = dyn_cast<PHINode>(LHS);
9518 if (!PNOut || PNOut->getParent() != L->getHeader())
9519 return false;
9520
9521 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9522 Value *OpLHS;
9523
9524 return
9525 // The backedge value for the PHI node must be a shift by a positive
9526 // amount
9527 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9528
9529 // of the PHI node itself
9530 OpLHS == PNOut &&
9531
9532 // and the kind of shift should be match the kind of shift we peeled
9533 // off, if any.
9534 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9535 };
9536
9537 PHINode *PN;
9539 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9540 return getCouldNotCompute();
9541
9542 const DataLayout &DL = getDataLayout();
9543
9544 // The key rationale for this optimization is that for some kinds of shift
9545 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9546 // within a finite number of iterations. If the condition guarding the
9547 // backedge (in the sense that the backedge is taken if the condition is true)
9548 // is false for the value the shift recurrence stabilizes to, then we know
9549 // that the backedge is taken only a finite number of times.
9550
9551 ConstantInt *StableValue = nullptr;
9552 switch (OpCode) {
9553 default:
9554 llvm_unreachable("Impossible case!");
9555
9556 case Instruction::AShr: {
9557 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9558 // bitwidth(K) iterations.
9559 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9560 KnownBits Known = computeKnownBits(FirstValue, DL, &AC,
9561 Predecessor->getTerminator(), &DT);
9562 auto *Ty = cast<IntegerType>(RHS->getType());
9563 if (Known.isNonNegative())
9564 StableValue = ConstantInt::get(Ty, 0);
9565 else if (Known.isNegative())
9566 StableValue = ConstantInt::get(Ty, -1, true);
9567 else
9568 return getCouldNotCompute();
9569
9570 break;
9571 }
9572 case Instruction::LShr:
9573 case Instruction::Shl:
9574 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9575 // stabilize to 0 in at most bitwidth(K) iterations.
9576 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9577 break;
9578 }
9579
9580 auto *Result =
9581 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9582 assert(Result->getType()->isIntegerTy(1) &&
9583 "Otherwise cannot be an operand to a branch instruction");
9584
9585 if (Result->isZeroValue()) {
9586 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9587 const SCEV *UpperBound =
9589 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9590 }
9591
9592 return getCouldNotCompute();
9593}
9594
9595/// Return true if we can constant fold an instruction of the specified type,
9596/// assuming that all operands were constants.
9597static bool CanConstantFold(const Instruction *I) {
9601 return true;
9602
9603 if (const CallInst *CI = dyn_cast<CallInst>(I))
9604 if (const Function *F = CI->getCalledFunction())
9605 return canConstantFoldCallTo(CI, F);
9606 return false;
9607}
9608
9609/// Determine whether this instruction can constant evolve within this loop
9610/// assuming its operands can all constant evolve.
9611static bool canConstantEvolve(Instruction *I, const Loop *L) {
9612 // An instruction outside of the loop can't be derived from a loop PHI.
9613 if (!L->contains(I)) return false;
9614
9615 if (isa<PHINode>(I)) {
9616 // We don't currently keep track of the control flow needed to evaluate
9617 // PHIs, so we cannot handle PHIs inside of loops.
9618 return L->getHeader() == I->getParent();
9619 }
9620
9621 // If we won't be able to constant fold this expression even if the operands
9622 // are constants, bail early.
9623 return CanConstantFold(I);
9624}
9625
9626/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9627/// recursing through each instruction operand until reaching a loop header phi.
9628static PHINode *
9631 unsigned Depth) {
9633 return nullptr;
9634
9635 // Otherwise, we can evaluate this instruction if all of its operands are
9636 // constant or derived from a PHI node themselves.
9637 PHINode *PHI = nullptr;
9638 for (Value *Op : UseInst->operands()) {
9639 if (isa<Constant>(Op)) continue;
9640
9642 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9643
9644 PHINode *P = dyn_cast<PHINode>(OpInst);
9645 if (!P)
9646 // If this operand is already visited, reuse the prior result.
9647 // We may have P != PHI if this is the deepest point at which the
9648 // inconsistent paths meet.
9649 P = PHIMap.lookup(OpInst);
9650 if (!P) {
9651 // Recurse and memoize the results, whether a phi is found or not.
9652 // This recursive call invalidates pointers into PHIMap.
9653 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9654 PHIMap[OpInst] = P;
9655 }
9656 if (!P)
9657 return nullptr; // Not evolving from PHI
9658 if (PHI && PHI != P)
9659 return nullptr; // Evolving from multiple different PHIs.
9660 PHI = P;
9661 }
9662 // This is a expression evolving from a constant PHI!
9663 return PHI;
9664}
9665
9666/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9667/// in the loop that V is derived from. We allow arbitrary operations along the
9668/// way, but the operands of an operation must either be constants or a value
9669/// derived from a constant PHI. If this expression does not fit with these
9670/// constraints, return null.
9673 if (!I || !canConstantEvolve(I, L)) return nullptr;
9674
9675 if (PHINode *PN = dyn_cast<PHINode>(I))
9676 return PN;
9677
9678 // Record non-constant instructions contained by the loop.
9680 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9681}
9682
9683/// EvaluateExpression - Given an expression that passes the
9684/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9685/// in the loop has the value PHIVal. If we can't fold this expression for some
9686/// reason, return null.
9689 const DataLayout &DL,
9690 const TargetLibraryInfo *TLI) {
9691 // Convenient constant check, but redundant for recursive calls.
9692 if (Constant *C = dyn_cast<Constant>(V)) return C;
9694 if (!I) return nullptr;
9695
9696 if (Constant *C = Vals.lookup(I)) return C;
9697
9698 // An instruction inside the loop depends on a value outside the loop that we
9699 // weren't given a mapping for, or a value such as a call inside the loop.
9700 if (!canConstantEvolve(I, L)) return nullptr;
9701
9702 // An unmapped PHI can be due to a branch or another loop inside this loop,
9703 // or due to this not being the initial iteration through a loop where we
9704 // couldn't compute the evolution of this particular PHI last time.
9705 if (isa<PHINode>(I)) return nullptr;
9706
9707 std::vector<Constant*> Operands(I->getNumOperands());
9708
9709 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9710 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9711 if (!Operand) {
9712 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9713 if (!Operands[i]) return nullptr;
9714 continue;
9715 }
9716 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9717 Vals[Operand] = C;
9718 if (!C) return nullptr;
9719 Operands[i] = C;
9720 }
9721
9722 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9723 /*AllowNonDeterministic=*/false);
9724}
9725
9726
9727// If every incoming value to PN except the one for BB is a specific Constant,
9728// return that, else return nullptr.
9730 Constant *IncomingVal = nullptr;
9731
9732 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9733 if (PN->getIncomingBlock(i) == BB)
9734 continue;
9735
9736 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9737 if (!CurrentVal)
9738 return nullptr;
9739
9740 if (IncomingVal != CurrentVal) {
9741 if (IncomingVal)
9742 return nullptr;
9743 IncomingVal = CurrentVal;
9744 }
9745 }
9746
9747 return IncomingVal;
9748}
9749
9750/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9751/// in the header of its containing loop, we know the loop executes a
9752/// constant number of times, and the PHI node is just a recurrence
9753/// involving constants, fold it.
9754Constant *
9755ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9756 const APInt &BEs,
9757 const Loop *L) {
9758 auto [I, Inserted] = ConstantEvolutionLoopExitValue.try_emplace(PN);
9759 if (!Inserted)
9760 return I->second;
9761
9763 return nullptr; // Not going to evaluate it.
9764
9765 Constant *&RetVal = I->second;
9766
9767 DenseMap<Instruction *, Constant *> CurrentIterVals;
9768 BasicBlock *Header = L->getHeader();
9769 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9770
9771 BasicBlock *Latch = L->getLoopLatch();
9772 if (!Latch)
9773 return nullptr;
9774
9775 for (PHINode &PHI : Header->phis()) {
9776 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9777 CurrentIterVals[&PHI] = StartCST;
9778 }
9779 if (!CurrentIterVals.count(PN))
9780 return RetVal = nullptr;
9781
9782 Value *BEValue = PN->getIncomingValueForBlock(Latch);
9783
9784 // Execute the loop symbolically to determine the exit value.
9785 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
9786 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
9787
9788 unsigned NumIterations = BEs.getZExtValue(); // must be in range
9789 unsigned IterationNum = 0;
9790 const DataLayout &DL = getDataLayout();
9791 for (; ; ++IterationNum) {
9792 if (IterationNum == NumIterations)
9793 return RetVal = CurrentIterVals[PN]; // Got exit value!
9794
9795 // Compute the value of the PHIs for the next iteration.
9796 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
9797 DenseMap<Instruction *, Constant *> NextIterVals;
9798 Constant *NextPHI =
9799 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9800 if (!NextPHI)
9801 return nullptr; // Couldn't evaluate!
9802 NextIterVals[PN] = NextPHI;
9803
9804 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
9805
9806 // Also evaluate the other PHI nodes. However, we don't get to stop if we
9807 // cease to be able to evaluate one of them or if they stop evolving,
9808 // because that doesn't necessarily prevent us from computing PN.
9810 for (const auto &I : CurrentIterVals) {
9811 PHINode *PHI = dyn_cast<PHINode>(I.first);
9812 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
9813 PHIsToCompute.emplace_back(PHI, I.second);
9814 }
9815 // We use two distinct loops because EvaluateExpression may invalidate any
9816 // iterators into CurrentIterVals.
9817 for (const auto &I : PHIsToCompute) {
9818 PHINode *PHI = I.first;
9819 Constant *&NextPHI = NextIterVals[PHI];
9820 if (!NextPHI) { // Not already computed.
9821 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9822 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9823 }
9824 if (NextPHI != I.second)
9825 StoppedEvolving = false;
9826 }
9827
9828 // If all entries in CurrentIterVals == NextIterVals then we can stop
9829 // iterating, the loop can't continue to change.
9830 if (StoppedEvolving)
9831 return RetVal = CurrentIterVals[PN];
9832
9833 CurrentIterVals.swap(NextIterVals);
9834 }
9835}
9836
9837const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
9838 Value *Cond,
9839 bool ExitWhen) {
9840 PHINode *PN = getConstantEvolvingPHI(Cond, L);
9841 if (!PN) return getCouldNotCompute();
9842
9843 // If the loop is canonicalized, the PHI will have exactly two entries.
9844 // That's the only form we support here.
9845 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
9846
9847 DenseMap<Instruction *, Constant *> CurrentIterVals;
9848 BasicBlock *Header = L->getHeader();
9849 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9850
9851 BasicBlock *Latch = L->getLoopLatch();
9852 assert(Latch && "Should follow from NumIncomingValues == 2!");
9853
9854 for (PHINode &PHI : Header->phis()) {
9855 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9856 CurrentIterVals[&PHI] = StartCST;
9857 }
9858 if (!CurrentIterVals.count(PN))
9859 return getCouldNotCompute();
9860
9861 // Okay, we find a PHI node that defines the trip count of this loop. Execute
9862 // the loop symbolically to determine when the condition gets a value of
9863 // "ExitWhen".
9864 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
9865 const DataLayout &DL = getDataLayout();
9866 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
9867 auto *CondVal = dyn_cast_or_null<ConstantInt>(
9868 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
9869
9870 // Couldn't symbolically evaluate.
9871 if (!CondVal) return getCouldNotCompute();
9872
9873 if (CondVal->getValue() == uint64_t(ExitWhen)) {
9874 ++NumBruteForceTripCountsComputed;
9875 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
9876 }
9877
9878 // Update all the PHI nodes for the next iteration.
9879 DenseMap<Instruction *, Constant *> NextIterVals;
9880
9881 // Create a list of which PHIs we need to compute. We want to do this before
9882 // calling EvaluateExpression on them because that may invalidate iterators
9883 // into CurrentIterVals.
9884 SmallVector<PHINode *, 8> PHIsToCompute;
9885 for (const auto &I : CurrentIterVals) {
9886 PHINode *PHI = dyn_cast<PHINode>(I.first);
9887 if (!PHI || PHI->getParent() != Header) continue;
9888 PHIsToCompute.push_back(PHI);
9889 }
9890 for (PHINode *PHI : PHIsToCompute) {
9891 Constant *&NextPHI = NextIterVals[PHI];
9892 if (NextPHI) continue; // Already computed!
9893
9894 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9895 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9896 }
9897 CurrentIterVals.swap(NextIterVals);
9898 }
9899
9900 // Too many iterations were needed to evaluate.
9901 return getCouldNotCompute();
9902}
9903
9904const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
9906 ValuesAtScopes[V];
9907 // Check to see if we've folded this expression at this loop before.
9908 for (auto &LS : Values)
9909 if (LS.first == L)
9910 return LS.second ? LS.second : V;
9911
9912 Values.emplace_back(L, nullptr);
9913
9914 // Otherwise compute it.
9915 const SCEV *C = computeSCEVAtScope(V, L);
9916 for (auto &LS : reverse(ValuesAtScopes[V]))
9917 if (LS.first == L) {
9918 LS.second = C;
9919 if (!isa<SCEVConstant>(C))
9920 ValuesAtScopesUsers[C].push_back({L, V});
9921 break;
9922 }
9923 return C;
9924}
9925
9926/// This builds up a Constant using the ConstantExpr interface. That way, we
9927/// will return Constants for objects which aren't represented by a
9928/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
9929/// Returns NULL if the SCEV isn't representable as a Constant.
9931 switch (V->getSCEVType()) {
9932 case scCouldNotCompute:
9933 case scAddRecExpr:
9934 case scVScale:
9935 return nullptr;
9936 case scConstant:
9937 return cast<SCEVConstant>(V)->getValue();
9938 case scUnknown:
9939 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
9940 case scPtrToInt: {
9942 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
9943 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
9944
9945 return nullptr;
9946 }
9947 case scTruncate: {
9949 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
9950 return ConstantExpr::getTrunc(CastOp, ST->getType());
9951 return nullptr;
9952 }
9953 case scAddExpr: {
9954 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
9955 Constant *C = nullptr;
9956 for (const SCEV *Op : SA->operands()) {
9958 if (!OpC)
9959 return nullptr;
9960 if (!C) {
9961 C = OpC;
9962 continue;
9963 }
9964 assert(!C->getType()->isPointerTy() &&
9965 "Can only have one pointer, and it must be last");
9966 if (OpC->getType()->isPointerTy()) {
9967 // The offsets have been converted to bytes. We can add bytes using
9968 // an i8 GEP.
9970 OpC, C);
9971 } else {
9972 C = ConstantExpr::getAdd(C, OpC);
9973 }
9974 }
9975 return C;
9976 }
9977 case scMulExpr:
9978 case scSignExtend:
9979 case scZeroExtend:
9980 case scUDivExpr:
9981 case scSMaxExpr:
9982 case scUMaxExpr:
9983 case scSMinExpr:
9984 case scUMinExpr:
9986 return nullptr;
9987 }
9988 llvm_unreachable("Unknown SCEV kind!");
9989}
9990
9991const SCEV *
9992ScalarEvolution::getWithOperands(const SCEV *S,
9993 SmallVectorImpl<const SCEV *> &NewOps) {
9994 switch (S->getSCEVType()) {
9995 case scTruncate:
9996 case scZeroExtend:
9997 case scSignExtend:
9998 case scPtrToInt:
9999 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
10000 case scAddRecExpr: {
10001 auto *AddRec = cast<SCEVAddRecExpr>(S);
10002 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
10003 }
10004 case scAddExpr:
10005 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
10006 case scMulExpr:
10007 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
10008 case scUDivExpr:
10009 return getUDivExpr(NewOps[0], NewOps[1]);
10010 case scUMaxExpr:
10011 case scSMaxExpr:
10012 case scUMinExpr:
10013 case scSMinExpr:
10014 return getMinMaxExpr(S->getSCEVType(), NewOps);
10016 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
10017 case scConstant:
10018 case scVScale:
10019 case scUnknown:
10020 return S;
10021 case scCouldNotCompute:
10022 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10023 }
10024 llvm_unreachable("Unknown SCEV kind!");
10025}
10026
10027const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
10028 switch (V->getSCEVType()) {
10029 case scConstant:
10030 case scVScale:
10031 return V;
10032 case scAddRecExpr: {
10033 // If this is a loop recurrence for a loop that does not contain L, then we
10034 // are dealing with the final value computed by the loop.
10035 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
10036 // First, attempt to evaluate each operand.
10037 // Avoid performing the look-up in the common case where the specified
10038 // expression has no loop-variant portions.
10039 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
10040 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
10041 if (OpAtScope == AddRec->getOperand(i))
10042 continue;
10043
10044 // Okay, at least one of these operands is loop variant but might be
10045 // foldable. Build a new instance of the folded commutative expression.
10047 NewOps.reserve(AddRec->getNumOperands());
10048 append_range(NewOps, AddRec->operands().take_front(i));
10049 NewOps.push_back(OpAtScope);
10050 for (++i; i != e; ++i)
10051 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
10052
10053 const SCEV *FoldedRec = getAddRecExpr(
10054 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
10055 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
10056 // The addrec may be folded to a nonrecurrence, for example, if the
10057 // induction variable is multiplied by zero after constant folding. Go
10058 // ahead and return the folded value.
10059 if (!AddRec)
10060 return FoldedRec;
10061 break;
10062 }
10063
10064 // If the scope is outside the addrec's loop, evaluate it by using the
10065 // loop exit value of the addrec.
10066 if (!AddRec->getLoop()->contains(L)) {
10067 // To evaluate this recurrence, we need to know how many times the AddRec
10068 // loop iterates. Compute this now.
10069 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
10070 if (BackedgeTakenCount == getCouldNotCompute())
10071 return AddRec;
10072
10073 // Then, evaluate the AddRec.
10074 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
10075 }
10076
10077 return AddRec;
10078 }
10079 case scTruncate:
10080 case scZeroExtend:
10081 case scSignExtend:
10082 case scPtrToInt:
10083 case scAddExpr:
10084 case scMulExpr:
10085 case scUDivExpr:
10086 case scUMaxExpr:
10087 case scSMaxExpr:
10088 case scUMinExpr:
10089 case scSMinExpr:
10090 case scSequentialUMinExpr: {
10091 ArrayRef<const SCEV *> Ops = V->operands();
10092 // Avoid performing the look-up in the common case where the specified
10093 // expression has no loop-variant portions.
10094 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
10095 const SCEV *OpAtScope = getSCEVAtScope(Ops[i], L);
10096 if (OpAtScope != Ops[i]) {
10097 // Okay, at least one of these operands is loop variant but might be
10098 // foldable. Build a new instance of the folded commutative expression.
10100 NewOps.reserve(Ops.size());
10101 append_range(NewOps, Ops.take_front(i));
10102 NewOps.push_back(OpAtScope);
10103
10104 for (++i; i != e; ++i) {
10105 OpAtScope = getSCEVAtScope(Ops[i], L);
10106 NewOps.push_back(OpAtScope);
10107 }
10108
10109 return getWithOperands(V, NewOps);
10110 }
10111 }
10112 // If we got here, all operands are loop invariant.
10113 return V;
10114 }
10115 case scUnknown: {
10116 // If this instruction is evolved from a constant-evolving PHI, compute the
10117 // exit value from the loop without using SCEVs.
10118 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
10120 if (!I)
10121 return V; // This is some other type of SCEVUnknown, just return it.
10122
10123 if (PHINode *PN = dyn_cast<PHINode>(I)) {
10124 const Loop *CurrLoop = this->LI[I->getParent()];
10125 // Looking for loop exit value.
10126 if (CurrLoop && CurrLoop->getParentLoop() == L &&
10127 PN->getParent() == CurrLoop->getHeader()) {
10128 // Okay, there is no closed form solution for the PHI node. Check
10129 // to see if the loop that contains it has a known backedge-taken
10130 // count. If so, we may be able to force computation of the exit
10131 // value.
10132 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10133 // This trivial case can show up in some degenerate cases where
10134 // the incoming IR has not yet been fully simplified.
10135 if (BackedgeTakenCount->isZero()) {
10136 Value *InitValue = nullptr;
10137 bool MultipleInitValues = false;
10138 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10139 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10140 if (!InitValue)
10141 InitValue = PN->getIncomingValue(i);
10142 else if (InitValue != PN->getIncomingValue(i)) {
10143 MultipleInitValues = true;
10144 break;
10145 }
10146 }
10147 }
10148 if (!MultipleInitValues && InitValue)
10149 return getSCEV(InitValue);
10150 }
10151 // Do we have a loop invariant value flowing around the backedge
10152 // for a loop which must execute the backedge?
10153 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10154 isKnownNonZero(BackedgeTakenCount) &&
10155 PN->getNumIncomingValues() == 2) {
10156
10157 unsigned InLoopPred =
10158 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10159 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10160 if (CurrLoop->isLoopInvariant(BackedgeVal))
10161 return getSCEV(BackedgeVal);
10162 }
10163 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10164 // Okay, we know how many times the containing loop executes. If
10165 // this is a constant evolving PHI node, get the final value at
10166 // the specified iteration number.
10167 Constant *RV =
10168 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10169 if (RV)
10170 return getSCEV(RV);
10171 }
10172 }
10173 }
10174
10175 // Okay, this is an expression that we cannot symbolically evaluate
10176 // into a SCEV. Check to see if it's possible to symbolically evaluate
10177 // the arguments into constants, and if so, try to constant propagate the
10178 // result. This is particularly useful for computing loop exit values.
10179 if (!CanConstantFold(I))
10180 return V; // This is some other type of SCEVUnknown, just return it.
10181
10182 SmallVector<Constant *, 4> Operands;
10183 Operands.reserve(I->getNumOperands());
10184 bool MadeImprovement = false;
10185 for (Value *Op : I->operands()) {
10186 if (Constant *C = dyn_cast<Constant>(Op)) {
10187 Operands.push_back(C);
10188 continue;
10189 }
10190
10191 // If any of the operands is non-constant and if they are
10192 // non-integer and non-pointer, don't even try to analyze them
10193 // with scev techniques.
10194 if (!isSCEVable(Op->getType()))
10195 return V;
10196
10197 const SCEV *OrigV = getSCEV(Op);
10198 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10199 MadeImprovement |= OrigV != OpV;
10200
10202 if (!C)
10203 return V;
10204 assert(C->getType() == Op->getType() && "Type mismatch");
10205 Operands.push_back(C);
10206 }
10207
10208 // Check to see if getSCEVAtScope actually made an improvement.
10209 if (!MadeImprovement)
10210 return V; // This is some other type of SCEVUnknown, just return it.
10211
10212 Constant *C = nullptr;
10213 const DataLayout &DL = getDataLayout();
10214 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10215 /*AllowNonDeterministic=*/false);
10216 if (!C)
10217 return V;
10218 return getSCEV(C);
10219 }
10220 case scCouldNotCompute:
10221 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10222 }
10223 llvm_unreachable("Unknown SCEV type!");
10224}
10225
10227 return getSCEVAtScope(getSCEV(V), L);
10228}
10229
10230const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10232 return stripInjectiveFunctions(ZExt->getOperand());
10234 return stripInjectiveFunctions(SExt->getOperand());
10235 return S;
10236}
10237
10238/// Finds the minimum unsigned root of the following equation:
10239///
10240/// A * X = B (mod N)
10241///
10242/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10243/// A and B isn't important.
10244///
10245/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10246static const SCEV *
10249 ScalarEvolution &SE, const Loop *L) {
10250 uint32_t BW = A.getBitWidth();
10251 assert(BW == SE.getTypeSizeInBits(B->getType()));
10252 assert(A != 0 && "A must be non-zero.");
10253
10254 // 1. D = gcd(A, N)
10255 //
10256 // The gcd of A and N may have only one prime factor: 2. The number of
10257 // trailing zeros in A is its multiplicity
10258 uint32_t Mult2 = A.countr_zero();
10259 // D = 2^Mult2
10260
10261 // 2. Check if B is divisible by D.
10262 //
10263 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10264 // is not less than multiplicity of this prime factor for D.
10265 unsigned MinTZ = SE.getMinTrailingZeros(B);
10266 // Try again with the terminator of the loop predecessor for context-specific
10267 // result, if MinTZ s too small.
10268 if (MinTZ < Mult2 && L->getLoopPredecessor())
10269 MinTZ = SE.getMinTrailingZeros(B, L->getLoopPredecessor()->getTerminator());
10270 if (MinTZ < Mult2) {
10271 // Check if we can prove there's no remainder using URem.
10272 const SCEV *URem =
10273 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
10274 const SCEV *Zero = SE.getZero(B->getType());
10275 if (!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero)) {
10276 // Try to add a predicate ensuring B is a multiple of 1 << Mult2.
10277 if (!Predicates)
10278 return SE.getCouldNotCompute();
10279
10280 // Avoid adding a predicate that is known to be false.
10281 if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
10282 return SE.getCouldNotCompute();
10283 Predicates->push_back(SE.getEqualPredicate(URem, Zero));
10284 }
10285 }
10286
10287 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10288 // modulo (N / D).
10289 //
10290 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10291 // (N / D) in general. The inverse itself always fits into BW bits, though,
10292 // so we immediately truncate it.
10293 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10294 APInt I = AD.multiplicativeInverse().zext(BW);
10295
10296 // 4. Compute the minimum unsigned root of the equation:
10297 // I * (B / D) mod (N / D)
10298 // To simplify the computation, we factor out the divide by D:
10299 // (I * B mod N) / D
10300 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10301 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10302}
10303
10304/// For a given quadratic addrec, generate coefficients of the corresponding
10305/// quadratic equation, multiplied by a common value to ensure that they are
10306/// integers.
10307/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10308/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10309/// were multiplied by, and BitWidth is the bit width of the original addrec
10310/// coefficients.
10311/// This function returns std::nullopt if the addrec coefficients are not
10312/// compile- time constants.
10313static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10315 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10316 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10317 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10318 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10319 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10320 << *AddRec << '\n');
10321
10322 // We currently can only solve this if the coefficients are constants.
10323 if (!LC || !MC || !NC) {
10324 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10325 return std::nullopt;
10326 }
10327
10328 APInt L = LC->getAPInt();
10329 APInt M = MC->getAPInt();
10330 APInt N = NC->getAPInt();
10331 assert(!N.isZero() && "This is not a quadratic addrec");
10332
10333 unsigned BitWidth = LC->getAPInt().getBitWidth();
10334 unsigned NewWidth = BitWidth + 1;
10335 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10336 << BitWidth << '\n');
10337 // The sign-extension (as opposed to a zero-extension) here matches the
10338 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10339 N = N.sext(NewWidth);
10340 M = M.sext(NewWidth);
10341 L = L.sext(NewWidth);
10342
10343 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10344 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10345 // L+M, L+2M+N, L+3M+3N, ...
10346 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10347 //
10348 // The equation Acc = 0 is then
10349 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10350 // In a quadratic form it becomes:
10351 // N n^2 + (2M-N) n + 2L = 0.
10352
10353 APInt A = N;
10354 APInt B = 2 * M - A;
10355 APInt C = 2 * L;
10356 APInt T = APInt(NewWidth, 2);
10357 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10358 << "x + " << C << ", coeff bw: " << NewWidth
10359 << ", multiplied by " << T << '\n');
10360 return std::make_tuple(A, B, C, T, BitWidth);
10361}
10362
10363/// Helper function to compare optional APInts:
10364/// (a) if X and Y both exist, return min(X, Y),
10365/// (b) if neither X nor Y exist, return std::nullopt,
10366/// (c) if exactly one of X and Y exists, return that value.
10367static std::optional<APInt> MinOptional(std::optional<APInt> X,
10368 std::optional<APInt> Y) {
10369 if (X && Y) {
10370 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10371 APInt XW = X->sext(W);
10372 APInt YW = Y->sext(W);
10373 return XW.slt(YW) ? *X : *Y;
10374 }
10375 if (!X && !Y)
10376 return std::nullopt;
10377 return X ? *X : *Y;
10378}
10379
10380/// Helper function to truncate an optional APInt to a given BitWidth.
10381/// When solving addrec-related equations, it is preferable to return a value
10382/// that has the same bit width as the original addrec's coefficients. If the
10383/// solution fits in the original bit width, truncate it (except for i1).
10384/// Returning a value of a different bit width may inhibit some optimizations.
10385///
10386/// In general, a solution to a quadratic equation generated from an addrec
10387/// may require BW+1 bits, where BW is the bit width of the addrec's
10388/// coefficients. The reason is that the coefficients of the quadratic
10389/// equation are BW+1 bits wide (to avoid truncation when converting from
10390/// the addrec to the equation).
10391static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10392 unsigned BitWidth) {
10393 if (!X)
10394 return std::nullopt;
10395 unsigned W = X->getBitWidth();
10397 return X->trunc(BitWidth);
10398 return X;
10399}
10400
10401/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10402/// iterations. The values L, M, N are assumed to be signed, and they
10403/// should all have the same bit widths.
10404/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10405/// where BW is the bit width of the addrec's coefficients.
10406/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10407/// returned as such, otherwise the bit width of the returned value may
10408/// be greater than BW.
10409///
10410/// This function returns std::nullopt if
10411/// (a) the addrec coefficients are not constant, or
10412/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10413/// like x^2 = 5, no integer solutions exist, in other cases an integer
10414/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10415static std::optional<APInt>
10417 APInt A, B, C, M;
10418 unsigned BitWidth;
10419 auto T = GetQuadraticEquation(AddRec);
10420 if (!T)
10421 return std::nullopt;
10422
10423 std::tie(A, B, C, M, BitWidth) = *T;
10424 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10425 std::optional<APInt> X =
10427 if (!X)
10428 return std::nullopt;
10429
10430 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10431 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10432 if (!V->isZero())
10433 return std::nullopt;
10434
10435 return TruncIfPossible(X, BitWidth);
10436}
10437
10438/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10439/// iterations. The values M, N are assumed to be signed, and they
10440/// should all have the same bit widths.
10441/// Find the least n such that c(n) does not belong to the given range,
10442/// while c(n-1) does.
10443///
10444/// This function returns std::nullopt if
10445/// (a) the addrec coefficients are not constant, or
10446/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10447/// bounds of the range.
10448static std::optional<APInt>
10450 const ConstantRange &Range, ScalarEvolution &SE) {
10451 assert(AddRec->getOperand(0)->isZero() &&
10452 "Starting value of addrec should be 0");
10453 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10454 << Range << ", addrec " << *AddRec << '\n');
10455 // This case is handled in getNumIterationsInRange. Here we can assume that
10456 // we start in the range.
10457 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10458 "Addrec's initial value should be in range");
10459
10460 APInt A, B, C, M;
10461 unsigned BitWidth;
10462 auto T = GetQuadraticEquation(AddRec);
10463 if (!T)
10464 return std::nullopt;
10465
10466 // Be careful about the return value: there can be two reasons for not
10467 // returning an actual number. First, if no solutions to the equations
10468 // were found, and second, if the solutions don't leave the given range.
10469 // The first case means that the actual solution is "unknown", the second
10470 // means that it's known, but not valid. If the solution is unknown, we
10471 // cannot make any conclusions.
10472 // Return a pair: the optional solution and a flag indicating if the
10473 // solution was found.
10474 auto SolveForBoundary =
10475 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10476 // Solve for signed overflow and unsigned overflow, pick the lower
10477 // solution.
10478 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10479 << Bound << " (before multiplying by " << M << ")\n");
10480 Bound *= M; // The quadratic equation multiplier.
10481
10482 std::optional<APInt> SO;
10483 if (BitWidth > 1) {
10484 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10485 "signed overflow\n");
10487 }
10488 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10489 "unsigned overflow\n");
10490 std::optional<APInt> UO =
10492
10493 auto LeavesRange = [&] (const APInt &X) {
10494 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10495 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10496 if (Range.contains(V0->getValue()))
10497 return false;
10498 // X should be at least 1, so X-1 is non-negative.
10499 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10500 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10501 if (Range.contains(V1->getValue()))
10502 return true;
10503 return false;
10504 };
10505
10506 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10507 // can be a solution, but the function failed to find it. We cannot treat it
10508 // as "no solution".
10509 if (!SO || !UO)
10510 return {std::nullopt, false};
10511
10512 // Check the smaller value first to see if it leaves the range.
10513 // At this point, both SO and UO must have values.
10514 std::optional<APInt> Min = MinOptional(SO, UO);
10515 if (LeavesRange(*Min))
10516 return { Min, true };
10517 std::optional<APInt> Max = Min == SO ? UO : SO;
10518 if (LeavesRange(*Max))
10519 return { Max, true };
10520
10521 // Solutions were found, but were eliminated, hence the "true".
10522 return {std::nullopt, true};
10523 };
10524
10525 std::tie(A, B, C, M, BitWidth) = *T;
10526 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10527 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10528 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10529 auto SL = SolveForBoundary(Lower);
10530 auto SU = SolveForBoundary(Upper);
10531 // If any of the solutions was unknown, no meaninigful conclusions can
10532 // be made.
10533 if (!SL.second || !SU.second)
10534 return std::nullopt;
10535
10536 // Claim: The correct solution is not some value between Min and Max.
10537 //
10538 // Justification: Assuming that Min and Max are different values, one of
10539 // them is when the first signed overflow happens, the other is when the
10540 // first unsigned overflow happens. Crossing the range boundary is only
10541 // possible via an overflow (treating 0 as a special case of it, modeling
10542 // an overflow as crossing k*2^W for some k).
10543 //
10544 // The interesting case here is when Min was eliminated as an invalid
10545 // solution, but Max was not. The argument is that if there was another
10546 // overflow between Min and Max, it would also have been eliminated if
10547 // it was considered.
10548 //
10549 // For a given boundary, it is possible to have two overflows of the same
10550 // type (signed/unsigned) without having the other type in between: this
10551 // can happen when the vertex of the parabola is between the iterations
10552 // corresponding to the overflows. This is only possible when the two
10553 // overflows cross k*2^W for the same k. In such case, if the second one
10554 // left the range (and was the first one to do so), the first overflow
10555 // would have to enter the range, which would mean that either we had left
10556 // the range before or that we started outside of it. Both of these cases
10557 // are contradictions.
10558 //
10559 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10560 // solution is not some value between the Max for this boundary and the
10561 // Min of the other boundary.
10562 //
10563 // Justification: Assume that we had such Max_A and Min_B corresponding
10564 // to range boundaries A and B and such that Max_A < Min_B. If there was
10565 // a solution between Max_A and Min_B, it would have to be caused by an
10566 // overflow corresponding to either A or B. It cannot correspond to B,
10567 // since Min_B is the first occurrence of such an overflow. If it
10568 // corresponded to A, it would have to be either a signed or an unsigned
10569 // overflow that is larger than both eliminated overflows for A. But
10570 // between the eliminated overflows and this overflow, the values would
10571 // cover the entire value space, thus crossing the other boundary, which
10572 // is a contradiction.
10573
10574 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10575}
10576
10577ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10578 const Loop *L,
10579 bool ControlsOnlyExit,
10580 bool AllowPredicates) {
10581
10582 // This is only used for loops with a "x != y" exit test. The exit condition
10583 // is now expressed as a single expression, V = x-y. So the exit test is
10584 // effectively V != 0. We know and take advantage of the fact that this
10585 // expression only being used in a comparison by zero context.
10586
10588 // If the value is a constant
10589 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10590 // If the value is already zero, the branch will execute zero times.
10591 if (C->getValue()->isZero()) return C;
10592 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10593 }
10594
10595 const SCEVAddRecExpr *AddRec =
10596 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10597
10598 if (!AddRec && AllowPredicates)
10599 // Try to make this an AddRec using runtime tests, in the first X
10600 // iterations of this loop, where X is the SCEV expression found by the
10601 // algorithm below.
10602 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10603
10604 if (!AddRec || AddRec->getLoop() != L)
10605 return getCouldNotCompute();
10606
10607 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10608 // the quadratic equation to solve it.
10609 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10610 // We can only use this value if the chrec ends up with an exact zero
10611 // value at this index. When solving for "X*X != 5", for example, we
10612 // should not accept a root of 2.
10613 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10614 const auto *R = cast<SCEVConstant>(getConstant(*S));
10615 return ExitLimit(R, R, R, false, Predicates);
10616 }
10617 return getCouldNotCompute();
10618 }
10619
10620 // Otherwise we can only handle this if it is affine.
10621 if (!AddRec->isAffine())
10622 return getCouldNotCompute();
10623
10624 // If this is an affine expression, the execution count of this branch is
10625 // the minimum unsigned root of the following equation:
10626 //
10627 // Start + Step*N = 0 (mod 2^BW)
10628 //
10629 // equivalent to:
10630 //
10631 // Step*N = -Start (mod 2^BW)
10632 //
10633 // where BW is the common bit width of Start and Step.
10634
10635 // Get the initial value for the loop.
10636 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10637 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10638
10639 if (!isLoopInvariant(Step, L))
10640 return getCouldNotCompute();
10641
10642 LoopGuards Guards = LoopGuards::collect(L, *this);
10643 // Specialize step for this loop so we get context sensitive facts below.
10644 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10645
10646 // For positive steps (counting up until unsigned overflow):
10647 // N = -Start/Step (as unsigned)
10648 // For negative steps (counting down to zero):
10649 // N = Start/-Step
10650 // First compute the unsigned distance from zero in the direction of Step.
10651 bool CountDown = isKnownNegative(StepWLG);
10652 if (!CountDown && !isKnownNonNegative(StepWLG))
10653 return getCouldNotCompute();
10654
10655 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10656 // Handle unitary steps, which cannot wraparound.
10657 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10658 // N = Distance (as unsigned)
10659
10660 if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
10661 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10662 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10663
10664 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10665 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10666 // case, and see if we can improve the bound.
10667 //
10668 // Explicitly handling this here is necessary because getUnsignedRange
10669 // isn't context-sensitive; it doesn't know that we only care about the
10670 // range inside the loop.
10671 const SCEV *Zero = getZero(Distance->getType());
10672 const SCEV *One = getOne(Distance->getType());
10673 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10674 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10675 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10676 // as "unsigned_max(Distance + 1) - 1".
10677 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10678 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10679 }
10680 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10681 Predicates);
10682 }
10683
10684 // If the condition controls loop exit (the loop exits only if the expression
10685 // is true) and the addition is no-wrap we can use unsigned divide to
10686 // compute the backedge count. In this case, the step may not divide the
10687 // distance, but we don't care because if the condition is "missed" the loop
10688 // will have undefined behavior due to wrapping.
10689 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10690 loopHasNoAbnormalExits(AddRec->getLoop())) {
10691
10692 // If the stride is zero and the start is non-zero, the loop must be
10693 // infinite. In C++, most loops are finite by assumption, in which case the
10694 // step being zero implies UB must execute if the loop is entered.
10695 if (!(loopIsFiniteByAssumption(L) && isKnownNonZero(Start)) &&
10696 !isKnownNonZero(StepWLG))
10697 return getCouldNotCompute();
10698
10699 const SCEV *Exact =
10700 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10701 const SCEV *ConstantMax = getCouldNotCompute();
10702 if (Exact != getCouldNotCompute()) {
10703 APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
10704 ConstantMax =
10706 }
10707 const SCEV *SymbolicMax =
10708 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10709 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10710 }
10711
10712 // Solve the general equation.
10713 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10714 if (!StepC || StepC->getValue()->isZero())
10715 return getCouldNotCompute();
10716 const SCEV *E = SolveLinEquationWithOverflow(
10717 StepC->getAPInt(), getNegativeSCEV(Start),
10718 AllowPredicates ? &Predicates : nullptr, *this, L);
10719
10720 const SCEV *M = E;
10721 if (E != getCouldNotCompute()) {
10722 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10723 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10724 }
10725 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10726 return ExitLimit(E, M, S, false, Predicates);
10727}
10728
10729ScalarEvolution::ExitLimit
10730ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10731 // Loops that look like: while (X == 0) are very strange indeed. We don't
10732 // handle them yet except for the trivial case. This could be expanded in the
10733 // future as needed.
10734
10735 // If the value is a constant, check to see if it is known to be non-zero
10736 // already. If so, the backedge will execute zero times.
10737 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10738 if (!C->getValue()->isZero())
10739 return getZero(C->getType());
10740 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10741 }
10742
10743 // We could implement others, but I really doubt anyone writes loops like
10744 // this, and if they did, they would already be constant folded.
10745 return getCouldNotCompute();
10746}
10747
10748std::pair<const BasicBlock *, const BasicBlock *>
10749ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10750 const {
10751 // If the block has a unique predecessor, then there is no path from the
10752 // predecessor to the block that does not go through the direct edge
10753 // from the predecessor to the block.
10754 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10755 return {Pred, BB};
10756
10757 // A loop's header is defined to be a block that dominates the loop.
10758 // If the header has a unique predecessor outside the loop, it must be
10759 // a block that has exactly one successor that can reach the loop.
10760 if (const Loop *L = LI.getLoopFor(BB))
10761 return {L->getLoopPredecessor(), L->getHeader()};
10762
10763 return {nullptr, BB};
10764}
10765
10766/// SCEV structural equivalence is usually sufficient for testing whether two
10767/// expressions are equal, however for the purposes of looking for a condition
10768/// guarding a loop, it can be useful to be a little more general, since a
10769/// front-end may have replicated the controlling expression.
10770static bool HasSameValue(const SCEV *A, const SCEV *B) {
10771 // Quick check to see if they are the same SCEV.
10772 if (A == B) return true;
10773
10774 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
10775 // Not all instructions that are "identical" compute the same value. For
10776 // instance, two distinct alloca instructions allocating the same type are
10777 // identical and do not read memory; but compute distinct values.
10778 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
10779 };
10780
10781 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
10782 // two different instructions with the same value. Check for this case.
10783 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
10784 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
10785 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
10786 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
10787 if (ComputesEqualValues(AI, BI))
10788 return true;
10789
10790 // Otherwise assume they may have a different value.
10791 return false;
10792}
10793
10794static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
10795 const SCEV *Op0, *Op1;
10796 if (!match(S, m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))))
10797 return false;
10798 if (match(Op0, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
10799 LHS = Op1;
10800 return true;
10801 }
10802 if (match(Op1, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
10803 LHS = Op0;
10804 return true;
10805 }
10806 return false;
10807}
10808
10810 const SCEV *&RHS, unsigned Depth) {
10811 bool Changed = false;
10812 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
10813 // '0 != 0'.
10814 auto TrivialCase = [&](bool TriviallyTrue) {
10816 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
10817 return true;
10818 };
10819 // If we hit the max recursion limit bail out.
10820 if (Depth >= 3)
10821 return false;
10822
10823 const SCEV *NewLHS, *NewRHS;
10824 if (match(LHS, m_scev_c_Mul(m_SCEV(NewLHS), m_SCEVVScale())) &&
10825 match(RHS, m_scev_c_Mul(m_SCEV(NewRHS), m_SCEVVScale()))) {
10826 const SCEVMulExpr *LMul = cast<SCEVMulExpr>(LHS);
10827 const SCEVMulExpr *RMul = cast<SCEVMulExpr>(RHS);
10828
10829 // (X * vscale) pred (Y * vscale) ==> X pred Y
10830 // when both multiples are NSW.
10831 // (X * vscale) uicmp/eq/ne (Y * vscale) ==> X uicmp/eq/ne Y
10832 // when both multiples are NUW.
10833 if ((LMul->hasNoSignedWrap() && RMul->hasNoSignedWrap()) ||
10834 (LMul->hasNoUnsignedWrap() && RMul->hasNoUnsignedWrap() &&
10835 !ICmpInst::isSigned(Pred))) {
10836 LHS = NewLHS;
10837 RHS = NewRHS;
10838 Changed = true;
10839 }
10840 }
10841
10842 // Canonicalize a constant to the right side.
10843 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
10844 // Check for both operands constant.
10845 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
10846 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
10847 return TrivialCase(false);
10848 return TrivialCase(true);
10849 }
10850 // Otherwise swap the operands to put the constant on the right.
10851 std::swap(LHS, RHS);
10853 Changed = true;
10854 }
10855
10856 // If we're comparing an addrec with a value which is loop-invariant in the
10857 // addrec's loop, put the addrec on the left. Also make a dominance check,
10858 // as both operands could be addrecs loop-invariant in each other's loop.
10859 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
10860 const Loop *L = AR->getLoop();
10861 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
10862 std::swap(LHS, RHS);
10864 Changed = true;
10865 }
10866 }
10867
10868 // If there's a constant operand, canonicalize comparisons with boundary
10869 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
10870 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
10871 const APInt &RA = RC->getAPInt();
10872
10873 bool SimplifiedByConstantRange = false;
10874
10875 if (!ICmpInst::isEquality(Pred)) {
10877 if (ExactCR.isFullSet())
10878 return TrivialCase(true);
10879 if (ExactCR.isEmptySet())
10880 return TrivialCase(false);
10881
10882 APInt NewRHS;
10883 CmpInst::Predicate NewPred;
10884 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
10885 ICmpInst::isEquality(NewPred)) {
10886 // We were able to convert an inequality to an equality.
10887 Pred = NewPred;
10888 RHS = getConstant(NewRHS);
10889 Changed = SimplifiedByConstantRange = true;
10890 }
10891 }
10892
10893 if (!SimplifiedByConstantRange) {
10894 switch (Pred) {
10895 default:
10896 break;
10897 case ICmpInst::ICMP_EQ:
10898 case ICmpInst::ICMP_NE:
10899 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
10900 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
10901 Changed = true;
10902 break;
10903
10904 // The "Should have been caught earlier!" messages refer to the fact
10905 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
10906 // should have fired on the corresponding cases, and canonicalized the
10907 // check to trivial case.
10908
10909 case ICmpInst::ICMP_UGE:
10910 assert(!RA.isMinValue() && "Should have been caught earlier!");
10911 Pred = ICmpInst::ICMP_UGT;
10912 RHS = getConstant(RA - 1);
10913 Changed = true;
10914 break;
10915 case ICmpInst::ICMP_ULE:
10916 assert(!RA.isMaxValue() && "Should have been caught earlier!");
10917 Pred = ICmpInst::ICMP_ULT;
10918 RHS = getConstant(RA + 1);
10919 Changed = true;
10920 break;
10921 case ICmpInst::ICMP_SGE:
10922 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
10923 Pred = ICmpInst::ICMP_SGT;
10924 RHS = getConstant(RA - 1);
10925 Changed = true;
10926 break;
10927 case ICmpInst::ICMP_SLE:
10928 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
10929 Pred = ICmpInst::ICMP_SLT;
10930 RHS = getConstant(RA + 1);
10931 Changed = true;
10932 break;
10933 }
10934 }
10935 }
10936
10937 // Check for obvious equality.
10938 if (HasSameValue(LHS, RHS)) {
10939 if (ICmpInst::isTrueWhenEqual(Pred))
10940 return TrivialCase(true);
10942 return TrivialCase(false);
10943 }
10944
10945 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
10946 // adding or subtracting 1 from one of the operands.
10947 switch (Pred) {
10948 case ICmpInst::ICMP_SLE:
10949 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
10950 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10952 Pred = ICmpInst::ICMP_SLT;
10953 Changed = true;
10954 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
10955 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
10957 Pred = ICmpInst::ICMP_SLT;
10958 Changed = true;
10959 }
10960 break;
10961 case ICmpInst::ICMP_SGE:
10962 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
10963 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
10965 Pred = ICmpInst::ICMP_SGT;
10966 Changed = true;
10967 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
10968 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10970 Pred = ICmpInst::ICMP_SGT;
10971 Changed = true;
10972 }
10973 break;
10974 case ICmpInst::ICMP_ULE:
10975 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
10976 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10978 Pred = ICmpInst::ICMP_ULT;
10979 Changed = true;
10980 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
10981 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
10982 Pred = ICmpInst::ICMP_ULT;
10983 Changed = true;
10984 }
10985 break;
10986 case ICmpInst::ICMP_UGE:
10987 // If RHS is an op we can fold the -1, try that first.
10988 // Otherwise prefer LHS to preserve the nuw flag.
10989 if ((isa<SCEVConstant>(RHS) ||
10991 isa<SCEVConstant>(cast<SCEVNAryExpr>(RHS)->getOperand(0)))) &&
10992 !getUnsignedRangeMin(RHS).isMinValue()) {
10993 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
10994 Pred = ICmpInst::ICMP_UGT;
10995 Changed = true;
10996 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
10997 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10999 Pred = ICmpInst::ICMP_UGT;
11000 Changed = true;
11001 } else if (!getUnsignedRangeMin(RHS).isMinValue()) {
11002 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11003 Pred = ICmpInst::ICMP_UGT;
11004 Changed = true;
11005 }
11006 break;
11007 default:
11008 break;
11009 }
11010
11011 // TODO: More simplifications are possible here.
11012
11013 // Recursively simplify until we either hit a recursion limit or nothing
11014 // changes.
11015 if (Changed)
11016 (void)SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
11017
11018 return Changed;
11019}
11020
11022 return getSignedRangeMax(S).isNegative();
11023}
11024
11028
11030 return !getSignedRangeMin(S).isNegative();
11031}
11032
11036
11038 // Query push down for cases where the unsigned range is
11039 // less than sufficient.
11040 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
11041 return isKnownNonZero(SExt->getOperand(0));
11042 return getUnsignedRangeMin(S) != 0;
11043}
11044
11046 bool OrNegative) {
11047 auto NonRecursive = [this, OrNegative](const SCEV *S) {
11048 if (auto *C = dyn_cast<SCEVConstant>(S))
11049 return C->getAPInt().isPowerOf2() ||
11050 (OrNegative && C->getAPInt().isNegatedPowerOf2());
11051
11052 // The vscale_range indicates vscale is a power-of-two.
11053 return isa<SCEVVScale>(S) && F.hasFnAttribute(Attribute::VScaleRange);
11054 };
11055
11056 if (NonRecursive(S))
11057 return true;
11058
11059 auto *Mul = dyn_cast<SCEVMulExpr>(S);
11060 if (!Mul)
11061 return false;
11062 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
11063}
11064
11066 const SCEV *S, uint64_t M,
11068 if (M == 0)
11069 return false;
11070 if (M == 1)
11071 return true;
11072
11073 // Recursively check AddRec operands. An AddRecExpr S is a multiple of M if S
11074 // starts with a multiple of M and at every iteration step S only adds
11075 // multiples of M.
11076 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
11077 return isKnownMultipleOf(AddRec->getStart(), M, Assumptions) &&
11078 isKnownMultipleOf(AddRec->getStepRecurrence(*this), M, Assumptions);
11079
11080 // For a constant, check that "S % M == 0".
11081 if (auto *Cst = dyn_cast<SCEVConstant>(S)) {
11082 APInt C = Cst->getAPInt();
11083 return C.urem(M) == 0;
11084 }
11085
11086 // TODO: Also check other SCEV expressions, i.e., SCEVAddRecExpr, etc.
11087
11088 // Basic tests have failed.
11089 // Check "S % M == 0" at compile time and record runtime Assumptions.
11090 auto *STy = dyn_cast<IntegerType>(S->getType());
11091 const SCEV *SmodM =
11092 getURemExpr(S, getConstant(ConstantInt::get(STy, M, false)));
11093 const SCEV *Zero = getZero(STy);
11094
11095 // Check whether "S % M == 0" is known at compile time.
11096 if (isKnownPredicate(ICmpInst::ICMP_EQ, SmodM, Zero))
11097 return true;
11098
11099 // Check whether "S % M != 0" is known at compile time.
11100 if (isKnownPredicate(ICmpInst::ICMP_NE, SmodM, Zero))
11101 return false;
11102
11104
11105 // Detect redundant predicates.
11106 for (auto *A : Assumptions)
11107 if (A->implies(P, *this))
11108 return true;
11109
11110 // Only record non-redundant predicates.
11111 Assumptions.push_back(P);
11112 return true;
11113}
11114
11116 return ((isKnownNonNegative(S1) && isKnownNonNegative(S2)) ||
11118}
11119
11120std::pair<const SCEV *, const SCEV *>
11122 // Compute SCEV on entry of loop L.
11123 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
11124 if (Start == getCouldNotCompute())
11125 return { Start, Start };
11126 // Compute post increment SCEV for loop L.
11127 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
11128 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
11129 return { Start, PostInc };
11130}
11131
11133 const SCEV *RHS) {
11134 // First collect all loops.
11136 getUsedLoops(LHS, LoopsUsed);
11137 getUsedLoops(RHS, LoopsUsed);
11138
11139 if (LoopsUsed.empty())
11140 return false;
11141
11142 // Domination relationship must be a linear order on collected loops.
11143#ifndef NDEBUG
11144 for (const auto *L1 : LoopsUsed)
11145 for (const auto *L2 : LoopsUsed)
11146 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
11147 DT.dominates(L2->getHeader(), L1->getHeader())) &&
11148 "Domination relationship is not a linear order");
11149#endif
11150
11151 const Loop *MDL =
11152 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
11153 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
11154 });
11155
11156 // Get init and post increment value for LHS.
11157 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
11158 // if LHS contains unknown non-invariant SCEV then bail out.
11159 if (SplitLHS.first == getCouldNotCompute())
11160 return false;
11161 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
11162 // Get init and post increment value for RHS.
11163 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
11164 // if RHS contains unknown non-invariant SCEV then bail out.
11165 if (SplitRHS.first == getCouldNotCompute())
11166 return false;
11167 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
11168 // It is possible that init SCEV contains an invariant load but it does
11169 // not dominate MDL and is not available at MDL loop entry, so we should
11170 // check it here.
11171 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
11172 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
11173 return false;
11174
11175 // It seems backedge guard check is faster than entry one so in some cases
11176 // it can speed up whole estimation by short circuit
11177 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
11178 SplitRHS.second) &&
11179 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
11180}
11181
11183 const SCEV *RHS) {
11184 // Canonicalize the inputs first.
11185 (void)SimplifyICmpOperands(Pred, LHS, RHS);
11186
11187 if (isKnownViaInduction(Pred, LHS, RHS))
11188 return true;
11189
11190 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
11191 return true;
11192
11193 // Otherwise see what can be done with some simple reasoning.
11194 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
11195}
11196
11198 const SCEV *LHS,
11199 const SCEV *RHS) {
11200 if (isKnownPredicate(Pred, LHS, RHS))
11201 return true;
11203 return false;
11204 return std::nullopt;
11205}
11206
11208 const SCEV *RHS,
11209 const Instruction *CtxI) {
11210 // TODO: Analyze guards and assumes from Context's block.
11211 return isKnownPredicate(Pred, LHS, RHS) ||
11212 isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS);
11213}
11214
11215std::optional<bool>
11217 const SCEV *RHS, const Instruction *CtxI) {
11218 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
11219 if (KnownWithoutContext)
11220 return KnownWithoutContext;
11221
11222 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
11223 return true;
11225 CtxI->getParent(), ICmpInst::getInverseCmpPredicate(Pred), LHS, RHS))
11226 return false;
11227 return std::nullopt;
11228}
11229
11231 const SCEVAddRecExpr *LHS,
11232 const SCEV *RHS) {
11233 const Loop *L = LHS->getLoop();
11234 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
11235 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
11236}
11237
11238std::optional<ScalarEvolution::MonotonicPredicateType>
11240 ICmpInst::Predicate Pred) {
11241 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
11242
11243#ifndef NDEBUG
11244 // Verify an invariant: inverting the predicate should turn a monotonically
11245 // increasing change to a monotonically decreasing one, and vice versa.
11246 if (Result) {
11247 auto ResultSwapped =
11248 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11249
11250 assert(*ResultSwapped != *Result &&
11251 "monotonicity should flip as we flip the predicate");
11252 }
11253#endif
11254
11255 return Result;
11256}
11257
11258std::optional<ScalarEvolution::MonotonicPredicateType>
11259ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11260 ICmpInst::Predicate Pred) {
11261 // A zero step value for LHS means the induction variable is essentially a
11262 // loop invariant value. We don't really depend on the predicate actually
11263 // flipping from false to true (for increasing predicates, and the other way
11264 // around for decreasing predicates), all we care about is that *if* the
11265 // predicate changes then it only changes from false to true.
11266 //
11267 // A zero step value in itself is not very useful, but there may be places
11268 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11269 // as general as possible.
11270
11271 // Only handle LE/LT/GE/GT predicates.
11272 if (!ICmpInst::isRelational(Pred))
11273 return std::nullopt;
11274
11275 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11276 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11277 "Should be greater or less!");
11278
11279 // Check that AR does not wrap.
11280 if (ICmpInst::isUnsigned(Pred)) {
11281 if (!LHS->hasNoUnsignedWrap())
11282 return std::nullopt;
11284 }
11285 assert(ICmpInst::isSigned(Pred) &&
11286 "Relational predicate is either signed or unsigned!");
11287 if (!LHS->hasNoSignedWrap())
11288 return std::nullopt;
11289
11290 const SCEV *Step = LHS->getStepRecurrence(*this);
11291
11292 if (isKnownNonNegative(Step))
11294
11295 if (isKnownNonPositive(Step))
11297
11298 return std::nullopt;
11299}
11300
11301std::optional<ScalarEvolution::LoopInvariantPredicate>
11303 const SCEV *RHS, const Loop *L,
11304 const Instruction *CtxI) {
11305 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11306 if (!isLoopInvariant(RHS, L)) {
11307 if (!isLoopInvariant(LHS, L))
11308 return std::nullopt;
11309
11310 std::swap(LHS, RHS);
11312 }
11313
11314 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11315 if (!ArLHS || ArLHS->getLoop() != L)
11316 return std::nullopt;
11317
11318 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11319 if (!MonotonicType)
11320 return std::nullopt;
11321 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11322 // true as the loop iterates, and the backedge is control dependent on
11323 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11324 //
11325 // * if the predicate was false in the first iteration then the predicate
11326 // is never evaluated again, since the loop exits without taking the
11327 // backedge.
11328 // * if the predicate was true in the first iteration then it will
11329 // continue to be true for all future iterations since it is
11330 // monotonically increasing.
11331 //
11332 // For both the above possibilities, we can replace the loop varying
11333 // predicate with its value on the first iteration of the loop (which is
11334 // loop invariant).
11335 //
11336 // A similar reasoning applies for a monotonically decreasing predicate, by
11337 // replacing true with false and false with true in the above two bullets.
11339 auto P = Increasing ? Pred : ICmpInst::getInverseCmpPredicate(Pred);
11340
11341 if (isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
11343 RHS);
11344
11345 if (!CtxI)
11346 return std::nullopt;
11347 // Try to prove via context.
11348 // TODO: Support other cases.
11349 switch (Pred) {
11350 default:
11351 break;
11352 case ICmpInst::ICMP_ULE:
11353 case ICmpInst::ICMP_ULT: {
11354 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11355 // Given preconditions
11356 // (1) ArLHS does not cross the border of positive and negative parts of
11357 // range because of:
11358 // - Positive step; (TODO: lift this limitation)
11359 // - nuw - does not cross zero boundary;
11360 // - nsw - does not cross SINT_MAX boundary;
11361 // (2) ArLHS <s RHS
11362 // (3) RHS >=s 0
11363 // we can replace the loop variant ArLHS <u RHS condition with loop
11364 // invariant Start(ArLHS) <u RHS.
11365 //
11366 // Because of (1) there are two options:
11367 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11368 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11369 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11370 // Because of (2) ArLHS <u RHS is trivially true.
11371 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11372 // We can strengthen this to Start(ArLHS) <u RHS.
11373 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11374 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11375 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11376 isKnownNonNegative(RHS) &&
11377 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11379 RHS);
11380 }
11381 }
11382
11383 return std::nullopt;
11384}
11385
11386std::optional<ScalarEvolution::LoopInvariantPredicate>
11388 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11389 const Instruction *CtxI, const SCEV *MaxIter) {
11391 Pred, LHS, RHS, L, CtxI, MaxIter))
11392 return LIP;
11393 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11394 // Number of iterations expressed as UMIN isn't always great for expressing
11395 // the value on the last iteration. If the straightforward approach didn't
11396 // work, try the following trick: if the a predicate is invariant for X, it
11397 // is also invariant for umin(X, ...). So try to find something that works
11398 // among subexpressions of MaxIter expressed as umin.
11399 for (auto *Op : UMin->operands())
11401 Pred, LHS, RHS, L, CtxI, Op))
11402 return LIP;
11403 return std::nullopt;
11404}
11405
11406std::optional<ScalarEvolution::LoopInvariantPredicate>
11408 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11409 const Instruction *CtxI, const SCEV *MaxIter) {
11410 // Try to prove the following set of facts:
11411 // - The predicate is monotonic in the iteration space.
11412 // - If the check does not fail on the 1st iteration:
11413 // - No overflow will happen during first MaxIter iterations;
11414 // - It will not fail on the MaxIter'th iteration.
11415 // If the check does fail on the 1st iteration, we leave the loop and no
11416 // other checks matter.
11417
11418 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11419 if (!isLoopInvariant(RHS, L)) {
11420 if (!isLoopInvariant(LHS, L))
11421 return std::nullopt;
11422
11423 std::swap(LHS, RHS);
11425 }
11426
11427 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11428 if (!AR || AR->getLoop() != L)
11429 return std::nullopt;
11430
11431 // The predicate must be relational (i.e. <, <=, >=, >).
11432 if (!ICmpInst::isRelational(Pred))
11433 return std::nullopt;
11434
11435 // TODO: Support steps other than +/- 1.
11436 const SCEV *Step = AR->getStepRecurrence(*this);
11437 auto *One = getOne(Step->getType());
11438 auto *MinusOne = getNegativeSCEV(One);
11439 if (Step != One && Step != MinusOne)
11440 return std::nullopt;
11441
11442 // Type mismatch here means that MaxIter is potentially larger than max
11443 // unsigned value in start type, which mean we cannot prove no wrap for the
11444 // indvar.
11445 if (AR->getType() != MaxIter->getType())
11446 return std::nullopt;
11447
11448 // Value of IV on suggested last iteration.
11449 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11450 // Does it still meet the requirement?
11451 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11452 return std::nullopt;
11453 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11454 // not exceed max unsigned value of this type), this effectively proves
11455 // that there is no wrap during the iteration. To prove that there is no
11456 // signed/unsigned wrap, we need to check that
11457 // Start <= Last for step = 1 or Start >= Last for step = -1.
11458 ICmpInst::Predicate NoOverflowPred =
11460 if (Step == MinusOne)
11461 NoOverflowPred = ICmpInst::getSwappedCmpPredicate(NoOverflowPred);
11462 const SCEV *Start = AR->getStart();
11463 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11464 return std::nullopt;
11465
11466 // Everything is fine.
11467 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11468}
11469
11470bool ScalarEvolution::isKnownPredicateViaConstantRanges(CmpPredicate Pred,
11471 const SCEV *LHS,
11472 const SCEV *RHS) {
11473 if (HasSameValue(LHS, RHS))
11474 return ICmpInst::isTrueWhenEqual(Pred);
11475
11476 auto CheckRange = [&](bool IsSigned) {
11477 auto RangeLHS = IsSigned ? getSignedRange(LHS) : getUnsignedRange(LHS);
11478 auto RangeRHS = IsSigned ? getSignedRange(RHS) : getUnsignedRange(RHS);
11479 return RangeLHS.icmp(Pred, RangeRHS);
11480 };
11481
11482 // The check at the top of the function catches the case where the values are
11483 // known to be equal.
11484 if (Pred == CmpInst::ICMP_EQ)
11485 return false;
11486
11487 if (Pred == CmpInst::ICMP_NE) {
11488 if (CheckRange(true) || CheckRange(false))
11489 return true;
11490 auto *Diff = getMinusSCEV(LHS, RHS);
11491 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11492 }
11493
11494 return CheckRange(CmpInst::isSigned(Pred));
11495}
11496
11497bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred,
11498 const SCEV *LHS,
11499 const SCEV *RHS) {
11500 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11501 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11502 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11503 // OutC1 and OutC2.
11504 auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
11505 APInt &OutC1, APInt &OutC2,
11506 SCEV::NoWrapFlags ExpectedFlags) {
11507 const SCEV *XNonConstOp, *XConstOp;
11508 const SCEV *YNonConstOp, *YConstOp;
11509 SCEV::NoWrapFlags XFlagsPresent;
11510 SCEV::NoWrapFlags YFlagsPresent;
11511
11512 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11513 XConstOp = getZero(X->getType());
11514 XNonConstOp = X;
11515 XFlagsPresent = ExpectedFlags;
11516 }
11517 if (!isa<SCEVConstant>(XConstOp))
11518 return false;
11519
11520 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11521 YConstOp = getZero(Y->getType());
11522 YNonConstOp = Y;
11523 YFlagsPresent = ExpectedFlags;
11524 }
11525
11526 if (YNonConstOp != XNonConstOp)
11527 return false;
11528
11529 if (!isa<SCEVConstant>(YConstOp))
11530 return false;
11531
11532 // When matching ADDs with NUW flags (and unsigned predicates), only the
11533 // second ADD (with the larger constant) requires NUW.
11534 if ((YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11535 return false;
11536 if (ExpectedFlags != SCEV::FlagNUW &&
11537 (XFlagsPresent & ExpectedFlags) != ExpectedFlags) {
11538 return false;
11539 }
11540
11541 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11542 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11543
11544 return true;
11545 };
11546
11547 APInt C1;
11548 APInt C2;
11549
11550 switch (Pred) {
11551 default:
11552 break;
11553
11554 case ICmpInst::ICMP_SGE:
11555 std::swap(LHS, RHS);
11556 [[fallthrough]];
11557 case ICmpInst::ICMP_SLE:
11558 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11559 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11560 return true;
11561
11562 break;
11563
11564 case ICmpInst::ICMP_SGT:
11565 std::swap(LHS, RHS);
11566 [[fallthrough]];
11567 case ICmpInst::ICMP_SLT:
11568 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11569 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11570 return true;
11571
11572 break;
11573
11574 case ICmpInst::ICMP_UGE:
11575 std::swap(LHS, RHS);
11576 [[fallthrough]];
11577 case ICmpInst::ICMP_ULE:
11578 // (X + C1) u<= (X + C2)<nuw> for C1 u<= C2.
11579 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11580 return true;
11581
11582 break;
11583
11584 case ICmpInst::ICMP_UGT:
11585 std::swap(LHS, RHS);
11586 [[fallthrough]];
11587 case ICmpInst::ICMP_ULT:
11588 // (X + C1) u< (X + C2)<nuw> if C1 u< C2.
11589 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11590 return true;
11591 break;
11592 }
11593
11594 return false;
11595}
11596
11597bool ScalarEvolution::isKnownPredicateViaSplitting(CmpPredicate Pred,
11598 const SCEV *LHS,
11599 const SCEV *RHS) {
11600 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11601 return false;
11602
11603 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11604 // the stack can result in exponential time complexity.
11605 SaveAndRestore Restore(ProvingSplitPredicate, true);
11606
11607 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11608 //
11609 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11610 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11611 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11612 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11613 // use isKnownPredicate later if needed.
11614 return isKnownNonNegative(RHS) &&
11617}
11618
11619bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, CmpPredicate Pred,
11620 const SCEV *LHS, const SCEV *RHS) {
11621 // No need to even try if we know the module has no guards.
11622 if (!HasGuards)
11623 return false;
11624
11625 return any_of(*BB, [&](const Instruction &I) {
11626 using namespace llvm::PatternMatch;
11627
11628 Value *Condition;
11630 m_Value(Condition))) &&
11631 isImpliedCond(Pred, LHS, RHS, Condition, false);
11632 });
11633}
11634
11635/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11636/// protected by a conditional between LHS and RHS. This is used to
11637/// to eliminate casts.
11639 CmpPredicate Pred,
11640 const SCEV *LHS,
11641 const SCEV *RHS) {
11642 // Interpret a null as meaning no loop, where there is obviously no guard
11643 // (interprocedural conditions notwithstanding). Do not bother about
11644 // unreachable loops.
11645 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11646 return true;
11647
11648 if (VerifyIR)
11649 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11650 "This cannot be done on broken IR!");
11651
11652
11653 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11654 return true;
11655
11656 BasicBlock *Latch = L->getLoopLatch();
11657 if (!Latch)
11658 return false;
11659
11660 BranchInst *LoopContinuePredicate =
11662 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
11663 isImpliedCond(Pred, LHS, RHS,
11664 LoopContinuePredicate->getCondition(),
11665 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11666 return true;
11667
11668 // We don't want more than one activation of the following loops on the stack
11669 // -- that can lead to O(n!) time complexity.
11670 if (WalkingBEDominatingConds)
11671 return false;
11672
11673 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11674
11675 // See if we can exploit a trip count to prove the predicate.
11676 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11677 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11678 if (LatchBECount != getCouldNotCompute()) {
11679 // We know that Latch branches back to the loop header exactly
11680 // LatchBECount times. This means the backdege condition at Latch is
11681 // equivalent to "{0,+,1} u< LatchBECount".
11682 Type *Ty = LatchBECount->getType();
11683 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11684 const SCEV *LoopCounter =
11685 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11686 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11687 LatchBECount))
11688 return true;
11689 }
11690
11691 // Check conditions due to any @llvm.assume intrinsics.
11692 for (auto &AssumeVH : AC.assumptions()) {
11693 if (!AssumeVH)
11694 continue;
11695 auto *CI = cast<CallInst>(AssumeVH);
11696 if (!DT.dominates(CI, Latch->getTerminator()))
11697 continue;
11698
11699 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11700 return true;
11701 }
11702
11703 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11704 return true;
11705
11706 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11707 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11708 assert(DTN && "should reach the loop header before reaching the root!");
11709
11710 BasicBlock *BB = DTN->getBlock();
11711 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11712 return true;
11713
11714 BasicBlock *PBB = BB->getSinglePredecessor();
11715 if (!PBB)
11716 continue;
11717
11718 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
11719 if (!ContinuePredicate || !ContinuePredicate->isConditional())
11720 continue;
11721
11722 Value *Condition = ContinuePredicate->getCondition();
11723
11724 // If we have an edge `E` within the loop body that dominates the only
11725 // latch, the condition guarding `E` also guards the backedge. This
11726 // reasoning works only for loops with a single latch.
11727
11728 BasicBlockEdge DominatingEdge(PBB, BB);
11729 if (DominatingEdge.isSingleEdge()) {
11730 // We're constructively (and conservatively) enumerating edges within the
11731 // loop body that dominate the latch. The dominator tree better agree
11732 // with us on this:
11733 assert(DT.dominates(DominatingEdge, Latch) && "should be!");
11734
11735 if (isImpliedCond(Pred, LHS, RHS, Condition,
11736 BB != ContinuePredicate->getSuccessor(0)))
11737 return true;
11738 }
11739 }
11740
11741 return false;
11742}
11743
11745 CmpPredicate Pred,
11746 const SCEV *LHS,
11747 const SCEV *RHS) {
11748 // Do not bother proving facts for unreachable code.
11749 if (!DT.isReachableFromEntry(BB))
11750 return true;
11751 if (VerifyIR)
11752 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11753 "This cannot be done on broken IR!");
11754
11755 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11756 // the facts (a >= b && a != b) separately. A typical situation is when the
11757 // non-strict comparison is known from ranges and non-equality is known from
11758 // dominating predicates. If we are proving strict comparison, we always try
11759 // to prove non-equality and non-strict comparison separately.
11760 CmpPredicate NonStrictPredicate = ICmpInst::getNonStrictCmpPredicate(Pred);
11761 const bool ProvingStrictComparison =
11762 Pred != NonStrictPredicate.dropSameSign();
11763 bool ProvedNonStrictComparison = false;
11764 bool ProvedNonEquality = false;
11765
11766 auto SplitAndProve = [&](std::function<bool(CmpPredicate)> Fn) -> bool {
11767 if (!ProvedNonStrictComparison)
11768 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11769 if (!ProvedNonEquality)
11770 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11771 if (ProvedNonStrictComparison && ProvedNonEquality)
11772 return true;
11773 return false;
11774 };
11775
11776 if (ProvingStrictComparison) {
11777 auto ProofFn = [&](CmpPredicate P) {
11778 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
11779 };
11780 if (SplitAndProve(ProofFn))
11781 return true;
11782 }
11783
11784 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
11785 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
11786 const Instruction *CtxI = &BB->front();
11787 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
11788 return true;
11789 if (ProvingStrictComparison) {
11790 auto ProofFn = [&](CmpPredicate P) {
11791 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
11792 };
11793 if (SplitAndProve(ProofFn))
11794 return true;
11795 }
11796 return false;
11797 };
11798
11799 // Starting at the block's predecessor, climb up the predecessor chain, as long
11800 // as there are predecessors that can be found that have unique successors
11801 // leading to the original block.
11802 const Loop *ContainingLoop = LI.getLoopFor(BB);
11803 const BasicBlock *PredBB;
11804 if (ContainingLoop && ContainingLoop->getHeader() == BB)
11805 PredBB = ContainingLoop->getLoopPredecessor();
11806 else
11807 PredBB = BB->getSinglePredecessor();
11808 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
11809 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
11810 const BranchInst *BlockEntryPredicate =
11811 dyn_cast<BranchInst>(Pair.first->getTerminator());
11812 if (!BlockEntryPredicate || BlockEntryPredicate->isUnconditional())
11813 continue;
11814
11815 if (ProveViaCond(BlockEntryPredicate->getCondition(),
11816 BlockEntryPredicate->getSuccessor(0) != Pair.second))
11817 return true;
11818 }
11819
11820 // Check conditions due to any @llvm.assume intrinsics.
11821 for (auto &AssumeVH : AC.assumptions()) {
11822 if (!AssumeVH)
11823 continue;
11824 auto *CI = cast<CallInst>(AssumeVH);
11825 if (!DT.dominates(CI, BB))
11826 continue;
11827
11828 if (ProveViaCond(CI->getArgOperand(0), false))
11829 return true;
11830 }
11831
11832 // Check conditions due to any @llvm.experimental.guard intrinsics.
11833 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
11834 F.getParent(), Intrinsic::experimental_guard);
11835 if (GuardDecl)
11836 for (const auto *GU : GuardDecl->users())
11837 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
11838 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
11839 if (ProveViaCond(Guard->getArgOperand(0), false))
11840 return true;
11841 return false;
11842}
11843
11845 const SCEV *LHS,
11846 const SCEV *RHS) {
11847 // Interpret a null as meaning no loop, where there is obviously no guard
11848 // (interprocedural conditions notwithstanding).
11849 if (!L)
11850 return false;
11851
11852 // Both LHS and RHS must be available at loop entry.
11854 "LHS is not available at Loop Entry");
11856 "RHS is not available at Loop Entry");
11857
11858 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11859 return true;
11860
11861 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
11862}
11863
11864bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11865 const SCEV *RHS,
11866 const Value *FoundCondValue, bool Inverse,
11867 const Instruction *CtxI) {
11868 // False conditions implies anything. Do not bother analyzing it further.
11869 if (FoundCondValue ==
11870 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
11871 return true;
11872
11873 if (!PendingLoopPredicates.insert(FoundCondValue).second)
11874 return false;
11875
11876 llvm::scope_exit ClearOnExit(
11877 [&]() { PendingLoopPredicates.erase(FoundCondValue); });
11878
11879 // Recursively handle And and Or conditions.
11880 const Value *Op0, *Op1;
11881 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
11882 if (!Inverse)
11883 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11884 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11885 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
11886 if (Inverse)
11887 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11888 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11889 }
11890
11891 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
11892 if (!ICI) return false;
11893
11894 // Now that we found a conditional branch that dominates the loop or controls
11895 // the loop latch. Check to see if it is the comparison we are looking for.
11896 CmpPredicate FoundPred;
11897 if (Inverse)
11898 FoundPred = ICI->getInverseCmpPredicate();
11899 else
11900 FoundPred = ICI->getCmpPredicate();
11901
11902 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
11903 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
11904
11905 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
11906}
11907
11908bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11909 const SCEV *RHS, CmpPredicate FoundPred,
11910 const SCEV *FoundLHS, const SCEV *FoundRHS,
11911 const Instruction *CtxI) {
11912 // Balance the types.
11913 if (getTypeSizeInBits(LHS->getType()) <
11914 getTypeSizeInBits(FoundLHS->getType())) {
11915 // For unsigned and equality predicates, try to prove that both found
11916 // operands fit into narrow unsigned range. If so, try to prove facts in
11917 // narrow types.
11918 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
11919 !FoundRHS->getType()->isPointerTy()) {
11920 auto *NarrowType = LHS->getType();
11921 auto *WideType = FoundLHS->getType();
11922 auto BitWidth = getTypeSizeInBits(NarrowType);
11923 const SCEV *MaxValue = getZeroExtendExpr(
11925 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
11926 MaxValue) &&
11927 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
11928 MaxValue)) {
11929 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
11930 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
11931 // We cannot preserve samesign after truncation.
11932 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred.dropSameSign(),
11933 TruncFoundLHS, TruncFoundRHS, CtxI))
11934 return true;
11935 }
11936 }
11937
11938 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
11939 return false;
11940 if (CmpInst::isSigned(Pred)) {
11941 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
11942 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
11943 } else {
11944 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
11945 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
11946 }
11947 } else if (getTypeSizeInBits(LHS->getType()) >
11948 getTypeSizeInBits(FoundLHS->getType())) {
11949 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
11950 return false;
11951 if (CmpInst::isSigned(FoundPred)) {
11952 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
11953 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
11954 } else {
11955 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
11956 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
11957 }
11958 }
11959 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
11960 FoundRHS, CtxI);
11961}
11962
11963bool ScalarEvolution::isImpliedCondBalancedTypes(
11964 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
11965 const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) {
11967 getTypeSizeInBits(FoundLHS->getType()) &&
11968 "Types should be balanced!");
11969 // Canonicalize the query to match the way instcombine will have
11970 // canonicalized the comparison.
11971 if (SimplifyICmpOperands(Pred, LHS, RHS))
11972 if (LHS == RHS)
11973 return CmpInst::isTrueWhenEqual(Pred);
11974 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
11975 if (FoundLHS == FoundRHS)
11976 return CmpInst::isFalseWhenEqual(FoundPred);
11977
11978 // Check to see if we can make the LHS or RHS match.
11979 if (LHS == FoundRHS || RHS == FoundLHS) {
11980 if (isa<SCEVConstant>(RHS)) {
11981 std::swap(FoundLHS, FoundRHS);
11982 FoundPred = ICmpInst::getSwappedCmpPredicate(FoundPred);
11983 } else {
11984 std::swap(LHS, RHS);
11986 }
11987 }
11988
11989 // Check whether the found predicate is the same as the desired predicate.
11990 if (auto P = CmpPredicate::getMatching(FoundPred, Pred))
11991 return isImpliedCondOperands(*P, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11992
11993 // Check whether swapping the found predicate makes it the same as the
11994 // desired predicate.
11995 if (auto P = CmpPredicate::getMatching(
11996 ICmpInst::getSwappedCmpPredicate(FoundPred), Pred)) {
11997 // We can write the implication
11998 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
11999 // using one of the following ways:
12000 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
12001 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
12002 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
12003 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
12004 // Forms 1. and 2. require swapping the operands of one condition. Don't
12005 // do this if it would break canonical constant/addrec ordering.
12007 return isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P), RHS,
12008 LHS, FoundLHS, FoundRHS, CtxI);
12009 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
12010 return isImpliedCondOperands(*P, LHS, RHS, FoundRHS, FoundLHS, CtxI);
12011
12012 // There's no clear preference between forms 3. and 4., try both. Avoid
12013 // forming getNotSCEV of pointer values as the resulting subtract is
12014 // not legal.
12015 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
12016 isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P),
12017 getNotSCEV(LHS), getNotSCEV(RHS), FoundLHS,
12018 FoundRHS, CtxI))
12019 return true;
12020
12021 if (!FoundLHS->getType()->isPointerTy() &&
12022 !FoundRHS->getType()->isPointerTy() &&
12023 isImpliedCondOperands(*P, LHS, RHS, getNotSCEV(FoundLHS),
12024 getNotSCEV(FoundRHS), CtxI))
12025 return true;
12026
12027 return false;
12028 }
12029
12030 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
12031 CmpInst::Predicate P2) {
12032 assert(P1 != P2 && "Handled earlier!");
12033 return CmpInst::isRelational(P2) &&
12035 };
12036 if (IsSignFlippedPredicate(Pred, FoundPred)) {
12037 // Unsigned comparison is the same as signed comparison when both the
12038 // operands are non-negative or negative.
12039 if (haveSameSign(FoundLHS, FoundRHS))
12040 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12041 // Create local copies that we can freely swap and canonicalize our
12042 // conditions to "le/lt".
12043 CmpPredicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
12044 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
12045 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
12046 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
12047 CanonicalPred = ICmpInst::getSwappedCmpPredicate(CanonicalPred);
12048 CanonicalFoundPred = ICmpInst::getSwappedCmpPredicate(CanonicalFoundPred);
12049 std::swap(CanonicalLHS, CanonicalRHS);
12050 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
12051 }
12052 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
12053 "Must be!");
12054 assert((ICmpInst::isLT(CanonicalFoundPred) ||
12055 ICmpInst::isLE(CanonicalFoundPred)) &&
12056 "Must be!");
12057 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
12058 // Use implication:
12059 // x <u y && y >=s 0 --> x <s y.
12060 // If we can prove the left part, the right part is also proven.
12061 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12062 CanonicalRHS, CanonicalFoundLHS,
12063 CanonicalFoundRHS);
12064 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
12065 // Use implication:
12066 // x <s y && y <s 0 --> x <u y.
12067 // If we can prove the left part, the right part is also proven.
12068 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12069 CanonicalRHS, CanonicalFoundLHS,
12070 CanonicalFoundRHS);
12071 }
12072
12073 // Check if we can make progress by sharpening ranges.
12074 if (FoundPred == ICmpInst::ICMP_NE &&
12075 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
12076
12077 const SCEVConstant *C = nullptr;
12078 const SCEV *V = nullptr;
12079
12080 if (isa<SCEVConstant>(FoundLHS)) {
12081 C = cast<SCEVConstant>(FoundLHS);
12082 V = FoundRHS;
12083 } else {
12084 C = cast<SCEVConstant>(FoundRHS);
12085 V = FoundLHS;
12086 }
12087
12088 // The guarding predicate tells us that C != V. If the known range
12089 // of V is [C, t), we can sharpen the range to [C + 1, t). The
12090 // range we consider has to correspond to same signedness as the
12091 // predicate we're interested in folding.
12092
12093 APInt Min = ICmpInst::isSigned(Pred) ?
12095
12096 if (Min == C->getAPInt()) {
12097 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
12098 // This is true even if (Min + 1) wraps around -- in case of
12099 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
12100
12101 APInt SharperMin = Min + 1;
12102
12103 switch (Pred) {
12104 case ICmpInst::ICMP_SGE:
12105 case ICmpInst::ICMP_UGE:
12106 // We know V `Pred` SharperMin. If this implies LHS `Pred`
12107 // RHS, we're done.
12108 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
12109 CtxI))
12110 return true;
12111 [[fallthrough]];
12112
12113 case ICmpInst::ICMP_SGT:
12114 case ICmpInst::ICMP_UGT:
12115 // We know from the range information that (V `Pred` Min ||
12116 // V == Min). We know from the guarding condition that !(V
12117 // == Min). This gives us
12118 //
12119 // V `Pred` Min || V == Min && !(V == Min)
12120 // => V `Pred` Min
12121 //
12122 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
12123
12124 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
12125 return true;
12126 break;
12127
12128 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
12129 case ICmpInst::ICMP_SLE:
12130 case ICmpInst::ICMP_ULE:
12131 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12132 LHS, V, getConstant(SharperMin), CtxI))
12133 return true;
12134 [[fallthrough]];
12135
12136 case ICmpInst::ICMP_SLT:
12137 case ICmpInst::ICMP_ULT:
12138 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12139 LHS, V, getConstant(Min), CtxI))
12140 return true;
12141 break;
12142
12143 default:
12144 // No change
12145 break;
12146 }
12147 }
12148 }
12149
12150 // Check whether the actual condition is beyond sufficient.
12151 if (FoundPred == ICmpInst::ICMP_EQ)
12152 if (ICmpInst::isTrueWhenEqual(Pred))
12153 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12154 return true;
12155 if (Pred == ICmpInst::ICMP_NE)
12156 if (!ICmpInst::isTrueWhenEqual(FoundPred))
12157 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12158 return true;
12159
12160 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
12161 return true;
12162
12163 // Otherwise assume the worst.
12164 return false;
12165}
12166
12167bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
12168 const SCEV *&L, const SCEV *&R,
12169 SCEV::NoWrapFlags &Flags) {
12170 if (!match(Expr, m_scev_Add(m_SCEV(L), m_SCEV(R))))
12171 return false;
12172
12173 Flags = cast<SCEVAddExpr>(Expr)->getNoWrapFlags();
12174 return true;
12175}
12176
12177std::optional<APInt>
12179 // We avoid subtracting expressions here because this function is usually
12180 // fairly deep in the call stack (i.e. is called many times).
12181
12182 unsigned BW = getTypeSizeInBits(More->getType());
12183 APInt Diff(BW, 0);
12184 APInt DiffMul(BW, 1);
12185 // Try various simplifications to reduce the difference to a constant. Limit
12186 // the number of allowed simplifications to keep compile-time low.
12187 for (unsigned I = 0; I < 8; ++I) {
12188 if (More == Less)
12189 return Diff;
12190
12191 // Reduce addrecs with identical steps to their start value.
12193 const auto *LAR = cast<SCEVAddRecExpr>(Less);
12194 const auto *MAR = cast<SCEVAddRecExpr>(More);
12195
12196 if (LAR->getLoop() != MAR->getLoop())
12197 return std::nullopt;
12198
12199 // We look at affine expressions only; not for correctness but to keep
12200 // getStepRecurrence cheap.
12201 if (!LAR->isAffine() || !MAR->isAffine())
12202 return std::nullopt;
12203
12204 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
12205 return std::nullopt;
12206
12207 Less = LAR->getStart();
12208 More = MAR->getStart();
12209 continue;
12210 }
12211
12212 // Try to match a common constant multiply.
12213 auto MatchConstMul =
12214 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
12215 const APInt *C;
12216 const SCEV *Op;
12217 if (match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op))))
12218 return {{Op, *C}};
12219 return std::nullopt;
12220 };
12221 if (auto MatchedMore = MatchConstMul(More)) {
12222 if (auto MatchedLess = MatchConstMul(Less)) {
12223 if (MatchedMore->second == MatchedLess->second) {
12224 More = MatchedMore->first;
12225 Less = MatchedLess->first;
12226 DiffMul *= MatchedMore->second;
12227 continue;
12228 }
12229 }
12230 }
12231
12232 // Try to cancel out common factors in two add expressions.
12234 auto Add = [&](const SCEV *S, int Mul) {
12235 if (auto *C = dyn_cast<SCEVConstant>(S)) {
12236 if (Mul == 1) {
12237 Diff += C->getAPInt() * DiffMul;
12238 } else {
12239 assert(Mul == -1);
12240 Diff -= C->getAPInt() * DiffMul;
12241 }
12242 } else
12243 Multiplicity[S] += Mul;
12244 };
12245 auto Decompose = [&](const SCEV *S, int Mul) {
12246 if (isa<SCEVAddExpr>(S)) {
12247 for (const SCEV *Op : S->operands())
12248 Add(Op, Mul);
12249 } else
12250 Add(S, Mul);
12251 };
12252 Decompose(More, 1);
12253 Decompose(Less, -1);
12254
12255 // Check whether all the non-constants cancel out, or reduce to new
12256 // More/Less values.
12257 const SCEV *NewMore = nullptr, *NewLess = nullptr;
12258 for (const auto &[S, Mul] : Multiplicity) {
12259 if (Mul == 0)
12260 continue;
12261 if (Mul == 1) {
12262 if (NewMore)
12263 return std::nullopt;
12264 NewMore = S;
12265 } else if (Mul == -1) {
12266 if (NewLess)
12267 return std::nullopt;
12268 NewLess = S;
12269 } else
12270 return std::nullopt;
12271 }
12272
12273 // Values stayed the same, no point in trying further.
12274 if (NewMore == More || NewLess == Less)
12275 return std::nullopt;
12276
12277 More = NewMore;
12278 Less = NewLess;
12279
12280 // Reduced to constant.
12281 if (!More && !Less)
12282 return Diff;
12283
12284 // Left with variable on only one side, bail out.
12285 if (!More || !Less)
12286 return std::nullopt;
12287 }
12288
12289 // Did not reduce to constant.
12290 return std::nullopt;
12291}
12292
12293bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12294 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS,
12295 const SCEV *FoundRHS, const Instruction *CtxI) {
12296 // Try to recognize the following pattern:
12297 //
12298 // FoundRHS = ...
12299 // ...
12300 // loop:
12301 // FoundLHS = {Start,+,W}
12302 // context_bb: // Basic block from the same loop
12303 // known(Pred, FoundLHS, FoundRHS)
12304 //
12305 // If some predicate is known in the context of a loop, it is also known on
12306 // each iteration of this loop, including the first iteration. Therefore, in
12307 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12308 // prove the original pred using this fact.
12309 if (!CtxI)
12310 return false;
12311 const BasicBlock *ContextBB = CtxI->getParent();
12312 // Make sure AR varies in the context block.
12313 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12314 const Loop *L = AR->getLoop();
12315 // Make sure that context belongs to the loop and executes on 1st iteration
12316 // (if it ever executes at all).
12317 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12318 return false;
12319 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12320 return false;
12321 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12322 }
12323
12324 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12325 const Loop *L = AR->getLoop();
12326 // Make sure that context belongs to the loop and executes on 1st iteration
12327 // (if it ever executes at all).
12328 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12329 return false;
12330 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12331 return false;
12332 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12333 }
12334
12335 return false;
12336}
12337
12338bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(CmpPredicate Pred,
12339 const SCEV *LHS,
12340 const SCEV *RHS,
12341 const SCEV *FoundLHS,
12342 const SCEV *FoundRHS) {
12343 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12344 return false;
12345
12346 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12347 if (!AddRecLHS)
12348 return false;
12349
12350 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12351 if (!AddRecFoundLHS)
12352 return false;
12353
12354 // We'd like to let SCEV reason about control dependencies, so we constrain
12355 // both the inequalities to be about add recurrences on the same loop. This
12356 // way we can use isLoopEntryGuardedByCond later.
12357
12358 const Loop *L = AddRecFoundLHS->getLoop();
12359 if (L != AddRecLHS->getLoop())
12360 return false;
12361
12362 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12363 //
12364 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12365 // ... (2)
12366 //
12367 // Informal proof for (2), assuming (1) [*]:
12368 //
12369 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12370 //
12371 // Then
12372 //
12373 // FoundLHS s< FoundRHS s< INT_MIN - C
12374 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12375 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12376 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12377 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12378 // <=> FoundLHS + C s< FoundRHS + C
12379 //
12380 // [*]: (1) can be proved by ruling out overflow.
12381 //
12382 // [**]: This can be proved by analyzing all the four possibilities:
12383 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12384 // (A s>= 0, B s>= 0).
12385 //
12386 // Note:
12387 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12388 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12389 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12390 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12391 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12392 // C)".
12393
12394 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12395 if (!LDiff)
12396 return false;
12397 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12398 if (!RDiff || *LDiff != *RDiff)
12399 return false;
12400
12401 if (LDiff->isMinValue())
12402 return true;
12403
12404 APInt FoundRHSLimit;
12405
12406 if (Pred == CmpInst::ICMP_ULT) {
12407 FoundRHSLimit = -(*RDiff);
12408 } else {
12409 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12410 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12411 }
12412
12413 // Try to prove (1) or (2), as needed.
12414 return isAvailableAtLoopEntry(FoundRHS, L) &&
12415 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12416 getConstant(FoundRHSLimit));
12417}
12418
12419bool ScalarEvolution::isImpliedViaMerge(CmpPredicate Pred, const SCEV *LHS,
12420 const SCEV *RHS, const SCEV *FoundLHS,
12421 const SCEV *FoundRHS, unsigned Depth) {
12422 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12423
12424 llvm::scope_exit ClearOnExit([&]() {
12425 if (LPhi) {
12426 bool Erased = PendingMerges.erase(LPhi);
12427 assert(Erased && "Failed to erase LPhi!");
12428 (void)Erased;
12429 }
12430 if (RPhi) {
12431 bool Erased = PendingMerges.erase(RPhi);
12432 assert(Erased && "Failed to erase RPhi!");
12433 (void)Erased;
12434 }
12435 });
12436
12437 // Find respective Phis and check that they are not being pending.
12438 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12439 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12440 if (!PendingMerges.insert(Phi).second)
12441 return false;
12442 LPhi = Phi;
12443 }
12444 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12445 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12446 // If we detect a loop of Phi nodes being processed by this method, for
12447 // example:
12448 //
12449 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12450 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12451 //
12452 // we don't want to deal with a case that complex, so return conservative
12453 // answer false.
12454 if (!PendingMerges.insert(Phi).second)
12455 return false;
12456 RPhi = Phi;
12457 }
12458
12459 // If none of LHS, RHS is a Phi, nothing to do here.
12460 if (!LPhi && !RPhi)
12461 return false;
12462
12463 // If there is a SCEVUnknown Phi we are interested in, make it left.
12464 if (!LPhi) {
12465 std::swap(LHS, RHS);
12466 std::swap(FoundLHS, FoundRHS);
12467 std::swap(LPhi, RPhi);
12469 }
12470
12471 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12472 const BasicBlock *LBB = LPhi->getParent();
12473 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12474
12475 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12476 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12477 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12478 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12479 };
12480
12481 if (RPhi && RPhi->getParent() == LBB) {
12482 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12483 // If we compare two Phis from the same block, and for each entry block
12484 // the predicate is true for incoming values from this block, then the
12485 // predicate is also true for the Phis.
12486 for (const BasicBlock *IncBB : predecessors(LBB)) {
12487 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12488 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12489 if (!ProvedEasily(L, R))
12490 return false;
12491 }
12492 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12493 // Case two: RHS is also a Phi from the same basic block, and it is an
12494 // AddRec. It means that there is a loop which has both AddRec and Unknown
12495 // PHIs, for it we can compare incoming values of AddRec from above the loop
12496 // and latch with their respective incoming values of LPhi.
12497 // TODO: Generalize to handle loops with many inputs in a header.
12498 if (LPhi->getNumIncomingValues() != 2) return false;
12499
12500 auto *RLoop = RAR->getLoop();
12501 auto *Predecessor = RLoop->getLoopPredecessor();
12502 assert(Predecessor && "Loop with AddRec with no predecessor?");
12503 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12504 if (!ProvedEasily(L1, RAR->getStart()))
12505 return false;
12506 auto *Latch = RLoop->getLoopLatch();
12507 assert(Latch && "Loop with AddRec with no latch?");
12508 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12509 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12510 return false;
12511 } else {
12512 // In all other cases go over inputs of LHS and compare each of them to RHS,
12513 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12514 // At this point RHS is either a non-Phi, or it is a Phi from some block
12515 // different from LBB.
12516 for (const BasicBlock *IncBB : predecessors(LBB)) {
12517 // Check that RHS is available in this block.
12518 if (!dominates(RHS, IncBB))
12519 return false;
12520 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12521 // Make sure L does not refer to a value from a potentially previous
12522 // iteration of a loop.
12523 if (!properlyDominates(L, LBB))
12524 return false;
12525 // Addrecs are considered to properly dominate their loop, so are missed
12526 // by the previous check. Discard any values that have computable
12527 // evolution in this loop.
12528 if (auto *Loop = LI.getLoopFor(LBB))
12529 if (hasComputableLoopEvolution(L, Loop))
12530 return false;
12531 if (!ProvedEasily(L, RHS))
12532 return false;
12533 }
12534 }
12535 return true;
12536}
12537
12538bool ScalarEvolution::isImpliedCondOperandsViaShift(CmpPredicate Pred,
12539 const SCEV *LHS,
12540 const SCEV *RHS,
12541 const SCEV *FoundLHS,
12542 const SCEV *FoundRHS) {
12543 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12544 // sure that we are dealing with same LHS.
12545 if (RHS == FoundRHS) {
12546 std::swap(LHS, RHS);
12547 std::swap(FoundLHS, FoundRHS);
12549 }
12550 if (LHS != FoundLHS)
12551 return false;
12552
12553 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12554 if (!SUFoundRHS)
12555 return false;
12556
12557 Value *Shiftee, *ShiftValue;
12558
12559 using namespace PatternMatch;
12560 if (match(SUFoundRHS->getValue(),
12561 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12562 auto *ShifteeS = getSCEV(Shiftee);
12563 // Prove one of the following:
12564 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12565 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12566 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12567 // ---> LHS <s RHS
12568 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12569 // ---> LHS <=s RHS
12570 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12571 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12572 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12573 if (isKnownNonNegative(ShifteeS))
12574 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12575 }
12576
12577 return false;
12578}
12579
12580bool ScalarEvolution::isImpliedCondOperands(CmpPredicate Pred, const SCEV *LHS,
12581 const SCEV *RHS,
12582 const SCEV *FoundLHS,
12583 const SCEV *FoundRHS,
12584 const Instruction *CtxI) {
12585 return isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS,
12586 FoundRHS) ||
12587 isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS,
12588 FoundRHS) ||
12589 isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS) ||
12590 isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12591 CtxI) ||
12592 isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS);
12593}
12594
12595/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12596template <typename MinMaxExprType>
12597static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12598 const SCEV *Candidate) {
12599 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12600 if (!MinMaxExpr)
12601 return false;
12602
12603 return is_contained(MinMaxExpr->operands(), Candidate);
12604}
12605
12607 CmpPredicate Pred, const SCEV *LHS,
12608 const SCEV *RHS) {
12609 // If both sides are affine addrecs for the same loop, with equal
12610 // steps, and we know the recurrences don't wrap, then we only
12611 // need to check the predicate on the starting values.
12612
12613 if (!ICmpInst::isRelational(Pred))
12614 return false;
12615
12616 const SCEV *LStart, *RStart, *Step;
12617 const Loop *L;
12618 if (!match(LHS,
12619 m_scev_AffineAddRec(m_SCEV(LStart), m_SCEV(Step), m_Loop(L))) ||
12621 m_SpecificLoop(L))))
12622 return false;
12627 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12628 return false;
12629
12630 return SE.isKnownPredicate(Pred, LStart, RStart);
12631}
12632
12633/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12634/// expression?
12636 const SCEV *LHS, const SCEV *RHS) {
12637 switch (Pred) {
12638 default:
12639 return false;
12640
12641 case ICmpInst::ICMP_SGE:
12642 std::swap(LHS, RHS);
12643 [[fallthrough]];
12644 case ICmpInst::ICMP_SLE:
12645 return
12646 // min(A, ...) <= A
12648 // A <= max(A, ...)
12650
12651 case ICmpInst::ICMP_UGE:
12652 std::swap(LHS, RHS);
12653 [[fallthrough]];
12654 case ICmpInst::ICMP_ULE:
12655 return
12656 // min(A, ...) <= A
12657 // FIXME: what about umin_seq?
12659 // A <= max(A, ...)
12661 }
12662
12663 llvm_unreachable("covered switch fell through?!");
12664}
12665
12666bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS,
12667 const SCEV *RHS,
12668 const SCEV *FoundLHS,
12669 const SCEV *FoundRHS,
12670 unsigned Depth) {
12673 "LHS and RHS have different sizes?");
12674 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12675 getTypeSizeInBits(FoundRHS->getType()) &&
12676 "FoundLHS and FoundRHS have different sizes?");
12677 // We want to avoid hurting the compile time with analysis of too big trees.
12679 return false;
12680
12681 // We only want to work with GT comparison so far.
12682 if (ICmpInst::isLT(Pred)) {
12684 std::swap(LHS, RHS);
12685 std::swap(FoundLHS, FoundRHS);
12686 }
12687
12689
12690 // For unsigned, try to reduce it to corresponding signed comparison.
12691 if (P == ICmpInst::ICMP_UGT)
12692 // We can replace unsigned predicate with its signed counterpart if all
12693 // involved values are non-negative.
12694 // TODO: We could have better support for unsigned.
12695 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12696 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12697 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12698 // use this fact to prove that LHS and RHS are non-negative.
12699 const SCEV *MinusOne = getMinusOne(LHS->getType());
12700 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12701 FoundRHS) &&
12702 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12703 FoundRHS))
12705 }
12706
12707 if (P != ICmpInst::ICMP_SGT)
12708 return false;
12709
12710 auto GetOpFromSExt = [&](const SCEV *S) {
12711 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12712 return Ext->getOperand();
12713 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12714 // the constant in some cases.
12715 return S;
12716 };
12717
12718 // Acquire values from extensions.
12719 auto *OrigLHS = LHS;
12720 auto *OrigFoundLHS = FoundLHS;
12721 LHS = GetOpFromSExt(LHS);
12722 FoundLHS = GetOpFromSExt(FoundLHS);
12723
12724 // Is the SGT predicate can be proved trivially or using the found context.
12725 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12726 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12727 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12728 FoundRHS, Depth + 1);
12729 };
12730
12731 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12732 // We want to avoid creation of any new non-constant SCEV. Since we are
12733 // going to compare the operands to RHS, we should be certain that we don't
12734 // need any size extensions for this. So let's decline all cases when the
12735 // sizes of types of LHS and RHS do not match.
12736 // TODO: Maybe try to get RHS from sext to catch more cases?
12738 return false;
12739
12740 // Should not overflow.
12741 if (!LHSAddExpr->hasNoSignedWrap())
12742 return false;
12743
12744 auto *LL = LHSAddExpr->getOperand(0);
12745 auto *LR = LHSAddExpr->getOperand(1);
12746 auto *MinusOne = getMinusOne(RHS->getType());
12747
12748 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12749 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12750 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12751 };
12752 // Try to prove the following rule:
12753 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12754 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12755 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12756 return true;
12757 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12758 Value *LL, *LR;
12759 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12760
12761 using namespace llvm::PatternMatch;
12762
12763 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12764 // Rules for division.
12765 // We are going to perform some comparisons with Denominator and its
12766 // derivative expressions. In general case, creating a SCEV for it may
12767 // lead to a complex analysis of the entire graph, and in particular it
12768 // can request trip count recalculation for the same loop. This would
12769 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12770 // this, we only want to create SCEVs that are constants in this section.
12771 // So we bail if Denominator is not a constant.
12772 if (!isa<ConstantInt>(LR))
12773 return false;
12774
12775 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12776
12777 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
12778 // then a SCEV for the numerator already exists and matches with FoundLHS.
12779 auto *Numerator = getExistingSCEV(LL);
12780 if (!Numerator || Numerator->getType() != FoundLHS->getType())
12781 return false;
12782
12783 // Make sure that the numerator matches with FoundLHS and the denominator
12784 // is positive.
12785 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
12786 return false;
12787
12788 auto *DTy = Denominator->getType();
12789 auto *FRHSTy = FoundRHS->getType();
12790 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
12791 // One of types is a pointer and another one is not. We cannot extend
12792 // them properly to a wider type, so let us just reject this case.
12793 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
12794 // to avoid this check.
12795 return false;
12796
12797 // Given that:
12798 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
12799 auto *WTy = getWiderType(DTy, FRHSTy);
12800 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
12801 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
12802
12803 // Try to prove the following rule:
12804 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
12805 // For example, given that FoundLHS > 2. It means that FoundLHS is at
12806 // least 3. If we divide it by Denominator < 4, we will have at least 1.
12807 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
12808 if (isKnownNonPositive(RHS) &&
12809 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
12810 return true;
12811
12812 // Try to prove the following rule:
12813 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
12814 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
12815 // If we divide it by Denominator > 2, then:
12816 // 1. If FoundLHS is negative, then the result is 0.
12817 // 2. If FoundLHS is non-negative, then the result is non-negative.
12818 // Anyways, the result is non-negative.
12819 auto *MinusOne = getMinusOne(WTy);
12820 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
12821 if (isKnownNegative(RHS) &&
12822 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
12823 return true;
12824 }
12825 }
12826
12827 // If our expression contained SCEVUnknown Phis, and we split it down and now
12828 // need to prove something for them, try to prove the predicate for every
12829 // possible incoming values of those Phis.
12830 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
12831 return true;
12832
12833 return false;
12834}
12835
12837 const SCEV *RHS) {
12838 // zext x u<= sext x, sext x s<= zext x
12839 const SCEV *Op;
12840 switch (Pred) {
12841 case ICmpInst::ICMP_SGE:
12842 std::swap(LHS, RHS);
12843 [[fallthrough]];
12844 case ICmpInst::ICMP_SLE: {
12845 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12846 return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
12848 }
12849 case ICmpInst::ICMP_UGE:
12850 std::swap(LHS, RHS);
12851 [[fallthrough]];
12852 case ICmpInst::ICMP_ULE: {
12853 // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
12854 return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
12856 }
12857 default:
12858 return false;
12859 };
12860 llvm_unreachable("unhandled case");
12861}
12862
12863bool ScalarEvolution::isKnownViaNonRecursiveReasoning(CmpPredicate Pred,
12864 const SCEV *LHS,
12865 const SCEV *RHS) {
12866 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
12867 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
12868 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
12869 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
12870 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
12871}
12872
12873bool ScalarEvolution::isImpliedCondOperandsHelper(CmpPredicate Pred,
12874 const SCEV *LHS,
12875 const SCEV *RHS,
12876 const SCEV *FoundLHS,
12877 const SCEV *FoundRHS) {
12878 switch (Pred) {
12879 default:
12880 llvm_unreachable("Unexpected CmpPredicate value!");
12881 case ICmpInst::ICMP_EQ:
12882 case ICmpInst::ICMP_NE:
12883 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
12884 return true;
12885 break;
12886 case ICmpInst::ICMP_SLT:
12887 case ICmpInst::ICMP_SLE:
12888 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
12889 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
12890 return true;
12891 break;
12892 case ICmpInst::ICMP_SGT:
12893 case ICmpInst::ICMP_SGE:
12894 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
12895 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
12896 return true;
12897 break;
12898 case ICmpInst::ICMP_ULT:
12899 case ICmpInst::ICMP_ULE:
12900 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
12901 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
12902 return true;
12903 break;
12904 case ICmpInst::ICMP_UGT:
12905 case ICmpInst::ICMP_UGE:
12906 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
12907 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
12908 return true;
12909 break;
12910 }
12911
12912 // Maybe it can be proved via operations?
12913 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
12914 return true;
12915
12916 return false;
12917}
12918
12919bool ScalarEvolution::isImpliedCondOperandsViaRanges(
12920 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
12921 const SCEV *FoundLHS, const SCEV *FoundRHS) {
12922 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
12923 // The restriction on `FoundRHS` be lifted easily -- it exists only to
12924 // reduce the compile time impact of this optimization.
12925 return false;
12926
12927 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
12928 if (!Addend)
12929 return false;
12930
12931 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
12932
12933 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
12934 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
12935 ConstantRange FoundLHSRange =
12936 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
12937
12938 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
12939 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
12940
12941 // We can also compute the range of values for `LHS` that satisfy the
12942 // consequent, "`LHS` `Pred` `RHS`":
12943 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
12944 // The antecedent implies the consequent if every value of `LHS` that
12945 // satisfies the antecedent also satisfies the consequent.
12946 return LHSRange.icmp(Pred, ConstRHS);
12947}
12948
12949bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
12950 bool IsSigned) {
12951 assert(isKnownPositive(Stride) && "Positive stride expected!");
12952
12953 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12954 const SCEV *One = getOne(Stride->getType());
12955
12956 if (IsSigned) {
12957 APInt MaxRHS = getSignedRangeMax(RHS);
12958 APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
12959 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12960
12961 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
12962 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
12963 }
12964
12965 APInt MaxRHS = getUnsignedRangeMax(RHS);
12966 APInt MaxValue = APInt::getMaxValue(BitWidth);
12967 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12968
12969 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
12970 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
12971}
12972
12973bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
12974 bool IsSigned) {
12975
12976 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12977 const SCEV *One = getOne(Stride->getType());
12978
12979 if (IsSigned) {
12980 APInt MinRHS = getSignedRangeMin(RHS);
12981 APInt MinValue = APInt::getSignedMinValue(BitWidth);
12982 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12983
12984 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
12985 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
12986 }
12987
12988 APInt MinRHS = getUnsignedRangeMin(RHS);
12989 APInt MinValue = APInt::getMinValue(BitWidth);
12990 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12991
12992 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
12993 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
12994}
12995
12997 // umin(N, 1) + floor((N - umin(N, 1)) / D)
12998 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
12999 // expression fixes the case of N=0.
13000 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
13001 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
13002 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
13003}
13004
13005const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
13006 const SCEV *Stride,
13007 const SCEV *End,
13008 unsigned BitWidth,
13009 bool IsSigned) {
13010 // The logic in this function assumes we can represent a positive stride.
13011 // If we can't, the backedge-taken count must be zero.
13012 if (IsSigned && BitWidth == 1)
13013 return getZero(Stride->getType());
13014
13015 // This code below only been closely audited for negative strides in the
13016 // unsigned comparison case, it may be correct for signed comparison, but
13017 // that needs to be established.
13018 if (IsSigned && isKnownNegative(Stride))
13019 return getCouldNotCompute();
13020
13021 // Calculate the maximum backedge count based on the range of values
13022 // permitted by Start, End, and Stride.
13023 APInt MinStart =
13024 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
13025
13026 APInt MinStride =
13027 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
13028
13029 // We assume either the stride is positive, or the backedge-taken count
13030 // is zero. So force StrideForMaxBECount to be at least one.
13031 APInt One(BitWidth, 1);
13032 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
13033 : APIntOps::umax(One, MinStride);
13034
13035 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
13036 : APInt::getMaxValue(BitWidth);
13037 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
13038
13039 // Although End can be a MAX expression we estimate MaxEnd considering only
13040 // the case End = RHS of the loop termination condition. This is safe because
13041 // in the other case (End - Start) is zero, leading to a zero maximum backedge
13042 // taken count.
13043 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
13044 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
13045
13046 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
13047 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
13048 : APIntOps::umax(MaxEnd, MinStart);
13049
13050 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
13051 getConstant(StrideForMaxBECount) /* Step */);
13052}
13053
13055ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
13056 const Loop *L, bool IsSigned,
13057 bool ControlsOnlyExit, bool AllowPredicates) {
13059
13061 bool PredicatedIV = false;
13062 if (!IV) {
13063 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
13064 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
13065 if (AR && AR->getLoop() == L && AR->isAffine()) {
13066 auto canProveNUW = [&]() {
13067 // We can use the comparison to infer no-wrap flags only if it fully
13068 // controls the loop exit.
13069 if (!ControlsOnlyExit)
13070 return false;
13071
13072 if (!isLoopInvariant(RHS, L))
13073 return false;
13074
13075 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
13076 // We need the sequence defined by AR to strictly increase in the
13077 // unsigned integer domain for the logic below to hold.
13078 return false;
13079
13080 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
13081 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
13082 // If RHS <=u Limit, then there must exist a value V in the sequence
13083 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
13084 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
13085 // overflow occurs. This limit also implies that a signed comparison
13086 // (in the wide bitwidth) is equivalent to an unsigned comparison as
13087 // the high bits on both sides must be zero.
13088 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
13089 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
13090 Limit = Limit.zext(OuterBitWidth);
13091 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
13092 };
13093 auto Flags = AR->getNoWrapFlags();
13094 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
13095 Flags = setFlags(Flags, SCEV::FlagNUW);
13096
13097 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
13098 if (AR->hasNoUnsignedWrap()) {
13099 // Emulate what getZeroExtendExpr would have done during construction
13100 // if we'd been able to infer the fact just above at that time.
13101 const SCEV *Step = AR->getStepRecurrence(*this);
13102 Type *Ty = ZExt->getType();
13103 auto *S = getAddRecExpr(
13105 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
13107 }
13108 }
13109 }
13110 }
13111
13112
13113 if (!IV && AllowPredicates) {
13114 // Try to make this an AddRec using runtime tests, in the first X
13115 // iterations of this loop, where X is the SCEV expression found by the
13116 // algorithm below.
13117 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13118 PredicatedIV = true;
13119 }
13120
13121 // Avoid weird loops
13122 if (!IV || IV->getLoop() != L || !IV->isAffine())
13123 return getCouldNotCompute();
13124
13125 // A precondition of this method is that the condition being analyzed
13126 // reaches an exiting branch which dominates the latch. Given that, we can
13127 // assume that an increment which violates the nowrap specification and
13128 // produces poison must cause undefined behavior when the resulting poison
13129 // value is branched upon and thus we can conclude that the backedge is
13130 // taken no more often than would be required to produce that poison value.
13131 // Note that a well defined loop can exit on the iteration which violates
13132 // the nowrap specification if there is another exit (either explicit or
13133 // implicit/exceptional) which causes the loop to execute before the
13134 // exiting instruction we're analyzing would trigger UB.
13135 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13136 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13138
13139 const SCEV *Stride = IV->getStepRecurrence(*this);
13140
13141 bool PositiveStride = isKnownPositive(Stride);
13142
13143 // Avoid negative or zero stride values.
13144 if (!PositiveStride) {
13145 // We can compute the correct backedge taken count for loops with unknown
13146 // strides if we can prove that the loop is not an infinite loop with side
13147 // effects. Here's the loop structure we are trying to handle -
13148 //
13149 // i = start
13150 // do {
13151 // A[i] = i;
13152 // i += s;
13153 // } while (i < end);
13154 //
13155 // The backedge taken count for such loops is evaluated as -
13156 // (max(end, start + stride) - start - 1) /u stride
13157 //
13158 // The additional preconditions that we need to check to prove correctness
13159 // of the above formula is as follows -
13160 //
13161 // a) IV is either nuw or nsw depending upon signedness (indicated by the
13162 // NoWrap flag).
13163 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
13164 // no side effects within the loop)
13165 // c) loop has a single static exit (with no abnormal exits)
13166 //
13167 // Precondition a) implies that if the stride is negative, this is a single
13168 // trip loop. The backedge taken count formula reduces to zero in this case.
13169 //
13170 // Precondition b) and c) combine to imply that if rhs is invariant in L,
13171 // then a zero stride means the backedge can't be taken without executing
13172 // undefined behavior.
13173 //
13174 // The positive stride case is the same as isKnownPositive(Stride) returning
13175 // true (original behavior of the function).
13176 //
13177 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
13179 return getCouldNotCompute();
13180
13181 if (!isKnownNonZero(Stride)) {
13182 // If we have a step of zero, and RHS isn't invariant in L, we don't know
13183 // if it might eventually be greater than start and if so, on which
13184 // iteration. We can't even produce a useful upper bound.
13185 if (!isLoopInvariant(RHS, L))
13186 return getCouldNotCompute();
13187
13188 // We allow a potentially zero stride, but we need to divide by stride
13189 // below. Since the loop can't be infinite and this check must control
13190 // the sole exit, we can infer the exit must be taken on the first
13191 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
13192 // we know the numerator in the divides below must be zero, so we can
13193 // pick an arbitrary non-zero value for the denominator (e.g. stride)
13194 // and produce the right result.
13195 // FIXME: Handle the case where Stride is poison?
13196 auto wouldZeroStrideBeUB = [&]() {
13197 // Proof by contradiction. Suppose the stride were zero. If we can
13198 // prove that the backedge *is* taken on the first iteration, then since
13199 // we know this condition controls the sole exit, we must have an
13200 // infinite loop. We can't have a (well defined) infinite loop per
13201 // check just above.
13202 // Note: The (Start - Stride) term is used to get the start' term from
13203 // (start' + stride,+,stride). Remember that we only care about the
13204 // result of this expression when stride == 0 at runtime.
13205 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
13206 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
13207 };
13208 if (!wouldZeroStrideBeUB()) {
13209 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
13210 }
13211 }
13212 } else if (!NoWrap) {
13213 // Avoid proven overflow cases: this will ensure that the backedge taken
13214 // count will not generate any unsigned overflow.
13215 if (canIVOverflowOnLT(RHS, Stride, IsSigned))
13216 return getCouldNotCompute();
13217 }
13218
13219 // On all paths just preceeding, we established the following invariant:
13220 // IV can be assumed not to overflow up to and including the exiting
13221 // iteration. We proved this in one of two ways:
13222 // 1) We can show overflow doesn't occur before the exiting iteration
13223 // 1a) canIVOverflowOnLT, and b) step of one
13224 // 2) We can show that if overflow occurs, the loop must execute UB
13225 // before any possible exit.
13226 // Note that we have not yet proved RHS invariant (in general).
13227
13228 const SCEV *Start = IV->getStart();
13229
13230 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
13231 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
13232 // Use integer-typed versions for actual computation; we can't subtract
13233 // pointers in general.
13234 const SCEV *OrigStart = Start;
13235 const SCEV *OrigRHS = RHS;
13236 if (Start->getType()->isPointerTy()) {
13238 if (isa<SCEVCouldNotCompute>(Start))
13239 return Start;
13240 }
13241 if (RHS->getType()->isPointerTy()) {
13244 return RHS;
13245 }
13246
13247 const SCEV *End = nullptr, *BECount = nullptr,
13248 *BECountIfBackedgeTaken = nullptr;
13249 if (!isLoopInvariant(RHS, L)) {
13250 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13251 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13252 RHSAddRec->getNoWrapFlags()) {
13253 // The structure of loop we are trying to calculate backedge count of:
13254 //
13255 // left = left_start
13256 // right = right_start
13257 //
13258 // while(left < right){
13259 // ... do something here ...
13260 // left += s1; // stride of left is s1 (s1 > 0)
13261 // right += s2; // stride of right is s2 (s2 < 0)
13262 // }
13263 //
13264
13265 const SCEV *RHSStart = RHSAddRec->getStart();
13266 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13267
13268 // If Stride - RHSStride is positive and does not overflow, we can write
13269 // backedge count as ->
13270 // ceil((End - Start) /u (Stride - RHSStride))
13271 // Where, End = max(RHSStart, Start)
13272
13273 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13274 if (isKnownNegative(RHSStride) &&
13275 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13276 RHSStride)) {
13277
13278 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13279 if (isKnownPositive(Denominator)) {
13280 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13281 : getUMaxExpr(RHSStart, Start);
13282
13283 // We can do this because End >= Start, as End = max(RHSStart, Start)
13284 const SCEV *Delta = getMinusSCEV(End, Start);
13285
13286 BECount = getUDivCeilSCEV(Delta, Denominator);
13287 BECountIfBackedgeTaken =
13288 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13289 }
13290 }
13291 }
13292 if (BECount == nullptr) {
13293 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13294 // given the start, stride and max value for the end bound of the
13295 // loop (RHS), and the fact that IV does not overflow (which is
13296 // checked above).
13297 const SCEV *MaxBECount = computeMaxBECountForLT(
13298 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13299 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13300 MaxBECount, false /*MaxOrZero*/, Predicates);
13301 }
13302 } else {
13303 // We use the expression (max(End,Start)-Start)/Stride to describe the
13304 // backedge count, as if the backedge is taken at least once
13305 // max(End,Start) is End and so the result is as above, and if not
13306 // max(End,Start) is Start so we get a backedge count of zero.
13307 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13308 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13309 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13310 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13311 // Can we prove (max(RHS,Start) > Start - Stride?
13312 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13313 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13314 // In this case, we can use a refined formula for computing backedge
13315 // taken count. The general formula remains:
13316 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13317 // We want to use the alternate formula:
13318 // "((End - 1) - (Start - Stride)) /u Stride"
13319 // Let's do a quick case analysis to show these are equivalent under
13320 // our precondition that max(RHS,Start) > Start - Stride.
13321 // * For RHS <= Start, the backedge-taken count must be zero.
13322 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13323 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13324 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13325 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13326 // reducing this to the stride of 1 case.
13327 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13328 // Stride".
13329 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13330 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13331 // "((RHS - (Start - Stride) - 1) /u Stride".
13332 // Our preconditions trivially imply no overflow in that form.
13333 const SCEV *MinusOne = getMinusOne(Stride->getType());
13334 const SCEV *Numerator =
13335 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13336 BECount = getUDivExpr(Numerator, Stride);
13337 }
13338
13339 if (!BECount) {
13340 auto canProveRHSGreaterThanEqualStart = [&]() {
13341 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13342 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13343 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13344
13345 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13346 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13347 return true;
13348
13349 // (RHS > Start - 1) implies RHS >= Start.
13350 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13351 // "Start - 1" doesn't overflow.
13352 // * For signed comparison, if Start - 1 does overflow, it's equal
13353 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13354 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13355 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13356 //
13357 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13358 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13359 auto *StartMinusOne =
13360 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13361 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13362 };
13363
13364 // If we know that RHS >= Start in the context of loop, then we know
13365 // that max(RHS, Start) = RHS at this point.
13366 if (canProveRHSGreaterThanEqualStart()) {
13367 End = RHS;
13368 } else {
13369 // If RHS < Start, the backedge will be taken zero times. So in
13370 // general, we can write the backedge-taken count as:
13371 //
13372 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13373 //
13374 // We convert it to the following to make it more convenient for SCEV:
13375 //
13376 // ceil(max(RHS, Start) - Start) / Stride
13377 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13378
13379 // See what would happen if we assume the backedge is taken. This is
13380 // used to compute MaxBECount.
13381 BECountIfBackedgeTaken =
13382 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13383 }
13384
13385 // At this point, we know:
13386 //
13387 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13388 // 2. The index variable doesn't overflow.
13389 //
13390 // Therefore, we know N exists such that
13391 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13392 // doesn't overflow.
13393 //
13394 // Using this information, try to prove whether the addition in
13395 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13396 const SCEV *One = getOne(Stride->getType());
13397 bool MayAddOverflow = [&] {
13398 if (isKnownToBeAPowerOfTwo(Stride)) {
13399 // Suppose Stride is a power of two, and Start/End are unsigned
13400 // integers. Let UMAX be the largest representable unsigned
13401 // integer.
13402 //
13403 // By the preconditions of this function, we know
13404 // "(Start + Stride * N) >= End", and this doesn't overflow.
13405 // As a formula:
13406 //
13407 // End <= (Start + Stride * N) <= UMAX
13408 //
13409 // Subtracting Start from all the terms:
13410 //
13411 // End - Start <= Stride * N <= UMAX - Start
13412 //
13413 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13414 //
13415 // End - Start <= Stride * N <= UMAX
13416 //
13417 // Stride * N is a multiple of Stride. Therefore,
13418 //
13419 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13420 //
13421 // Since Stride is a power of two, UMAX + 1 is divisible by
13422 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13423 // write:
13424 //
13425 // End - Start <= Stride * N <= UMAX - Stride - 1
13426 //
13427 // Dropping the middle term:
13428 //
13429 // End - Start <= UMAX - Stride - 1
13430 //
13431 // Adding Stride - 1 to both sides:
13432 //
13433 // (End - Start) + (Stride - 1) <= UMAX
13434 //
13435 // In other words, the addition doesn't have unsigned overflow.
13436 //
13437 // A similar proof works if we treat Start/End as signed values.
13438 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13439 // to use signed max instead of unsigned max. Note that we're
13440 // trying to prove a lack of unsigned overflow in either case.
13441 return false;
13442 }
13443 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13444 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13445 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13446 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13447 // 1 <s End.
13448 //
13449 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13450 // End.
13451 return false;
13452 }
13453 return true;
13454 }();
13455
13456 const SCEV *Delta = getMinusSCEV(End, Start);
13457 if (!MayAddOverflow) {
13458 // floor((D + (S - 1)) / S)
13459 // We prefer this formulation if it's legal because it's fewer
13460 // operations.
13461 BECount =
13462 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13463 } else {
13464 BECount = getUDivCeilSCEV(Delta, Stride);
13465 }
13466 }
13467 }
13468
13469 const SCEV *ConstantMaxBECount;
13470 bool MaxOrZero = false;
13471 if (isa<SCEVConstant>(BECount)) {
13472 ConstantMaxBECount = BECount;
13473 } else if (BECountIfBackedgeTaken &&
13474 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13475 // If we know exactly how many times the backedge will be taken if it's
13476 // taken at least once, then the backedge count will either be that or
13477 // zero.
13478 ConstantMaxBECount = BECountIfBackedgeTaken;
13479 MaxOrZero = true;
13480 } else {
13481 ConstantMaxBECount = computeMaxBECountForLT(
13482 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13483 }
13484
13485 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13486 !isa<SCEVCouldNotCompute>(BECount))
13487 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13488
13489 const SCEV *SymbolicMaxBECount =
13490 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13491 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13492 Predicates);
13493}
13494
13495ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13496 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13497 bool ControlsOnlyExit, bool AllowPredicates) {
13499 // We handle only IV > Invariant
13500 if (!isLoopInvariant(RHS, L))
13501 return getCouldNotCompute();
13502
13503 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13504 if (!IV && AllowPredicates)
13505 // Try to make this an AddRec using runtime tests, in the first X
13506 // iterations of this loop, where X is the SCEV expression found by the
13507 // algorithm below.
13508 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13509
13510 // Avoid weird loops
13511 if (!IV || IV->getLoop() != L || !IV->isAffine())
13512 return getCouldNotCompute();
13513
13514 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13515 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13517
13518 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13519
13520 // Avoid negative or zero stride values
13521 if (!isKnownPositive(Stride))
13522 return getCouldNotCompute();
13523
13524 // Avoid proven overflow cases: this will ensure that the backedge taken count
13525 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13526 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13527 // behaviors like the case of C language.
13528 if (!Stride->isOne() && !NoWrap)
13529 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13530 return getCouldNotCompute();
13531
13532 const SCEV *Start = IV->getStart();
13533 const SCEV *End = RHS;
13534 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13535 // If we know that Start >= RHS in the context of loop, then we know that
13536 // min(RHS, Start) = RHS at this point.
13538 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13539 End = RHS;
13540 else
13541 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13542 }
13543
13544 if (Start->getType()->isPointerTy()) {
13546 if (isa<SCEVCouldNotCompute>(Start))
13547 return Start;
13548 }
13549 if (End->getType()->isPointerTy()) {
13550 End = getLosslessPtrToIntExpr(End);
13551 if (isa<SCEVCouldNotCompute>(End))
13552 return End;
13553 }
13554
13555 // Compute ((Start - End) + (Stride - 1)) / Stride.
13556 // FIXME: This can overflow. Holding off on fixing this for now;
13557 // howManyGreaterThans will hopefully be gone soon.
13558 const SCEV *One = getOne(Stride->getType());
13559 const SCEV *BECount = getUDivExpr(
13560 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13561
13562 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13564
13565 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13566 : getUnsignedRangeMin(Stride);
13567
13568 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13569 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13570 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13571
13572 // Although End can be a MIN expression we estimate MinEnd considering only
13573 // the case End = RHS. This is safe because in the other case (Start - End)
13574 // is zero, leading to a zero maximum backedge taken count.
13575 APInt MinEnd =
13576 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13577 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13578
13579 const SCEV *ConstantMaxBECount =
13580 isa<SCEVConstant>(BECount)
13581 ? BECount
13582 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13583 getConstant(MinStride));
13584
13585 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13586 ConstantMaxBECount = BECount;
13587 const SCEV *SymbolicMaxBECount =
13588 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13589
13590 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13591 Predicates);
13592}
13593
13595 ScalarEvolution &SE) const {
13596 if (Range.isFullSet()) // Infinite loop.
13597 return SE.getCouldNotCompute();
13598
13599 // If the start is a non-zero constant, shift the range to simplify things.
13600 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13601 if (!SC->getValue()->isZero()) {
13603 Operands[0] = SE.getZero(SC->getType());
13604 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13606 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13607 return ShiftedAddRec->getNumIterationsInRange(
13608 Range.subtract(SC->getAPInt()), SE);
13609 // This is strange and shouldn't happen.
13610 return SE.getCouldNotCompute();
13611 }
13612
13613 // The only time we can solve this is when we have all constant indices.
13614 // Otherwise, we cannot determine the overflow conditions.
13615 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13616 return SE.getCouldNotCompute();
13617
13618 // Okay at this point we know that all elements of the chrec are constants and
13619 // that the start element is zero.
13620
13621 // First check to see if the range contains zero. If not, the first
13622 // iteration exits.
13623 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13624 if (!Range.contains(APInt(BitWidth, 0)))
13625 return SE.getZero(getType());
13626
13627 if (isAffine()) {
13628 // If this is an affine expression then we have this situation:
13629 // Solve {0,+,A} in Range === Ax in Range
13630
13631 // We know that zero is in the range. If A is positive then we know that
13632 // the upper value of the range must be the first possible exit value.
13633 // If A is negative then the lower of the range is the last possible loop
13634 // value. Also note that we already checked for a full range.
13635 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13636 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13637
13638 // The exit value should be (End+A)/A.
13639 APInt ExitVal = (End + A).udiv(A);
13640 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13641
13642 // Evaluate at the exit value. If we really did fall out of the valid
13643 // range, then we computed our trip count, otherwise wrap around or other
13644 // things must have happened.
13645 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13646 if (Range.contains(Val->getValue()))
13647 return SE.getCouldNotCompute(); // Something strange happened
13648
13649 // Ensure that the previous value is in the range.
13650 assert(Range.contains(
13652 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13653 "Linear scev computation is off in a bad way!");
13654 return SE.getConstant(ExitValue);
13655 }
13656
13657 if (isQuadratic()) {
13658 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13659 return SE.getConstant(*S);
13660 }
13661
13662 return SE.getCouldNotCompute();
13663}
13664
13665const SCEVAddRecExpr *
13667 assert(getNumOperands() > 1 && "AddRec with zero step?");
13668 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13669 // but in this case we cannot guarantee that the value returned will be an
13670 // AddRec because SCEV does not have a fixed point where it stops
13671 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13672 // may happen if we reach arithmetic depth limit while simplifying. So we
13673 // construct the returned value explicitly.
13675 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13676 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13677 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13678 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13679 // We know that the last operand is not a constant zero (otherwise it would
13680 // have been popped out earlier). This guarantees us that if the result has
13681 // the same last operand, then it will also not be popped out, meaning that
13682 // the returned value will be an AddRec.
13683 const SCEV *Last = getOperand(getNumOperands() - 1);
13684 assert(!Last->isZero() && "Recurrency with zero step?");
13685 Ops.push_back(Last);
13688}
13689
13690// Return true when S contains at least an undef value.
13692 return SCEVExprContains(
13693 S, [](const SCEV *S) { return match(S, m_scev_UndefOrPoison()); });
13694}
13695
13696// Return true when S contains a value that is a nullptr.
13698 return SCEVExprContains(S, [](const SCEV *S) {
13699 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13700 return SU->getValue() == nullptr;
13701 return false;
13702 });
13703}
13704
13705/// Return the size of an element read or written by Inst.
13707 Type *Ty;
13708 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13709 Ty = Store->getValueOperand()->getType();
13710 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13711 Ty = Load->getType();
13712 else
13713 return nullptr;
13714
13716 return getSizeOfExpr(ETy, Ty);
13717}
13718
13719//===----------------------------------------------------------------------===//
13720// SCEVCallbackVH Class Implementation
13721//===----------------------------------------------------------------------===//
13722
13724 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13725 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13726 SE->ConstantEvolutionLoopExitValue.erase(PN);
13727 SE->eraseValueFromMap(getValPtr());
13728 // this now dangles!
13729}
13730
13731void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13732 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13733
13734 // Forget all the expressions associated with users of the old value,
13735 // so that future queries will recompute the expressions using the new
13736 // value.
13737 SE->forgetValue(getValPtr());
13738 // this now dangles!
13739}
13740
13741ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13742 : CallbackVH(V), SE(se) {}
13743
13744//===----------------------------------------------------------------------===//
13745// ScalarEvolution Class Implementation
13746//===----------------------------------------------------------------------===//
13747
13750 LoopInfo &LI)
13751 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
13752 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13753 LoopDispositions(64), BlockDispositions(64) {
13754 // To use guards for proving predicates, we need to scan every instruction in
13755 // relevant basic blocks, and not just terminators. Doing this is a waste of
13756 // time if the IR does not actually contain any calls to
13757 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13758 //
13759 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13760 // to _add_ guards to the module when there weren't any before, and wants
13761 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13762 // efficient in lieu of being smart in that rather obscure case.
13763
13764 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
13765 F.getParent(), Intrinsic::experimental_guard);
13766 HasGuards = GuardDecl && !GuardDecl->use_empty();
13767}
13768
13770 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
13771 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13772 ValueExprMap(std::move(Arg.ValueExprMap)),
13773 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13774 PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
13775 PendingMerges(std::move(Arg.PendingMerges)),
13776 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
13777 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13778 PredicatedBackedgeTakenCounts(
13779 std::move(Arg.PredicatedBackedgeTakenCounts)),
13780 BECountUsers(std::move(Arg.BECountUsers)),
13781 ConstantEvolutionLoopExitValue(
13782 std::move(Arg.ConstantEvolutionLoopExitValue)),
13783 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
13784 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
13785 LoopDispositions(std::move(Arg.LoopDispositions)),
13786 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
13787 BlockDispositions(std::move(Arg.BlockDispositions)),
13788 SCEVUsers(std::move(Arg.SCEVUsers)),
13789 UnsignedRanges(std::move(Arg.UnsignedRanges)),
13790 SignedRanges(std::move(Arg.SignedRanges)),
13791 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
13792 UniquePreds(std::move(Arg.UniquePreds)),
13793 SCEVAllocator(std::move(Arg.SCEVAllocator)),
13794 LoopUsers(std::move(Arg.LoopUsers)),
13795 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
13796 FirstUnknown(Arg.FirstUnknown) {
13797 Arg.FirstUnknown = nullptr;
13798}
13799
13801 // Iterate through all the SCEVUnknown instances and call their
13802 // destructors, so that they release their references to their values.
13803 for (SCEVUnknown *U = FirstUnknown; U;) {
13804 SCEVUnknown *Tmp = U;
13805 U = U->Next;
13806 Tmp->~SCEVUnknown();
13807 }
13808 FirstUnknown = nullptr;
13809
13810 ExprValueMap.clear();
13811 ValueExprMap.clear();
13812 HasRecMap.clear();
13813 BackedgeTakenCounts.clear();
13814 PredicatedBackedgeTakenCounts.clear();
13815
13816 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
13817 assert(PendingPhiRanges.empty() && "getRangeRef garbage");
13818 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
13819 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
13820 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
13821}
13822
13826
13827/// When printing a top-level SCEV for trip counts, it's helpful to include
13828/// a type for constants which are otherwise hard to disambiguate.
13829static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
13830 if (isa<SCEVConstant>(S))
13831 OS << *S->getType() << " ";
13832 OS << *S;
13833}
13834
13836 const Loop *L) {
13837 // Print all inner loops first
13838 for (Loop *I : *L)
13839 PrintLoopInfo(OS, SE, I);
13840
13841 OS << "Loop ";
13842 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13843 OS << ": ";
13844
13845 SmallVector<BasicBlock *, 8> ExitingBlocks;
13846 L->getExitingBlocks(ExitingBlocks);
13847 if (ExitingBlocks.size() != 1)
13848 OS << "<multiple exits> ";
13849
13850 auto *BTC = SE->getBackedgeTakenCount(L);
13851 if (!isa<SCEVCouldNotCompute>(BTC)) {
13852 OS << "backedge-taken count is ";
13853 PrintSCEVWithTypeHint(OS, BTC);
13854 } else
13855 OS << "Unpredictable backedge-taken count.";
13856 OS << "\n";
13857
13858 if (ExitingBlocks.size() > 1)
13859 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13860 OS << " exit count for " << ExitingBlock->getName() << ": ";
13861 const SCEV *EC = SE->getExitCount(L, ExitingBlock);
13862 PrintSCEVWithTypeHint(OS, EC);
13863 if (isa<SCEVCouldNotCompute>(EC)) {
13864 // Retry with predicates.
13866 EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
13867 if (!isa<SCEVCouldNotCompute>(EC)) {
13868 OS << "\n predicated exit count for " << ExitingBlock->getName()
13869 << ": ";
13870 PrintSCEVWithTypeHint(OS, EC);
13871 OS << "\n Predicates:\n";
13872 for (const auto *P : Predicates)
13873 P->print(OS, 4);
13874 }
13875 }
13876 OS << "\n";
13877 }
13878
13879 OS << "Loop ";
13880 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13881 OS << ": ";
13882
13883 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
13884 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
13885 OS << "constant max backedge-taken count is ";
13886 PrintSCEVWithTypeHint(OS, ConstantBTC);
13888 OS << ", actual taken count either this or zero.";
13889 } else {
13890 OS << "Unpredictable constant max backedge-taken count. ";
13891 }
13892
13893 OS << "\n"
13894 "Loop ";
13895 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13896 OS << ": ";
13897
13898 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
13899 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
13900 OS << "symbolic max backedge-taken count is ";
13901 PrintSCEVWithTypeHint(OS, SymbolicBTC);
13903 OS << ", actual taken count either this or zero.";
13904 } else {
13905 OS << "Unpredictable symbolic max backedge-taken count. ";
13906 }
13907 OS << "\n";
13908
13909 if (ExitingBlocks.size() > 1)
13910 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13911 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
13912 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
13914 PrintSCEVWithTypeHint(OS, ExitBTC);
13915 if (isa<SCEVCouldNotCompute>(ExitBTC)) {
13916 // Retry with predicates.
13918 ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
13920 if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
13921 OS << "\n predicated symbolic max exit count for "
13922 << ExitingBlock->getName() << ": ";
13923 PrintSCEVWithTypeHint(OS, ExitBTC);
13924 OS << "\n Predicates:\n";
13925 for (const auto *P : Predicates)
13926 P->print(OS, 4);
13927 }
13928 }
13929 OS << "\n";
13930 }
13931
13933 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
13934 if (PBT != BTC) {
13935 assert(!Preds.empty() && "Different predicated BTC, but no predicates");
13936 OS << "Loop ";
13937 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13938 OS << ": ";
13939 if (!isa<SCEVCouldNotCompute>(PBT)) {
13940 OS << "Predicated backedge-taken count is ";
13941 PrintSCEVWithTypeHint(OS, PBT);
13942 } else
13943 OS << "Unpredictable predicated backedge-taken count.";
13944 OS << "\n";
13945 OS << " Predicates:\n";
13946 for (const auto *P : Preds)
13947 P->print(OS, 4);
13948 }
13949 Preds.clear();
13950
13951 auto *PredConstantMax =
13953 if (PredConstantMax != ConstantBTC) {
13954 assert(!Preds.empty() &&
13955 "different predicated constant max BTC but no predicates");
13956 OS << "Loop ";
13957 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13958 OS << ": ";
13959 if (!isa<SCEVCouldNotCompute>(PredConstantMax)) {
13960 OS << "Predicated constant max backedge-taken count is ";
13961 PrintSCEVWithTypeHint(OS, PredConstantMax);
13962 } else
13963 OS << "Unpredictable predicated constant max backedge-taken count.";
13964 OS << "\n";
13965 OS << " Predicates:\n";
13966 for (const auto *P : Preds)
13967 P->print(OS, 4);
13968 }
13969 Preds.clear();
13970
13971 auto *PredSymbolicMax =
13973 if (SymbolicBTC != PredSymbolicMax) {
13974 assert(!Preds.empty() &&
13975 "Different predicated symbolic max BTC, but no predicates");
13976 OS << "Loop ";
13977 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13978 OS << ": ";
13979 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
13980 OS << "Predicated symbolic max backedge-taken count is ";
13981 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
13982 } else
13983 OS << "Unpredictable predicated symbolic max backedge-taken count.";
13984 OS << "\n";
13985 OS << " Predicates:\n";
13986 for (const auto *P : Preds)
13987 P->print(OS, 4);
13988 }
13989
13991 OS << "Loop ";
13992 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13993 OS << ": ";
13994 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
13995 }
13996}
13997
13998namespace llvm {
13999// Note: these overloaded operators need to be in the llvm namespace for them
14000// to be resolved correctly. If we put them outside the llvm namespace, the
14001//
14002// OS << ": " << SE.getLoopDisposition(SV, InnerL);
14003//
14004// code below "breaks" and start printing raw enum values as opposed to the
14005// string values.
14008 switch (LD) {
14010 OS << "Variant";
14011 break;
14013 OS << "Invariant";
14014 break;
14016 OS << "Computable";
14017 break;
14018 }
14019 return OS;
14020}
14021
14024 switch (BD) {
14026 OS << "DoesNotDominate";
14027 break;
14029 OS << "Dominates";
14030 break;
14032 OS << "ProperlyDominates";
14033 break;
14034 }
14035 return OS;
14036}
14037} // namespace llvm
14038
14040 // ScalarEvolution's implementation of the print method is to print
14041 // out SCEV values of all instructions that are interesting. Doing
14042 // this potentially causes it to create new SCEV objects though,
14043 // which technically conflicts with the const qualifier. This isn't
14044 // observable from outside the class though, so casting away the
14045 // const isn't dangerous.
14046 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14047
14048 if (ClassifyExpressions) {
14049 OS << "Classifying expressions for: ";
14050 F.printAsOperand(OS, /*PrintType=*/false);
14051 OS << "\n";
14052 for (Instruction &I : instructions(F))
14053 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
14054 OS << I << '\n';
14055 OS << " --> ";
14056 const SCEV *SV = SE.getSCEV(&I);
14057 SV->print(OS);
14058 if (!isa<SCEVCouldNotCompute>(SV)) {
14059 OS << " U: ";
14060 SE.getUnsignedRange(SV).print(OS);
14061 OS << " S: ";
14062 SE.getSignedRange(SV).print(OS);
14063 }
14064
14065 const Loop *L = LI.getLoopFor(I.getParent());
14066
14067 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
14068 if (AtUse != SV) {
14069 OS << " --> ";
14070 AtUse->print(OS);
14071 if (!isa<SCEVCouldNotCompute>(AtUse)) {
14072 OS << " U: ";
14073 SE.getUnsignedRange(AtUse).print(OS);
14074 OS << " S: ";
14075 SE.getSignedRange(AtUse).print(OS);
14076 }
14077 }
14078
14079 if (L) {
14080 OS << "\t\t" "Exits: ";
14081 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
14082 if (!SE.isLoopInvariant(ExitValue, L)) {
14083 OS << "<<Unknown>>";
14084 } else {
14085 OS << *ExitValue;
14086 }
14087
14088 ListSeparator LS(", ", "\t\tLoopDispositions: { ");
14089 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
14090 OS << LS;
14091 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14092 OS << ": " << SE.getLoopDisposition(SV, Iter);
14093 }
14094
14095 for (const auto *InnerL : depth_first(L)) {
14096 if (InnerL == L)
14097 continue;
14098 OS << LS;
14099 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14100 OS << ": " << SE.getLoopDisposition(SV, InnerL);
14101 }
14102
14103 OS << " }";
14104 }
14105
14106 OS << "\n";
14107 }
14108 }
14109
14110 OS << "Determining loop execution counts for: ";
14111 F.printAsOperand(OS, /*PrintType=*/false);
14112 OS << "\n";
14113 for (Loop *I : LI)
14114 PrintLoopInfo(OS, &SE, I);
14115}
14116
14119 auto &Values = LoopDispositions[S];
14120 for (auto &V : Values) {
14121 if (V.getPointer() == L)
14122 return V.getInt();
14123 }
14124 Values.emplace_back(L, LoopVariant);
14125 LoopDisposition D = computeLoopDisposition(S, L);
14126 auto &Values2 = LoopDispositions[S];
14127 for (auto &V : llvm::reverse(Values2)) {
14128 if (V.getPointer() == L) {
14129 V.setInt(D);
14130 break;
14131 }
14132 }
14133 return D;
14134}
14135
14137ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
14138 switch (S->getSCEVType()) {
14139 case scConstant:
14140 case scVScale:
14141 return LoopInvariant;
14142 case scAddRecExpr: {
14143 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14144
14145 // If L is the addrec's loop, it's computable.
14146 if (AR->getLoop() == L)
14147 return LoopComputable;
14148
14149 // Add recurrences are never invariant in the function-body (null loop).
14150 if (!L)
14151 return LoopVariant;
14152
14153 // Everything that is not defined at loop entry is variant.
14154 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
14155 return LoopVariant;
14156 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
14157 " dominate the contained loop's header?");
14158
14159 // This recurrence is invariant w.r.t. L if AR's loop contains L.
14160 if (AR->getLoop()->contains(L))
14161 return LoopInvariant;
14162
14163 // This recurrence is variant w.r.t. L if any of its operands
14164 // are variant.
14165 for (const auto *Op : AR->operands())
14166 if (!isLoopInvariant(Op, L))
14167 return LoopVariant;
14168
14169 // Otherwise it's loop-invariant.
14170 return LoopInvariant;
14171 }
14172 case scTruncate:
14173 case scZeroExtend:
14174 case scSignExtend:
14175 case scPtrToInt:
14176 case scAddExpr:
14177 case scMulExpr:
14178 case scUDivExpr:
14179 case scUMaxExpr:
14180 case scSMaxExpr:
14181 case scUMinExpr:
14182 case scSMinExpr:
14183 case scSequentialUMinExpr: {
14184 bool HasVarying = false;
14185 for (const auto *Op : S->operands()) {
14187 if (D == LoopVariant)
14188 return LoopVariant;
14189 if (D == LoopComputable)
14190 HasVarying = true;
14191 }
14192 return HasVarying ? LoopComputable : LoopInvariant;
14193 }
14194 case scUnknown:
14195 // All non-instruction values are loop invariant. All instructions are loop
14196 // invariant if they are not contained in the specified loop.
14197 // Instructions are never considered invariant in the function body
14198 // (null loop) because they are defined within the "loop".
14199 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
14200 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
14201 return LoopInvariant;
14202 case scCouldNotCompute:
14203 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14204 }
14205 llvm_unreachable("Unknown SCEV kind!");
14206}
14207
14209 return getLoopDisposition(S, L) == LoopInvariant;
14210}
14211
14213 return getLoopDisposition(S, L) == LoopComputable;
14214}
14215
14218 auto &Values = BlockDispositions[S];
14219 for (auto &V : Values) {
14220 if (V.getPointer() == BB)
14221 return V.getInt();
14222 }
14223 Values.emplace_back(BB, DoesNotDominateBlock);
14224 BlockDisposition D = computeBlockDisposition(S, BB);
14225 auto &Values2 = BlockDispositions[S];
14226 for (auto &V : llvm::reverse(Values2)) {
14227 if (V.getPointer() == BB) {
14228 V.setInt(D);
14229 break;
14230 }
14231 }
14232 return D;
14233}
14234
14236ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
14237 switch (S->getSCEVType()) {
14238 case scConstant:
14239 case scVScale:
14241 case scAddRecExpr: {
14242 // This uses a "dominates" query instead of "properly dominates" query
14243 // to test for proper dominance too, because the instruction which
14244 // produces the addrec's value is a PHI, and a PHI effectively properly
14245 // dominates its entire containing block.
14246 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14247 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
14248 return DoesNotDominateBlock;
14249
14250 // Fall through into SCEVNAryExpr handling.
14251 [[fallthrough]];
14252 }
14253 case scTruncate:
14254 case scZeroExtend:
14255 case scSignExtend:
14256 case scPtrToInt:
14257 case scAddExpr:
14258 case scMulExpr:
14259 case scUDivExpr:
14260 case scUMaxExpr:
14261 case scSMaxExpr:
14262 case scUMinExpr:
14263 case scSMinExpr:
14264 case scSequentialUMinExpr: {
14265 bool Proper = true;
14266 for (const SCEV *NAryOp : S->operands()) {
14268 if (D == DoesNotDominateBlock)
14269 return DoesNotDominateBlock;
14270 if (D == DominatesBlock)
14271 Proper = false;
14272 }
14273 return Proper ? ProperlyDominatesBlock : DominatesBlock;
14274 }
14275 case scUnknown:
14276 if (Instruction *I =
14277 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
14278 if (I->getParent() == BB)
14279 return DominatesBlock;
14280 if (DT.properlyDominates(I->getParent(), BB))
14282 return DoesNotDominateBlock;
14283 }
14285 case scCouldNotCompute:
14286 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14287 }
14288 llvm_unreachable("Unknown SCEV kind!");
14289}
14290
14291bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14292 return getBlockDisposition(S, BB) >= DominatesBlock;
14293}
14294
14297}
14298
14299bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14300 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14301}
14302
14303void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14304 bool Predicated) {
14305 auto &BECounts =
14306 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14307 auto It = BECounts.find(L);
14308 if (It != BECounts.end()) {
14309 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14310 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14311 if (!isa<SCEVConstant>(S)) {
14312 auto UserIt = BECountUsers.find(S);
14313 assert(UserIt != BECountUsers.end());
14314 UserIt->second.erase({L, Predicated});
14315 }
14316 }
14317 }
14318 BECounts.erase(It);
14319 }
14320}
14321
14322void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
14323 SmallPtrSet<const SCEV *, 8> ToForget(llvm::from_range, SCEVs);
14324 SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
14325
14326 while (!Worklist.empty()) {
14327 const SCEV *Curr = Worklist.pop_back_val();
14328 auto Users = SCEVUsers.find(Curr);
14329 if (Users != SCEVUsers.end())
14330 for (const auto *User : Users->second)
14331 if (ToForget.insert(User).second)
14332 Worklist.push_back(User);
14333 }
14334
14335 for (const auto *S : ToForget)
14336 forgetMemoizedResultsImpl(S);
14337
14338 for (auto I = PredicatedSCEVRewrites.begin();
14339 I != PredicatedSCEVRewrites.end();) {
14340 std::pair<const SCEV *, const Loop *> Entry = I->first;
14341 if (ToForget.count(Entry.first))
14342 PredicatedSCEVRewrites.erase(I++);
14343 else
14344 ++I;
14345 }
14346}
14347
14348void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14349 LoopDispositions.erase(S);
14350 BlockDispositions.erase(S);
14351 UnsignedRanges.erase(S);
14352 SignedRanges.erase(S);
14353 HasRecMap.erase(S);
14354 ConstantMultipleCache.erase(S);
14355
14356 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14357 UnsignedWrapViaInductionTried.erase(AR);
14358 SignedWrapViaInductionTried.erase(AR);
14359 }
14360
14361 auto ExprIt = ExprValueMap.find(S);
14362 if (ExprIt != ExprValueMap.end()) {
14363 for (Value *V : ExprIt->second) {
14364 auto ValueIt = ValueExprMap.find_as(V);
14365 if (ValueIt != ValueExprMap.end())
14366 ValueExprMap.erase(ValueIt);
14367 }
14368 ExprValueMap.erase(ExprIt);
14369 }
14370
14371 auto ScopeIt = ValuesAtScopes.find(S);
14372 if (ScopeIt != ValuesAtScopes.end()) {
14373 for (const auto &Pair : ScopeIt->second)
14374 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14375 llvm::erase(ValuesAtScopesUsers[Pair.second],
14376 std::make_pair(Pair.first, S));
14377 ValuesAtScopes.erase(ScopeIt);
14378 }
14379
14380 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14381 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14382 for (const auto &Pair : ScopeUserIt->second)
14383 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14384 ValuesAtScopesUsers.erase(ScopeUserIt);
14385 }
14386
14387 auto BEUsersIt = BECountUsers.find(S);
14388 if (BEUsersIt != BECountUsers.end()) {
14389 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14390 auto Copy = BEUsersIt->second;
14391 for (const auto &Pair : Copy)
14392 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14393 BECountUsers.erase(BEUsersIt);
14394 }
14395
14396 auto FoldUser = FoldCacheUser.find(S);
14397 if (FoldUser != FoldCacheUser.end())
14398 for (auto &KV : FoldUser->second)
14399 FoldCache.erase(KV);
14400 FoldCacheUser.erase(S);
14401}
14402
14403void
14404ScalarEvolution::getUsedLoops(const SCEV *S,
14405 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14406 struct FindUsedLoops {
14407 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14408 : LoopsUsed(LoopsUsed) {}
14409 SmallPtrSetImpl<const Loop *> &LoopsUsed;
14410 bool follow(const SCEV *S) {
14411 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14412 LoopsUsed.insert(AR->getLoop());
14413 return true;
14414 }
14415
14416 bool isDone() const { return false; }
14417 };
14418
14419 FindUsedLoops F(LoopsUsed);
14420 SCEVTraversal<FindUsedLoops>(F).visitAll(S);
14421}
14422
14423void ScalarEvolution::getReachableBlocks(
14426 Worklist.push_back(&F.getEntryBlock());
14427 while (!Worklist.empty()) {
14428 BasicBlock *BB = Worklist.pop_back_val();
14429 if (!Reachable.insert(BB).second)
14430 continue;
14431
14432 Value *Cond;
14433 BasicBlock *TrueBB, *FalseBB;
14434 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14435 m_BasicBlock(FalseBB)))) {
14436 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14437 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14438 continue;
14439 }
14440
14441 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14442 const SCEV *L = getSCEV(Cmp->getOperand(0));
14443 const SCEV *R = getSCEV(Cmp->getOperand(1));
14444 if (isKnownPredicateViaConstantRanges(Cmp->getCmpPredicate(), L, R)) {
14445 Worklist.push_back(TrueBB);
14446 continue;
14447 }
14448 if (isKnownPredicateViaConstantRanges(Cmp->getInverseCmpPredicate(), L,
14449 R)) {
14450 Worklist.push_back(FalseBB);
14451 continue;
14452 }
14453 }
14454 }
14455
14456 append_range(Worklist, successors(BB));
14457 }
14458}
14459
14461 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14462 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14463
14464 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14465
14466 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14467 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14468 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14469
14470 const SCEV *visitConstant(const SCEVConstant *Constant) {
14471 return SE.getConstant(Constant->getAPInt());
14472 }
14473
14474 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14475 return SE.getUnknown(Expr->getValue());
14476 }
14477
14478 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14479 return SE.getCouldNotCompute();
14480 }
14481 };
14482
14483 SCEVMapper SCM(SE2);
14484 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14485 SE2.getReachableBlocks(ReachableBlocks, F);
14486
14487 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14488 if (containsUndefs(Old) || containsUndefs(New)) {
14489 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14490 // not propagate undef aggressively). This means we can (and do) fail
14491 // verification in cases where a transform makes a value go from "undef"
14492 // to "undef+1" (say). The transform is fine, since in both cases the
14493 // result is "undef", but SCEV thinks the value increased by 1.
14494 return nullptr;
14495 }
14496
14497 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14498 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14499 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14500 return nullptr;
14501
14502 return Delta;
14503 };
14504
14505 while (!LoopStack.empty()) {
14506 auto *L = LoopStack.pop_back_val();
14507 llvm::append_range(LoopStack, *L);
14508
14509 // Only verify BECounts in reachable loops. For an unreachable loop,
14510 // any BECount is legal.
14511 if (!ReachableBlocks.contains(L->getHeader()))
14512 continue;
14513
14514 // Only verify cached BECounts. Computing new BECounts may change the
14515 // results of subsequent SCEV uses.
14516 auto It = BackedgeTakenCounts.find(L);
14517 if (It == BackedgeTakenCounts.end())
14518 continue;
14519
14520 auto *CurBECount =
14521 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14522 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14523
14524 if (CurBECount == SE2.getCouldNotCompute() ||
14525 NewBECount == SE2.getCouldNotCompute()) {
14526 // NB! This situation is legal, but is very suspicious -- whatever pass
14527 // change the loop to make a trip count go from could not compute to
14528 // computable or vice-versa *should have* invalidated SCEV. However, we
14529 // choose not to assert here (for now) since we don't want false
14530 // positives.
14531 continue;
14532 }
14533
14534 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14535 SE.getTypeSizeInBits(NewBECount->getType()))
14536 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14537 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14538 SE.getTypeSizeInBits(NewBECount->getType()))
14539 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14540
14541 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14542 if (Delta && !Delta->isZero()) {
14543 dbgs() << "Trip Count for " << *L << " Changed!\n";
14544 dbgs() << "Old: " << *CurBECount << "\n";
14545 dbgs() << "New: " << *NewBECount << "\n";
14546 dbgs() << "Delta: " << *Delta << "\n";
14547 std::abort();
14548 }
14549 }
14550
14551 // Collect all valid loops currently in LoopInfo.
14552 SmallPtrSet<Loop *, 32> ValidLoops;
14553 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14554 while (!Worklist.empty()) {
14555 Loop *L = Worklist.pop_back_val();
14556 if (ValidLoops.insert(L).second)
14557 Worklist.append(L->begin(), L->end());
14558 }
14559 for (const auto &KV : ValueExprMap) {
14560#ifndef NDEBUG
14561 // Check for SCEV expressions referencing invalid/deleted loops.
14562 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14563 assert(ValidLoops.contains(AR->getLoop()) &&
14564 "AddRec references invalid loop");
14565 }
14566#endif
14567
14568 // Check that the value is also part of the reverse map.
14569 auto It = ExprValueMap.find(KV.second);
14570 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14571 dbgs() << "Value " << *KV.first
14572 << " is in ValueExprMap but not in ExprValueMap\n";
14573 std::abort();
14574 }
14575
14576 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14577 if (!ReachableBlocks.contains(I->getParent()))
14578 continue;
14579 const SCEV *OldSCEV = SCM.visit(KV.second);
14580 const SCEV *NewSCEV = SE2.getSCEV(I);
14581 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14582 if (Delta && !Delta->isZero()) {
14583 dbgs() << "SCEV for value " << *I << " changed!\n"
14584 << "Old: " << *OldSCEV << "\n"
14585 << "New: " << *NewSCEV << "\n"
14586 << "Delta: " << *Delta << "\n";
14587 std::abort();
14588 }
14589 }
14590 }
14591
14592 for (const auto &KV : ExprValueMap) {
14593 for (Value *V : KV.second) {
14594 const SCEV *S = ValueExprMap.lookup(V);
14595 if (!S) {
14596 dbgs() << "Value " << *V
14597 << " is in ExprValueMap but not in ValueExprMap\n";
14598 std::abort();
14599 }
14600 if (S != KV.first) {
14601 dbgs() << "Value " << *V << " mapped to " << *S << " rather than "
14602 << *KV.first << "\n";
14603 std::abort();
14604 }
14605 }
14606 }
14607
14608 // Verify integrity of SCEV users.
14609 for (const auto &S : UniqueSCEVs) {
14610 for (const auto *Op : S.operands()) {
14611 // We do not store dependencies of constants.
14612 if (isa<SCEVConstant>(Op))
14613 continue;
14614 auto It = SCEVUsers.find(Op);
14615 if (It != SCEVUsers.end() && It->second.count(&S))
14616 continue;
14617 dbgs() << "Use of operand " << *Op << " by user " << S
14618 << " is not being tracked!\n";
14619 std::abort();
14620 }
14621 }
14622
14623 // Verify integrity of ValuesAtScopes users.
14624 for (const auto &ValueAndVec : ValuesAtScopes) {
14625 const SCEV *Value = ValueAndVec.first;
14626 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14627 const Loop *L = LoopAndValueAtScope.first;
14628 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14629 if (!isa<SCEVConstant>(ValueAtScope)) {
14630 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14631 if (It != ValuesAtScopesUsers.end() &&
14632 is_contained(It->second, std::make_pair(L, Value)))
14633 continue;
14634 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14635 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14636 std::abort();
14637 }
14638 }
14639 }
14640
14641 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14642 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14643 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14644 const Loop *L = LoopAndValue.first;
14645 const SCEV *Value = LoopAndValue.second;
14647 auto It = ValuesAtScopes.find(Value);
14648 if (It != ValuesAtScopes.end() &&
14649 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14650 continue;
14651 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14652 << *ValueAtScope << " missing in ValuesAtScopes\n";
14653 std::abort();
14654 }
14655 }
14656
14657 // Verify integrity of BECountUsers.
14658 auto VerifyBECountUsers = [&](bool Predicated) {
14659 auto &BECounts =
14660 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14661 for (const auto &LoopAndBEInfo : BECounts) {
14662 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14663 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14664 if (!isa<SCEVConstant>(S)) {
14665 auto UserIt = BECountUsers.find(S);
14666 if (UserIt != BECountUsers.end() &&
14667 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14668 continue;
14669 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14670 << " missing from BECountUsers\n";
14671 std::abort();
14672 }
14673 }
14674 }
14675 }
14676 };
14677 VerifyBECountUsers(/* Predicated */ false);
14678 VerifyBECountUsers(/* Predicated */ true);
14679
14680 // Verify intergity of loop disposition cache.
14681 for (auto &[S, Values] : LoopDispositions) {
14682 for (auto [Loop, CachedDisposition] : Values) {
14683 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14684 if (CachedDisposition != RecomputedDisposition) {
14685 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14686 << " is incorrect: cached " << CachedDisposition << ", actual "
14687 << RecomputedDisposition << "\n";
14688 std::abort();
14689 }
14690 }
14691 }
14692
14693 // Verify integrity of the block disposition cache.
14694 for (auto &[S, Values] : BlockDispositions) {
14695 for (auto [BB, CachedDisposition] : Values) {
14696 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14697 if (CachedDisposition != RecomputedDisposition) {
14698 dbgs() << "Cached disposition of " << *S << " for block %"
14699 << BB->getName() << " is incorrect: cached " << CachedDisposition
14700 << ", actual " << RecomputedDisposition << "\n";
14701 std::abort();
14702 }
14703 }
14704 }
14705
14706 // Verify FoldCache/FoldCacheUser caches.
14707 for (auto [FoldID, Expr] : FoldCache) {
14708 auto I = FoldCacheUser.find(Expr);
14709 if (I == FoldCacheUser.end()) {
14710 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14711 << "!\n";
14712 std::abort();
14713 }
14714 if (!is_contained(I->second, FoldID)) {
14715 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14716 std::abort();
14717 }
14718 }
14719 for (auto [Expr, IDs] : FoldCacheUser) {
14720 for (auto &FoldID : IDs) {
14721 const SCEV *S = FoldCache.lookup(FoldID);
14722 if (!S) {
14723 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14724 << "!\n";
14725 std::abort();
14726 }
14727 if (S != Expr) {
14728 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: " << *S
14729 << " != " << *Expr << "!\n";
14730 std::abort();
14731 }
14732 }
14733 }
14734
14735 // Verify that ConstantMultipleCache computations are correct. We check that
14736 // cached multiples and recomputed multiples are multiples of each other to
14737 // verify correctness. It is possible that a recomputed multiple is different
14738 // from the cached multiple due to strengthened no wrap flags or changes in
14739 // KnownBits computations.
14740 for (auto [S, Multiple] : ConstantMultipleCache) {
14741 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14742 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14743 Multiple.urem(RecomputedMultiple) != 0 &&
14744 RecomputedMultiple.urem(Multiple) != 0)) {
14745 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14746 << *S << " : Computed " << RecomputedMultiple
14747 << " but cache contains " << Multiple << "!\n";
14748 std::abort();
14749 }
14750 }
14751}
14752
14754 Function &F, const PreservedAnalyses &PA,
14755 FunctionAnalysisManager::Invalidator &Inv) {
14756 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14757 // of its dependencies is invalidated.
14758 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14759 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14760 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14761 Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
14762 Inv.invalidate<LoopAnalysis>(F, PA);
14763}
14764
14765AnalysisKey ScalarEvolutionAnalysis::Key;
14766
14769 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14770 auto &AC = AM.getResult<AssumptionAnalysis>(F);
14771 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
14772 auto &LI = AM.getResult<LoopAnalysis>(F);
14773 return ScalarEvolution(F, TLI, AC, DT, LI);
14774}
14775
14781
14784 // For compatibility with opt's -analyze feature under legacy pass manager
14785 // which was not ported to NPM. This keeps tests using
14786 // update_analyze_test_checks.py working.
14787 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
14788 << F.getName() << "':\n";
14790 return PreservedAnalyses::all();
14791}
14792
14794 "Scalar Evolution Analysis", false, true)
14800 "Scalar Evolution Analysis", false, true)
14801
14803
14805
14807 SE.reset(new ScalarEvolution(
14809 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
14811 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
14812 return false;
14813}
14814
14816
14818 SE->print(OS);
14819}
14820
14822 if (!VerifySCEV)
14823 return;
14824
14825 SE->verify();
14826}
14827
14835
14837 const SCEV *RHS) {
14838 return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS);
14839}
14840
14841const SCEVPredicate *
14843 const SCEV *LHS, const SCEV *RHS) {
14845 assert(LHS->getType() == RHS->getType() &&
14846 "Type mismatch between LHS and RHS");
14847 // Unique this node based on the arguments
14848 ID.AddInteger(SCEVPredicate::P_Compare);
14849 ID.AddInteger(Pred);
14850 ID.AddPointer(LHS);
14851 ID.AddPointer(RHS);
14852 void *IP = nullptr;
14853 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14854 return S;
14855 SCEVComparePredicate *Eq = new (SCEVAllocator)
14856 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
14857 UniquePreds.InsertNode(Eq, IP);
14858 return Eq;
14859}
14860
14862 const SCEVAddRecExpr *AR,
14865 // Unique this node based on the arguments
14866 ID.AddInteger(SCEVPredicate::P_Wrap);
14867 ID.AddPointer(AR);
14868 ID.AddInteger(AddedFlags);
14869 void *IP = nullptr;
14870 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14871 return S;
14872 auto *OF = new (SCEVAllocator)
14873 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
14874 UniquePreds.InsertNode(OF, IP);
14875 return OF;
14876}
14877
14878namespace {
14879
14880class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14881public:
14882
14883 /// Rewrites \p S in the context of a loop L and the SCEV predication
14884 /// infrastructure.
14885 ///
14886 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
14887 /// equivalences present in \p Pred.
14888 ///
14889 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
14890 /// \p NewPreds such that the result will be an AddRecExpr.
14891 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
14893 const SCEVPredicate *Pred) {
14894 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
14895 return Rewriter.visit(S);
14896 }
14897
14898 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14899 if (Pred) {
14900 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
14901 for (const auto *Pred : U->getPredicates())
14902 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
14903 if (IPred->getLHS() == Expr &&
14904 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14905 return IPred->getRHS();
14906 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
14907 if (IPred->getLHS() == Expr &&
14908 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14909 return IPred->getRHS();
14910 }
14911 }
14912 return convertToAddRecWithPreds(Expr);
14913 }
14914
14915 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14916 const SCEV *Operand = visit(Expr->getOperand());
14917 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14918 if (AR && AR->getLoop() == L && AR->isAffine()) {
14919 // This couldn't be folded because the operand didn't have the nuw
14920 // flag. Add the nusw flag as an assumption that we could make.
14921 const SCEV *Step = AR->getStepRecurrence(SE);
14922 Type *Ty = Expr->getType();
14923 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
14924 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
14925 SE.getSignExtendExpr(Step, Ty), L,
14926 AR->getNoWrapFlags());
14927 }
14928 return SE.getZeroExtendExpr(Operand, Expr->getType());
14929 }
14930
14931 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
14932 const SCEV *Operand = visit(Expr->getOperand());
14933 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14934 if (AR && AR->getLoop() == L && AR->isAffine()) {
14935 // This couldn't be folded because the operand didn't have the nsw
14936 // flag. Add the nssw flag as an assumption that we could make.
14937 const SCEV *Step = AR->getStepRecurrence(SE);
14938 Type *Ty = Expr->getType();
14939 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
14940 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
14941 SE.getSignExtendExpr(Step, Ty), L,
14942 AR->getNoWrapFlags());
14943 }
14944 return SE.getSignExtendExpr(Operand, Expr->getType());
14945 }
14946
14947private:
14948 explicit SCEVPredicateRewriter(
14949 const Loop *L, ScalarEvolution &SE,
14950 SmallVectorImpl<const SCEVPredicate *> *NewPreds,
14951 const SCEVPredicate *Pred)
14952 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
14953
14954 bool addOverflowAssumption(const SCEVPredicate *P) {
14955 if (!NewPreds) {
14956 // Check if we've already made this assumption.
14957 return Pred && Pred->implies(P, SE);
14958 }
14959 NewPreds->push_back(P);
14960 return true;
14961 }
14962
14963 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
14965 auto *A = SE.getWrapPredicate(AR, AddedFlags);
14966 return addOverflowAssumption(A);
14967 }
14968
14969 // If \p Expr represents a PHINode, we try to see if it can be represented
14970 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
14971 // to add this predicate as a runtime overflow check, we return the AddRec.
14972 // If \p Expr does not meet these conditions (is not a PHI node, or we
14973 // couldn't create an AddRec for it, or couldn't add the predicate), we just
14974 // return \p Expr.
14975 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
14976 if (!isa<PHINode>(Expr->getValue()))
14977 return Expr;
14978 std::optional<
14979 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
14980 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
14981 if (!PredicatedRewrite)
14982 return Expr;
14983 for (const auto *P : PredicatedRewrite->second){
14984 // Wrap predicates from outer loops are not supported.
14985 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
14986 if (L != WP->getExpr()->getLoop())
14987 return Expr;
14988 }
14989 if (!addOverflowAssumption(P))
14990 return Expr;
14991 }
14992 return PredicatedRewrite->first;
14993 }
14994
14995 SmallVectorImpl<const SCEVPredicate *> *NewPreds;
14996 const SCEVPredicate *Pred;
14997 const Loop *L;
14998};
14999
15000} // end anonymous namespace
15001
15002const SCEV *
15004 const SCEVPredicate &Preds) {
15005 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
15006}
15007
15009 const SCEV *S, const Loop *L,
15012 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
15013 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
15014
15015 if (!AddRec)
15016 return nullptr;
15017
15018 // Check if any of the transformed predicates is known to be false. In that
15019 // case, it doesn't make sense to convert to a predicated AddRec, as the
15020 // versioned loop will never execute.
15021 for (const SCEVPredicate *Pred : TransformPreds) {
15022 auto *WrapPred = dyn_cast<SCEVWrapPredicate>(Pred);
15023 if (!WrapPred || WrapPred->getFlags() != SCEVWrapPredicate::IncrementNSSW)
15024 continue;
15025
15026 const SCEVAddRecExpr *AddRecToCheck = WrapPred->getExpr();
15027 const SCEV *ExitCount = getBackedgeTakenCount(AddRecToCheck->getLoop());
15028 if (isa<SCEVCouldNotCompute>(ExitCount))
15029 continue;
15030
15031 const SCEV *Step = AddRecToCheck->getStepRecurrence(*this);
15032 if (!Step->isOne())
15033 continue;
15034
15035 ExitCount = getTruncateOrSignExtend(ExitCount, Step->getType());
15036 const SCEV *Add = getAddExpr(AddRecToCheck->getStart(), ExitCount);
15037 if (isKnownPredicate(CmpInst::ICMP_SLT, Add, AddRecToCheck->getStart()))
15038 return nullptr;
15039 }
15040
15041 // Since the transformation was successful, we can now transfer the SCEV
15042 // predicates.
15043 Preds.append(TransformPreds.begin(), TransformPreds.end());
15044
15045 return AddRec;
15046}
15047
15048/// SCEV predicates
15052
15054 const ICmpInst::Predicate Pred,
15055 const SCEV *LHS, const SCEV *RHS)
15056 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
15057 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
15058 assert(LHS != RHS && "LHS and RHS are the same SCEV");
15059}
15060
15062 ScalarEvolution &SE) const {
15063 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
15064
15065 if (!Op)
15066 return false;
15067
15068 if (Pred != ICmpInst::ICMP_EQ)
15069 return false;
15070
15071 return Op->LHS == LHS && Op->RHS == RHS;
15072}
15073
15074bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
15075
15077 if (Pred == ICmpInst::ICMP_EQ)
15078 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
15079 else
15080 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
15081 << *RHS << "\n";
15082
15083}
15084
15086 const SCEVAddRecExpr *AR,
15087 IncrementWrapFlags Flags)
15088 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
15089
15090const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
15091
15093 ScalarEvolution &SE) const {
15094 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
15095 if (!Op || setFlags(Flags, Op->Flags) != Flags)
15096 return false;
15097
15098 if (Op->AR == AR)
15099 return true;
15100
15101 if (Flags != SCEVWrapPredicate::IncrementNSSW &&
15103 return false;
15104
15105 const SCEV *Start = AR->getStart();
15106 const SCEV *OpStart = Op->AR->getStart();
15107 if (Start->getType()->isPointerTy() != OpStart->getType()->isPointerTy())
15108 return false;
15109
15110 // Reject pointers to different address spaces.
15111 if (Start->getType()->isPointerTy() && Start->getType() != OpStart->getType())
15112 return false;
15113
15114 const SCEV *Step = AR->getStepRecurrence(SE);
15115 const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
15116 if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
15117 return false;
15118
15119 // If both steps are positive, this implies N, if N's start and step are
15120 // ULE/SLE (for NSUW/NSSW) than this'.
15121 Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
15122 Step = SE.getNoopOrZeroExtend(Step, WiderTy);
15123 OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
15124
15125 bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
15126 OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
15127 : SE.getNoopOrSignExtend(OpStart, WiderTy);
15128 Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
15129 : SE.getNoopOrSignExtend(Start, WiderTy);
15131 return SE.isKnownPredicate(Pred, OpStep, Step) &&
15132 SE.isKnownPredicate(Pred, OpStart, Start);
15133}
15134
15136 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
15137 IncrementWrapFlags IFlags = Flags;
15138
15139 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
15140 IFlags = clearFlags(IFlags, IncrementNSSW);
15141
15142 return IFlags == IncrementAnyWrap;
15143}
15144
15145void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
15146 OS.indent(Depth) << *getExpr() << " Added Flags: ";
15148 OS << "<nusw>";
15150 OS << "<nssw>";
15151 OS << "\n";
15152}
15153
15156 ScalarEvolution &SE) {
15157 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
15158 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
15159
15160 // We can safely transfer the NSW flag as NSSW.
15161 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
15162 ImpliedFlags = IncrementNSSW;
15163
15164 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
15165 // If the increment is positive, the SCEV NUW flag will also imply the
15166 // WrapPredicate NUSW flag.
15167 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
15168 if (Step->getValue()->getValue().isNonNegative())
15169 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
15170 }
15171
15172 return ImpliedFlags;
15173}
15174
15175/// Union predicates don't get cached so create a dummy set ID for it.
15177 ScalarEvolution &SE)
15178 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15179 for (const auto *P : Preds)
15180 add(P, SE);
15181}
15182
15184 return all_of(Preds,
15185 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15186}
15187
15189 ScalarEvolution &SE) const {
15190 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15191 return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15192 return this->implies(I, SE);
15193 });
15194
15195 return any_of(Preds,
15196 [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
15197}
15198
15200 for (const auto *Pred : Preds)
15201 Pred->print(OS, Depth);
15202}
15203
15204void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
15205 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15206 for (const auto *Pred : Set->Preds)
15207 add(Pred, SE);
15208 return;
15209 }
15210
15211 // Implication checks are quadratic in the number of predicates. Stop doing
15212 // them if there are many predicates, as they should be too expensive to use
15213 // anyway at that point.
15214 bool CheckImplies = Preds.size() < 16;
15215
15216 // Only add predicate if it is not already implied by this union predicate.
15217 if (CheckImplies && implies(N, SE))
15218 return;
15219
15220 // Build a new vector containing the current predicates, except the ones that
15221 // are implied by the new predicate N.
15223 for (auto *P : Preds) {
15224 if (CheckImplies && N->implies(P, SE))
15225 continue;
15226 PrunedPreds.push_back(P);
15227 }
15228 Preds = std::move(PrunedPreds);
15229 Preds.push_back(N);
15230}
15231
15233 Loop &L)
15234 : SE(SE), L(L) {
15236 Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
15237}
15238
15241 for (const auto *Op : Ops)
15242 // We do not expect that forgetting cached data for SCEVConstants will ever
15243 // open any prospects for sharpening or introduce any correctness issues,
15244 // so we don't bother storing their dependencies.
15245 if (!isa<SCEVConstant>(Op))
15246 SCEVUsers[Op].insert(User);
15247}
15248
15250 const SCEV *Expr = SE.getSCEV(V);
15251 return getPredicatedSCEV(Expr);
15252}
15253
15255 RewriteEntry &Entry = RewriteMap[Expr];
15256
15257 // If we already have an entry and the version matches, return it.
15258 if (Entry.second && Generation == Entry.first)
15259 return Entry.second;
15260
15261 // We found an entry but it's stale. Rewrite the stale entry
15262 // according to the current predicate.
15263 if (Entry.second)
15264 Expr = Entry.second;
15265
15266 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
15267 Entry = {Generation, NewSCEV};
15268
15269 return NewSCEV;
15270}
15271
15273 if (!BackedgeCount) {
15275 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
15276 for (const auto *P : Preds)
15277 addPredicate(*P);
15278 }
15279 return BackedgeCount;
15280}
15281
15283 if (!SymbolicMaxBackedgeCount) {
15285 SymbolicMaxBackedgeCount =
15286 SE.getPredicatedSymbolicMaxBackedgeTakenCount(&L, Preds);
15287 for (const auto *P : Preds)
15288 addPredicate(*P);
15289 }
15290 return SymbolicMaxBackedgeCount;
15291}
15292
15294 if (!SmallConstantMaxTripCount) {
15296 SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds);
15297 for (const auto *P : Preds)
15298 addPredicate(*P);
15299 }
15300 return *SmallConstantMaxTripCount;
15301}
15302
15304 if (Preds->implies(&Pred, SE))
15305 return;
15306
15307 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15308 NewPreds.push_back(&Pred);
15309 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
15310 updateGeneration();
15311}
15312
15314 return *Preds;
15315}
15316
15317void PredicatedScalarEvolution::updateGeneration() {
15318 // If the generation number wrapped recompute everything.
15319 if (++Generation == 0) {
15320 for (auto &II : RewriteMap) {
15321 const SCEV *Rewritten = II.second.second;
15322 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
15323 }
15324 }
15325}
15326
15329 const SCEV *Expr = getSCEV(V);
15330 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15331
15332 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
15333
15334 // Clear the statically implied flags.
15335 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
15336 addPredicate(*SE.getWrapPredicate(AR, Flags));
15337
15338 auto II = FlagsMap.insert({V, Flags});
15339 if (!II.second)
15340 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
15341}
15342
15345 const SCEV *Expr = getSCEV(V);
15346 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15347
15349 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
15350
15351 auto II = FlagsMap.find(V);
15352
15353 if (II != FlagsMap.end())
15354 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
15355
15357}
15358
15360 const SCEV *Expr = this->getSCEV(V);
15362 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
15363
15364 if (!New)
15365 return nullptr;
15366
15367 for (const auto *P : NewPreds)
15368 addPredicate(*P);
15369
15370 RewriteMap[SE.getSCEV(V)] = {Generation, New};
15371 return New;
15372}
15373
15376 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15377 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15378 SE)),
15379 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15380 for (auto I : Init.FlagsMap)
15381 FlagsMap.insert(I);
15382}
15383
15385 // For each block.
15386 for (auto *BB : L.getBlocks())
15387 for (auto &I : *BB) {
15388 if (!SE.isSCEVable(I.getType()))
15389 continue;
15390
15391 auto *Expr = SE.getSCEV(&I);
15392 auto II = RewriteMap.find(Expr);
15393
15394 if (II == RewriteMap.end())
15395 continue;
15396
15397 // Don't print things that are not interesting.
15398 if (II->second.second == Expr)
15399 continue;
15400
15401 OS.indent(Depth) << "[PSE]" << I << ":\n";
15402 OS.indent(Depth + 2) << *Expr << "\n";
15403 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15404 }
15405}
15406
15409 BasicBlock *Header = L->getHeader();
15410 BasicBlock *Pred = L->getLoopPredecessor();
15411 LoopGuards Guards(SE);
15412 if (!Pred)
15413 return Guards;
15415 collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15416 return Guards;
15417}
15418
15419void ScalarEvolution::LoopGuards::collectFromPHI(
15423 unsigned Depth) {
15424 if (!SE.isSCEVable(Phi.getType()))
15425 return;
15426
15427 using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15428 auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15429 const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15430 if (!VisitedBlocks.insert(InBlock).second)
15431 return {nullptr, scCouldNotCompute};
15432
15433 // Avoid analyzing unreachable blocks so that we don't get trapped
15434 // traversing cycles with ill-formed dominance or infinite cycles
15435 if (!SE.DT.isReachableFromEntry(InBlock))
15436 return {nullptr, scCouldNotCompute};
15437
15438 auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15439 if (Inserted)
15440 collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15441 Depth + 1);
15442 auto &RewriteMap = G->second.RewriteMap;
15443 if (RewriteMap.empty())
15444 return {nullptr, scCouldNotCompute};
15445 auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15446 if (S == RewriteMap.end())
15447 return {nullptr, scCouldNotCompute};
15448 auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15449 if (!SM)
15450 return {nullptr, scCouldNotCompute};
15451 if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15452 return {C0, SM->getSCEVType()};
15453 return {nullptr, scCouldNotCompute};
15454 };
15455 auto MergeMinMaxConst = [](MinMaxPattern P1,
15456 MinMaxPattern P2) -> MinMaxPattern {
15457 auto [C1, T1] = P1;
15458 auto [C2, T2] = P2;
15459 if (!C1 || !C2 || T1 != T2)
15460 return {nullptr, scCouldNotCompute};
15461 switch (T1) {
15462 case scUMaxExpr:
15463 return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15464 case scSMaxExpr:
15465 return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15466 case scUMinExpr:
15467 return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15468 case scSMinExpr:
15469 return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15470 default:
15471 llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15472 }
15473 };
15474 auto P = GetMinMaxConst(0);
15475 for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15476 if (!P.first)
15477 break;
15478 P = MergeMinMaxConst(P, GetMinMaxConst(In));
15479 }
15480 if (P.first) {
15481 const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15483 const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15484 Guards.RewriteMap.insert({LHS, RHS});
15485 }
15486}
15487
15488// Return a new SCEV that modifies \p Expr to the closest number divides by
15489// \p Divisor and less or equal than Expr. For now, only handle constant
15490// Expr.
15492 const APInt &DivisorVal,
15493 ScalarEvolution &SE) {
15494 const APInt *ExprVal;
15495 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15496 DivisorVal.isNonPositive())
15497 return Expr;
15498 APInt Rem = ExprVal->urem(DivisorVal);
15499 // return the SCEV: Expr - Expr % Divisor
15500 return SE.getConstant(*ExprVal - Rem);
15501}
15502
15503// Return a new SCEV that modifies \p Expr to the closest number divides by
15504// \p Divisor and greater or equal than Expr. For now, only handle constant
15505// Expr.
15506static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
15507 const APInt &DivisorVal,
15508 ScalarEvolution &SE) {
15509 const APInt *ExprVal;
15510 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15511 DivisorVal.isNonPositive())
15512 return Expr;
15513 APInt Rem = ExprVal->urem(DivisorVal);
15514 if (Rem.isZero())
15515 return Expr;
15516 // return the SCEV: Expr + Divisor - Expr % Divisor
15517 return SE.getConstant(*ExprVal + DivisorVal - Rem);
15518}
15519
15521 ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS,
15524 // If we have LHS == 0, check if LHS is computing a property of some unknown
15525 // SCEV %v which we can rewrite %v to express explicitly.
15527 return false;
15528 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15529 // explicitly express that.
15530 const SCEVUnknown *URemLHS = nullptr;
15531 const SCEV *URemRHS = nullptr;
15532 if (!match(LHS, m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE)))
15533 return false;
15534
15535 const SCEV *Multiple =
15536 SE.getMulExpr(SE.getUDivExpr(URemLHS, URemRHS), URemRHS);
15537 DivInfo[URemLHS] = Multiple;
15538 if (auto *C = dyn_cast<SCEVConstant>(URemRHS))
15539 Multiples[URemLHS] = C->getAPInt();
15540 return true;
15541}
15542
15543// Check if the condition is a divisibility guard (A % B == 0).
15544static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS,
15545 ScalarEvolution &SE) {
15546 const SCEV *X, *Y;
15547 return match(LHS, m_scev_URem(m_SCEV(X), m_SCEV(Y), SE)) && RHS->isZero();
15548}
15549
15550// Apply divisibility by \p Divisor on MinMaxExpr with constant values,
15551// recursively. This is done by aligning up/down the constant value to the
15552// Divisor.
15553static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr,
15554 APInt Divisor,
15555 ScalarEvolution &SE) {
15556 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15557 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15558 // the non-constant operand and in \p LHS the constant operand.
15559 auto IsMinMaxSCEVWithNonNegativeConstant =
15560 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15561 const SCEV *&RHS) {
15562 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15563 if (MinMax->getNumOperands() != 2)
15564 return false;
15565 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15566 if (C->getAPInt().isNegative())
15567 return false;
15568 SCTy = MinMax->getSCEVType();
15569 LHS = MinMax->getOperand(0);
15570 RHS = MinMax->getOperand(1);
15571 return true;
15572 }
15573 }
15574 return false;
15575 };
15576
15577 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15578 SCEVTypes SCTy;
15579 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15580 MinMaxRHS))
15581 return MinMaxExpr;
15582 auto IsMin = isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15583 assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!");
15584 auto *DivisibleExpr =
15585 IsMin ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE)
15586 : getNextSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE);
15588 applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr};
15589 return SE.getMinMaxExpr(SCTy, Ops);
15590}
15591
15592void ScalarEvolution::LoopGuards::collectFromBlock(
15593 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
15594 const BasicBlock *Block, const BasicBlock *Pred,
15595 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
15596
15598
15599 SmallVector<const SCEV *> ExprsToRewrite;
15600 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15601 const SCEV *RHS,
15602 DenseMap<const SCEV *, const SCEV *> &RewriteMap,
15603 const LoopGuards &DivGuards) {
15604 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15605 // replacement SCEV which isn't directly implied by the structure of that
15606 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15607 // legal. See the scoping rules for flags in the header to understand why.
15608
15609 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15610 // create this form when combining two checks of the form (X u< C2 + C1) and
15611 // (X >=u C1).
15612 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15613 &ExprsToRewrite]() {
15614 const SCEVConstant *C1;
15615 const SCEVUnknown *LHSUnknown;
15616 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15617 if (!match(LHS,
15618 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
15619 !C2)
15620 return false;
15621
15622 auto ExactRegion =
15623 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15624 .sub(C1->getAPInt());
15625
15626 // Bail out, unless we have a non-wrapping, monotonic range.
15627 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15628 return false;
15629 auto [I, Inserted] = RewriteMap.try_emplace(LHSUnknown);
15630 const SCEV *RewrittenLHS = Inserted ? LHSUnknown : I->second;
15631 I->second = SE.getUMaxExpr(
15632 SE.getConstant(ExactRegion.getUnsignedMin()),
15633 SE.getUMinExpr(RewrittenLHS,
15634 SE.getConstant(ExactRegion.getUnsignedMax())));
15635 ExprsToRewrite.push_back(LHSUnknown);
15636 return true;
15637 };
15638 if (MatchRangeCheckIdiom())
15639 return;
15640
15641 // Do not apply information for constants or if RHS contains an AddRec.
15643 return;
15644
15645 // If RHS is SCEVUnknown, make sure the information is applied to it.
15647 std::swap(LHS, RHS);
15649 }
15650
15651 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15652 // and \p FromRewritten are the same (i.e. there has been no rewrite
15653 // registered for \p From), then puts this value in the list of rewritten
15654 // expressions.
15655 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15656 const SCEV *To) {
15657 if (From == FromRewritten)
15658 ExprsToRewrite.push_back(From);
15659 RewriteMap[From] = To;
15660 };
15661
15662 // Checks whether \p S has already been rewritten. In that case returns the
15663 // existing rewrite because we want to chain further rewrites onto the
15664 // already rewritten value. Otherwise returns \p S.
15665 auto GetMaybeRewritten = [&](const SCEV *S) {
15666 return RewriteMap.lookup_or(S, S);
15667 };
15668
15669 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15670 // Apply divisibility information when computing the constant multiple.
15671 const APInt &DividesBy =
15672 SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS));
15673
15674 // Collect rewrites for LHS and its transitive operands based on the
15675 // condition.
15676 // For min/max expressions, also apply the guard to its operands:
15677 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15678 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15679 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15680 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15681
15682 // We cannot express strict predicates in SCEV, so instead we replace them
15683 // with non-strict ones against plus or minus one of RHS depending on the
15684 // predicate.
15685 const SCEV *One = SE.getOne(RHS->getType());
15686 switch (Predicate) {
15687 case CmpInst::ICMP_ULT:
15688 if (RHS->getType()->isPointerTy())
15689 return;
15690 RHS = SE.getUMaxExpr(RHS, One);
15691 [[fallthrough]];
15692 case CmpInst::ICMP_SLT: {
15693 RHS = SE.getMinusSCEV(RHS, One);
15694 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15695 break;
15696 }
15697 case CmpInst::ICMP_UGT:
15698 case CmpInst::ICMP_SGT:
15699 RHS = SE.getAddExpr(RHS, One);
15700 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15701 break;
15702 case CmpInst::ICMP_ULE:
15703 case CmpInst::ICMP_SLE:
15704 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15705 break;
15706 case CmpInst::ICMP_UGE:
15707 case CmpInst::ICMP_SGE:
15708 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15709 break;
15710 default:
15711 break;
15712 }
15713
15715 SmallPtrSet<const SCEV *, 16> Visited;
15716
15717 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15718 append_range(Worklist, S->operands());
15719 };
15720
15721 while (!Worklist.empty()) {
15722 const SCEV *From = Worklist.pop_back_val();
15723 if (isa<SCEVConstant>(From))
15724 continue;
15725 if (!Visited.insert(From).second)
15726 continue;
15727 const SCEV *FromRewritten = GetMaybeRewritten(From);
15728 const SCEV *To = nullptr;
15729
15730 switch (Predicate) {
15731 case CmpInst::ICMP_ULT:
15732 case CmpInst::ICMP_ULE:
15733 To = SE.getUMinExpr(FromRewritten, RHS);
15734 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15735 EnqueueOperands(UMax);
15736 break;
15737 case CmpInst::ICMP_SLT:
15738 case CmpInst::ICMP_SLE:
15739 To = SE.getSMinExpr(FromRewritten, RHS);
15740 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15741 EnqueueOperands(SMax);
15742 break;
15743 case CmpInst::ICMP_UGT:
15744 case CmpInst::ICMP_UGE:
15745 To = SE.getUMaxExpr(FromRewritten, RHS);
15746 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15747 EnqueueOperands(UMin);
15748 break;
15749 case CmpInst::ICMP_SGT:
15750 case CmpInst::ICMP_SGE:
15751 To = SE.getSMaxExpr(FromRewritten, RHS);
15752 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15753 EnqueueOperands(SMin);
15754 break;
15755 case CmpInst::ICMP_EQ:
15757 To = RHS;
15758 break;
15759 case CmpInst::ICMP_NE:
15760 if (match(RHS, m_scev_Zero())) {
15761 const SCEV *OneAlignedUp =
15762 getNextSCEVDivisibleByDivisor(One, DividesBy, SE);
15763 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
15764 } else {
15765 // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS),
15766 // but creating the subtraction eagerly is expensive. Track the
15767 // inequalities in a separate map, and materialize the rewrite lazily
15768 // when encountering a suitable subtraction while re-writing.
15769 if (LHS->getType()->isPointerTy()) {
15773 break;
15774 }
15775 const SCEVConstant *C;
15776 const SCEV *A, *B;
15779 RHS = A;
15780 LHS = B;
15781 }
15782 if (LHS > RHS)
15783 std::swap(LHS, RHS);
15784 Guards.NotEqual.insert({LHS, RHS});
15785 continue;
15786 }
15787 break;
15788 default:
15789 break;
15790 }
15791
15792 if (To)
15793 AddRewrite(From, FromRewritten, To);
15794 }
15795 };
15796
15798 // First, collect information from assumptions dominating the loop.
15799 for (auto &AssumeVH : SE.AC.assumptions()) {
15800 if (!AssumeVH)
15801 continue;
15802 auto *AssumeI = cast<CallInst>(AssumeVH);
15803 if (!SE.DT.dominates(AssumeI, Block))
15804 continue;
15805 Terms.emplace_back(AssumeI->getOperand(0), true);
15806 }
15807
15808 // Second, collect information from llvm.experimental.guards dominating the loop.
15809 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
15810 SE.F.getParent(), Intrinsic::experimental_guard);
15811 if (GuardDecl)
15812 for (const auto *GU : GuardDecl->users())
15813 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
15814 if (Guard->getFunction() == Block->getParent() &&
15815 SE.DT.dominates(Guard, Block))
15816 Terms.emplace_back(Guard->getArgOperand(0), true);
15817
15818 // Third, collect conditions from dominating branches. Starting at the loop
15819 // predecessor, climb up the predecessor chain, as long as there are
15820 // predecessors that can be found that have unique successors leading to the
15821 // original header.
15822 // TODO: share this logic with isLoopEntryGuardedByCond.
15823 unsigned NumCollectedConditions = 0;
15825 std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
15826 for (; Pair.first;
15827 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
15828 VisitedBlocks.insert(Pair.second);
15829 const BranchInst *LoopEntryPredicate =
15830 dyn_cast<BranchInst>(Pair.first->getTerminator());
15831 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
15832 continue;
15833
15834 Terms.emplace_back(LoopEntryPredicate->getCondition(),
15835 LoopEntryPredicate->getSuccessor(0) == Pair.second);
15836 NumCollectedConditions++;
15837
15838 // If we are recursively collecting guards stop after 2
15839 // conditions to limit compile-time impact for now.
15840 if (Depth > 0 && NumCollectedConditions == 2)
15841 break;
15842 }
15843 // Finally, if we stopped climbing the predecessor chain because
15844 // there wasn't a unique one to continue, try to collect conditions
15845 // for PHINodes by recursively following all of their incoming
15846 // blocks and try to merge the found conditions to build a new one
15847 // for the Phi.
15848 if (Pair.second->hasNPredecessorsOrMore(2) &&
15850 SmallDenseMap<const BasicBlock *, LoopGuards> IncomingGuards;
15851 for (auto &Phi : Pair.second->phis())
15852 collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
15853 }
15854
15855 // Now apply the information from the collected conditions to
15856 // Guards.RewriteMap. Conditions are processed in reverse order, so the
15857 // earliest conditions is processed first, except guards with divisibility
15858 // information, which are moved to the back. This ensures the SCEVs with the
15859 // shortest dependency chains are constructed first.
15861 GuardsToProcess;
15862 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
15863 SmallVector<Value *, 8> Worklist;
15864 SmallPtrSet<Value *, 8> Visited;
15865 Worklist.push_back(Term);
15866 while (!Worklist.empty()) {
15867 Value *Cond = Worklist.pop_back_val();
15868 if (!Visited.insert(Cond).second)
15869 continue;
15870
15871 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
15872 auto Predicate =
15873 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
15874 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
15875 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15876 // If LHS is a constant, apply information to the other expression.
15877 // TODO: If LHS is not a constant, check if using CompareSCEVComplexity
15878 // can improve results.
15879 if (isa<SCEVConstant>(LHS)) {
15880 std::swap(LHS, RHS);
15882 }
15883 GuardsToProcess.emplace_back(Predicate, LHS, RHS);
15884 continue;
15885 }
15886
15887 Value *L, *R;
15888 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
15889 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
15890 Worklist.push_back(L);
15891 Worklist.push_back(R);
15892 }
15893 }
15894 }
15895
15896 // Process divisibility guards in reverse order to populate DivGuards early.
15897 DenseMap<const SCEV *, APInt> Multiples;
15898 LoopGuards DivGuards(SE);
15899 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) {
15900 if (!isDivisibilityGuard(LHS, RHS, SE))
15901 continue;
15902 collectDivisibilityInformation(Predicate, LHS, RHS, DivGuards.RewriteMap,
15903 Multiples, SE);
15904 }
15905
15906 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess)
15907 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivGuards);
15908
15909 // Apply divisibility information last. This ensures it is applied to the
15910 // outermost expression after other rewrites for the given value.
15911 for (const auto &[K, Divisor] : Multiples) {
15912 const SCEV *DivisorSCEV = SE.getConstant(Divisor);
15913 Guards.RewriteMap[K] =
15915 Guards.rewrite(K), Divisor, SE),
15916 DivisorSCEV),
15917 DivisorSCEV);
15918 ExprsToRewrite.push_back(K);
15919 }
15920
15921 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
15922 // the replacement expressions are contained in the ranges of the replaced
15923 // expressions.
15924 Guards.PreserveNUW = true;
15925 Guards.PreserveNSW = true;
15926 for (const SCEV *Expr : ExprsToRewrite) {
15927 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15928 Guards.PreserveNUW &=
15929 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
15930 Guards.PreserveNSW &=
15931 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
15932 }
15933
15934 // Now that all rewrite information is collect, rewrite the collected
15935 // expressions with the information in the map. This applies information to
15936 // sub-expressions.
15937 if (ExprsToRewrite.size() > 1) {
15938 for (const SCEV *Expr : ExprsToRewrite) {
15939 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15940 Guards.RewriteMap.erase(Expr);
15941 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
15942 }
15943 }
15944}
15945
15947 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
15948 /// in the map. It skips AddRecExpr because we cannot guarantee that the
15949 /// replacement is loop invariant in the loop of the AddRec.
15950 class SCEVLoopGuardRewriter
15951 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
15954
15956
15957 public:
15958 SCEVLoopGuardRewriter(ScalarEvolution &SE,
15959 const ScalarEvolution::LoopGuards &Guards)
15960 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap),
15961 NotEqual(Guards.NotEqual) {
15962 if (Guards.PreserveNUW)
15963 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
15964 if (Guards.PreserveNSW)
15965 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
15966 }
15967
15968 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
15969
15970 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
15971 return Map.lookup_or(Expr, Expr);
15972 }
15973
15974 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
15975 if (const SCEV *S = Map.lookup(Expr))
15976 return S;
15977
15978 // If we didn't find the extact ZExt expr in the map, check if there's
15979 // an entry for a smaller ZExt we can use instead.
15980 Type *Ty = Expr->getType();
15981 const SCEV *Op = Expr->getOperand(0);
15982 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
15983 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
15984 Bitwidth > Op->getType()->getScalarSizeInBits()) {
15985 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
15986 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
15987 if (const SCEV *S = Map.lookup(NarrowExt))
15988 return SE.getZeroExtendExpr(S, Ty);
15989 Bitwidth = Bitwidth / 2;
15990 }
15991
15993 Expr);
15994 }
15995
15996 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
15997 if (const SCEV *S = Map.lookup(Expr))
15998 return S;
16000 Expr);
16001 }
16002
16003 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
16004 if (const SCEV *S = Map.lookup(Expr))
16005 return S;
16007 }
16008
16009 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
16010 if (const SCEV *S = Map.lookup(Expr))
16011 return S;
16013 }
16014
16015 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
16016 // Helper to check if S is a subtraction (A - B) where A != B, and if so,
16017 // return UMax(S, 1).
16018 auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * {
16019 const SCEV *LHS, *RHS;
16020 if (MatchBinarySub(S, LHS, RHS)) {
16021 if (LHS > RHS)
16022 std::swap(LHS, RHS);
16023 if (NotEqual.contains({LHS, RHS})) {
16024 const SCEV *OneAlignedUp = getNextSCEVDivisibleByDivisor(
16025 SE.getOne(S->getType()), SE.getConstantMultiple(S), SE);
16026 return SE.getUMaxExpr(OneAlignedUp, S);
16027 }
16028 }
16029 return nullptr;
16030 };
16031
16032 // Check if Expr itself is a subtraction pattern with guard info.
16033 if (const SCEV *Rewritten = RewriteSubtraction(Expr))
16034 return Rewritten;
16035
16036 // Trip count expressions sometimes consist of adding 3 operands, i.e.
16037 // (Const + A + B). There may be guard info for A + B, and if so, apply
16038 // it.
16039 // TODO: Could more generally apply guards to Add sub-expressions.
16040 if (isa<SCEVConstant>(Expr->getOperand(0)) &&
16041 Expr->getNumOperands() == 3) {
16042 const SCEV *Add =
16043 SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2));
16044 if (const SCEV *Rewritten = RewriteSubtraction(Add))
16045 return SE.getAddExpr(
16046 Expr->getOperand(0), Rewritten,
16047 ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask));
16048 if (const SCEV *S = Map.lookup(Add))
16049 return SE.getAddExpr(Expr->getOperand(0), S);
16050 }
16052 bool Changed = false;
16053 for (const auto *Op : Expr->operands()) {
16054 Operands.push_back(
16056 Changed |= Op != Operands.back();
16057 }
16058 // We are only replacing operands with equivalent values, so transfer the
16059 // flags from the original expression.
16060 return !Changed ? Expr
16061 : SE.getAddExpr(Operands,
16063 Expr->getNoWrapFlags(), FlagMask));
16064 }
16065
16066 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
16068 bool Changed = false;
16069 for (const auto *Op : Expr->operands()) {
16070 Operands.push_back(
16072 Changed |= Op != Operands.back();
16073 }
16074 // We are only replacing operands with equivalent values, so transfer the
16075 // flags from the original expression.
16076 return !Changed ? Expr
16077 : SE.getMulExpr(Operands,
16079 Expr->getNoWrapFlags(), FlagMask));
16080 }
16081 };
16082
16083 if (RewriteMap.empty() && NotEqual.empty())
16084 return Expr;
16085
16086 SCEVLoopGuardRewriter Rewriter(SE, *this);
16087 return Rewriter.visit(Expr);
16088}
16089
16090const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
16091 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
16092}
16093
16095 const LoopGuards &Guards) {
16096 return Guards.rewrite(Expr);
16097}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
constexpr LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition Compiler.h:646
This file contains the declarations for the subclasses of Constant, which represent the different fla...
SmallPtrSet< const BasicBlock *, 8 > VisitedBlocks
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition IVUsers.cpp:48
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition Lint.cpp:539
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
#define G(x, y, z)
Definition MD5.cpp:55
#define T
#define T1
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
SI optimize exec mask operations pre RA
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
This file contains some templates that are useful if you are working with the STL at all.
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static const SCEV * getNextSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< unsigned > MaxLoopGuardCollectionDepth("scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static void GroupByComplexity(SmallVectorImpl< const SCEV * > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, SmallVectorImpl< const SCEV * > &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber)
Performs a number of common optimizations on the passed Ops.
static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS, ScalarEvolution &SE)
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, SmallVectorImpl< const SCEVPredicate * > *Predicates, ScalarEvolution &SE, const Loop *L)
Finds the minimum unsigned root of the following equation:
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS)
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool hasHugeExpression(ArrayRef< const SCEV * > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static bool collectDivisibilityInformation(ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS, DenseMap< const SCEV *, const SCEV * > &DivInfo, DenseMap< const SCEV *, APInt > &Multiples, ScalarEvolution &SE)
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, ArrayRef< const SCEV * > Ops, SCEV::NoWrapFlags Flags)
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * getPreviousSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static const SCEV * applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr, APInt Divisor, ScalarEvolution &SE)
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static std::optional< int > CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static bool CollectAddOperandsWithScales(SmallDenseMap< const SCEV *, APInt, 16 > &M, SmallVectorImpl< const SCEV * > &NewOps, APInt &AccumulatedConstant, ArrayRef< const SCEV * > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
static bool InBlock(const Value *V, const BasicBlock *BB)
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:114
static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")
static TableGen::Emitter::OptClass< SkeletonEmitter > X("gen-skeleton-class", "Generate example skeleton class")
static SymbolRef::Type getType(const Symbol *Sym)
Definition TapiFile.cpp:39
LocallyHashedType DenseMapInfo< LocallyHashedType >::Empty
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition VPlanSLP.cpp:247
static std::optional< bool > isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS, const Value *ARHS, const Value *BLHS, const Value *BRHS)
Return true if "icmp Pred BLHS BRHS" is true whenever "icmp PredALHS ARHS" is true.
Virtual Register Rewriter
Value * RHS
Value * LHS
BinaryOperator * Mul
static const uint32_t IV[8]
Definition blake3_impl.h:83
SCEVCastSinkingRewriter(ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
static const SCEV * rewrite(const SCEV *Scev, ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
const SCEV * visitUnknown(const SCEVUnknown *Expr)
const SCEV * visitMulExpr(const SCEVMulExpr *Expr)
const SCEV * visitAddExpr(const SCEVAddExpr *Expr)
const SCEV * visit(const SCEV *S)
Class for arbitrary precision integers.
Definition APInt.h:78
LLVM_ABI APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition APInt.cpp:1982
LLVM_ABI APInt zext(unsigned width) const
Zero extend to a new width.
Definition APInt.cpp:1023
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition APInt.h:424
uint64_t getZExtValue() const
Get zero extended value.
Definition APInt.h:1549
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition APInt.h:1400
LLVM_ABI APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition APInt.cpp:639
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition APInt.h:1521
LLVM_ABI APInt trunc(unsigned width) const
Truncate to new width.
Definition APInt.cpp:936
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition APInt.h:207
APInt abs() const
Get the absolute value.
Definition APInt.h:1804
bool sgt(const APInt &RHS) const
Signed greater than comparison.
Definition APInt.h:1202
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
Definition APInt.h:372
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition APInt.h:1183
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition APInt.h:381
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition APInt.h:467
LLVM_ABI APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition APInt.cpp:1677
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1497
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition APInt.h:1112
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition APInt.h:210
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition APInt.h:217
bool isNegative() const
Determine sign of this APInt.
Definition APInt.h:330
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition APInt.h:1167
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition APInt.h:220
bool isNonPositive() const
Determine if this APInt Value is non-positive (<= 0).
Definition APInt.h:362
unsigned countTrailingZeros() const
Definition APInt.h:1656
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition APInt.h:357
unsigned logBase2() const
Definition APInt.h:1770
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition APInt.h:828
LLVM_ABI APInt multiplicativeInverse() const
Definition APInt.cpp:1285
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition APInt.h:1151
LLVM_ABI APInt sext(unsigned width) const
Sign extend to a new width.
Definition APInt.cpp:996
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition APInt.h:874
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
Definition APInt.h:441
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition APInt.h:307
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition APInt.h:342
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition APInt.h:1131
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition APInt.h:201
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition APInt.h:433
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:240
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition APInt.h:1222
This templated class represents "all analyses that operate over <aparticular IR unit>" (e....
Definition Analysis.h:50
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
iterator end() const
Definition ArrayRef.h:131
size_t size() const
size - Get the array size.
Definition ArrayRef.h:142
iterator begin() const
Definition ArrayRef.h:130
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< WeakVH > assumptions()
Access the list of assumption handles currently tracked for this function.
LLVM_ABI bool isSingleEdge() const
Check if this is the only edge between Start and End.
LLVM Basic Block Representation.
Definition BasicBlock.h:62
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:470
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
LLVM_ABI const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const Instruction & front() const
Definition BasicBlock.h:493
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition BasicBlock.h:233
LLVM_ABI unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
LLVM_ABI Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
BinaryOps getOpcode() const
Definition InstrTypes.h:374
Conditional or Unconditional Branch instruction.
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Value * getCondition() const
This class represents a function call, abstracting a target machine's calling convention.
virtual void deleted()
Callback for Value destruction.
void setValPtr(Value *P)
bool isFalseWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:948
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:676
@ ICMP_SLT
signed less than
Definition InstrTypes.h:705
@ ICMP_SLE
signed less or equal
Definition InstrTypes.h:706
@ ICMP_UGE
unsigned greater or equal
Definition InstrTypes.h:700
@ ICMP_UGT
unsigned greater than
Definition InstrTypes.h:699
@ ICMP_SGT
signed greater than
Definition InstrTypes.h:703
@ ICMP_ULT
unsigned less than
Definition InstrTypes.h:701
@ ICMP_NE
not equal
Definition InstrTypes.h:698
@ ICMP_SGE
signed greater or equal
Definition InstrTypes.h:704
@ ICMP_ULE
unsigned less or equal
Definition InstrTypes.h:702
bool isSigned() const
Definition InstrTypes.h:930
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition InstrTypes.h:827
bool isTrueWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:942
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition InstrTypes.h:789
bool isUnsigned() const
Definition InstrTypes.h:936
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition InstrTypes.h:926
An abstraction over a floating-point predicate, and a pack of an integer predicate with samesign info...
static LLVM_ABI std::optional< CmpPredicate > getMatching(CmpPredicate A, CmpPredicate B)
Compares two CmpPredicates taking samesign into account and returns the canonicalized CmpPredicate if...
LLVM_ABI CmpInst::Predicate getPreferredSignedPredicate() const
Attempts to return a signed CmpInst::Predicate from the CmpPredicate.
CmpInst::Predicate dropSameSign() const
Drops samesign information.
static LLVM_ABI Constant * getNot(Constant *C)
static LLVM_ABI Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static Constant * getGetElementPtr(Type *Ty, Constant *C, ArrayRef< Constant * > IdxList, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReducedTy=nullptr)
Getelementptr form.
Definition Constants.h:1284
static LLVM_ABI Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
static LLVM_ABI Constant * getNeg(Constant *C, bool HasNSW=false)
static LLVM_ABI Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
This is the shared class of boolean and integer constants.
Definition Constants.h:87
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:219
static LLVM_ABI ConstantInt * getFalse(LLVMContext &Context)
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition Constants.h:168
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition Constants.h:159
static LLVM_ABI ConstantInt * getBool(LLVMContext &Context, bool V)
This class represents a range of values.
LLVM_ABI ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
LLVM_ABI ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
LLVM_ABI bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
const APInt & getLower() const
Return the lower value for this range.
LLVM_ABI ConstantRange urem(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an unsigned remainder operation of...
LLVM_ABI bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
LLVM_ABI bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other?
LLVM_ABI bool isEmptySet() const
Return true if this set contains no members.
LLVM_ABI ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
LLVM_ABI bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
LLVM_ABI APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
LLVM_ABI bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
LLVM_ABI void print(raw_ostream &OS) const
Print out the bounds to a stream.
LLVM_ABI ConstantRange truncate(uint32_t BitWidth, unsigned NoWrapKind=0) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
LLVM_ABI ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
LLVM_ABI ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static LLVM_ABI ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
LLVM_ABI bool contains(const APInt &Val) const
Return true if the specified value is in the set.
LLVM_ABI APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
LLVM_ABI ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
LLVM_ABI APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
static LLVM_ABI ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
LLVM_ABI unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
LLVM_ABI ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
LLVM_ABI ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static LLVM_ABI ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
LLVM_ABI const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
LLVM_ABI IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
LLVM_ABI unsigned getIndexTypeSizeInBits(Type *Ty) const
The size in bits of the index used in GEP calculation for this type.
LLVM_ABI IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition DataLayout.h:771
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition DenseMap.h:205
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:178
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition DenseMap.h:256
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition DenseMap.h:74
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition DenseMap.h:191
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition DenseMap.h:174
iterator end()
Definition DenseMap.h:81
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition DenseMap.h:169
void swap(DerivedT &RHS)
Definition DenseMap.h:371
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:241
Analysis pass which computes a DominatorTree.
Definition Dominators.h:283
Legacy analysis pass which computes a DominatorTree.
Definition Dominators.h:321
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:164
LLVM_ABI bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
FoldingSetNodeIDRef - This class describes a reference to an interned FoldingSetNodeID,...
Definition FoldingSet.h:172
FoldingSetNodeID - This class is used to gather all the unique data bits of a node.
Definition FoldingSet.h:209
FunctionPass(char &pid)
Definition Pass.h:316
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static LLVM_ABI Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
static bool isPrivateLinkage(LinkageTypes Linkage)
static bool isInternalLinkage(LinkageTypes Linkage)
This instruction compares its operands according to the predicate given to the constructor.
CmpPredicate getCmpPredicate() const
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
CmpPredicate getSwappedCmpPredicate() const
static LLVM_ABI bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
CmpPredicate getInverseCmpPredicate() const
Predicate getNonStrictCmpPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
Predicate getFlippedSignednessPredicate() const
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ.
static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred)
static bool isEquality(Predicate P)
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
LLVM_ABI bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
LLVM_ABI bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
LLVM_ABI bool isIdenticalToWhenDefined(const Instruction *I, bool IntersectAttrs=false) const LLVM_READONLY
This is like isIdenticalTo, except that it ignores the SubclassOptionalData flags,...
Class to represent integer types.
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:318
A helper class to return the specified delimiter string after the first invocation of operator String...
An instruction for reading from memory.
Analysis pass that exposes the LoopInfo for a function.
Definition LoopInfo.h:569
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition LoopInfo.h:596
Represents a single loop in the control flow graph.
Definition LoopInfo.h:40
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition LoopInfo.cpp:61
Metadata node.
Definition Metadata.h:1078
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition Operator.h:43
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition Operator.h:78
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition Operator.h:111
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition Operator.h:105
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
LLVM_ABI void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
LLVM_ABI const SCEVPredicate & getPredicate() const
LLVM_ABI const SCEV * getPredicatedSCEV(const SCEV *Expr)
Returns the rewritten SCEV for Expr in the context of the current SCEV predicate.
LLVM_ABI bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
LLVM_ABI void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
LLVM_ABI void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
LLVM_ABI bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
LLVM_ABI PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
LLVM_ABI const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
LLVM_ABI unsigned getSmallConstantMaxTripCount()
Returns the upper bound of the loop trip count as a normal unsigned value, or 0 if the trip count is ...
LLVM_ABI const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition Analysis.h:275
constexpr bool isValid() const
Definition Register.h:112
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
LLVM_ABI const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
LLVM_ABI const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
LLVM_ABI const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
This is the base class for unary cast operator classes.
const SCEV * getOperand() const
LLVM_ABI SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Implementation of the SCEVPredicate interface.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
LLVM_ABI SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
const SCEV * getOperand(unsigned i) const
ArrayRef< const SCEV * > operands() const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
SCEVPredicate(const SCEVPredicate &)=default
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const =0
Returns true if this predicate implies N.
SCEVPredicateKind Kind
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty)
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds, ScalarEvolution &SE)
Union predicates don't get cached so create a dummy set ID for it.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
LLVM_ABI ArrayRef< const SCEV * > operands() const
Return operands of this SCEV expression.
unsigned short getExpressionSize() const
LLVM_ABI bool isOne() const
Return true if the expression is a constant one.
LLVM_ABI bool isZero() const
Return true if the expression is a constant zero.
LLVM_ABI void dump() const
This method is used for debugging.
LLVM_ABI bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
LLVM_ABI bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
LLVM_ABI void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEV(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, unsigned short ExpressionSize)
SCEVTypes getSCEVType() const
LLVM_ABI Type * getType() const
Return the LLVM type of this SCEV expression.
NoWrapFlags
NoWrapFlags are bitfield indices into SubclassData.
Analysis pass that exposes the ScalarEvolution for a function.
LLVM_ABI ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LLVM_ABI LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
LLVM_ABI const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
LLVM_ABI bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
LLVM_ABI bool isKnownOnEveryIteration(CmpPredicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
LLVM_ABI const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
LLVM_ABI const SCEV * getSMaxExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
LLVM_ABI Type * getWiderType(Type *Ty1, Type *Ty2) const
LLVM_ABI const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
LLVM_ABI bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
LLVM_ABI const SCEV * getURemExpr(const SCEV *LHS, const SCEV *RHS)
Represents an unsigned remainder expression based on unsigned division.
LLVM_ABI bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
LLVM_ABI const SCEV * getPredicatedConstantMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getConstantMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
LLVM_ABI bool isLoopEntryGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
LLVM_ABI bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
LLVM_ABI const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
LLVM_ABI const SCEV * getSMinExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
LLVM_ABI const SCEV * getUMaxExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
LLVM_ABI const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
LLVM_ABI bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?
LLVM_ABI ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
LLVM_ABI const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
LLVM_ABI uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
LLVM_ABI const SCEV * getConstant(ConstantInt *V)
LLVM_ABI const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
LLVM_ABI const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
LLVM_ABI ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
LLVM_ABI const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI const SCEV * getLosslessPtrToIntExpr(const SCEV *Op)
LLVM_ABI const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
LLVM_ABI const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
LLVM_ABI std::optional< bool > evaluatePredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
LLVM_ABI unsigned getSmallConstantMaxTripCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > *Predicates=nullptr)
Returns the upper bound of the loop trip count as a normal unsigned value.
LLVM_ABI const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
LLVM_ABI bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
LLVM_ABI void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
LLVM_ABI bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
LLVM_ABI bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
LLVM_ABI bool SimplifyICmpOperands(CmpPredicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
LLVM_ABI const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LLVM_ABI LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
LLVM_ABI bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
LLVM_ABI const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
LLVM_ABI bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
LLVM_ABI const SCEV * getUDivExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
LLVM_ABI Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
LLVM_ABI const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
LLVM_ABI bool haveSameSign(const SCEV *S1, const SCEV *S2)
Return true if we know that S1 and S2 must have the same sign.
LLVM_ABI const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
LLVM_ABI const SCEV * getElementCount(Type *Ty, ElementCount EC, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
LLVM_ABI bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
LLVM_ABI void print(raw_ostream &OS) const
LLVM_ABI const SCEV * getUMinExpr(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
LLVM_ABI const SCEV * getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl< const SCEVPredicate * > *Predicates, ExitCountKind Kind=Exact)
Same as above except this uses the predicated backedge taken info and may require predicates.
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
LLVM_ABI void forgetTopmostLoop(const Loop *L)
LLVM_ABI void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
LLVM_ABI const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
LLVM_ABI const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
LLVM_ABI const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
LLVM_ABI bool isLoopBackedgeGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
LLVM_ABI APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
LLVM_ABI bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
LLVM_ABI BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
LLVM_ABI const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
LLVM_ABI const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
LLVM_ABI bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
LLVM_ABI const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LLVM_ABI APInt getConstantMultiple(const SCEV *S, const Instruction *CtxI=nullptr)
Returns the max constant multiple of S.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
LLVM_ABI bool isKnownMultipleOf(const SCEV *S, uint64_t M, SmallVectorImpl< const SCEVPredicate * > &Assumptions)
Check that S is a multiple of M.
LLVM_ABI const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
LLVM_ABI bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero=false, bool OrNegative=false)
Test if the given expression is known to be a power of 2.
LLVM_ABI std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
LLVM_ABI void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
LLVM_ABI bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
LLVM_ABI std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
LLVM_ABI const SCEV * getCouldNotCompute()
LLVM_ABI bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S, const Instruction *CtxI=nullptr)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
LLVM_ABI const SCEV * getGEPExpr(GEPOperator *GEP, ArrayRef< const SCEV * > IndexExprs)
Returns an expression for a GEP.
LLVM_ABI const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
LLVM_ABI const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
LLVM_ABI void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
LLVM_ABI const SCEV * getVScale(Type *Ty)
LLVM_ABI unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
LLVM_ABI bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
LLVM_ABI const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
LLVM_ABI const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
LLVM_ABI void forgetAllLoops()
LLVM_ABI bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
LLVM_ABI const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
LLVM_ABI const SCEV * getMulExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
LLVM_ABI const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
LLVM_ABI const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
LLVM_ABI const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
LLVM_ABI std::optional< bool > evaluatePredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
LLVM_ABI const SCEV * getUnknown(Value *V)
LLVM_ABI std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
LLVM_ABI const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, int Mask)
Convenient NoWrapFlags manipulation that hides enum casts and is visible in the ScalarEvolution name ...
LLVM_ABI std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
LLVM_ABI bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
LLVM_ABI const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
LLVM_ABI std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
LLVM_ABI bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
LLVM_ABI bool isKnownPredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * getUDivExactExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
LLVM_ABI const SCEV * getAddExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
LLVM_ABI bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
LLVM_ABI bool isKnownPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI bool isKnownViaInduction(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVM_ABI void verify() const
LLVMContext & getContext() const
Implements a dense probed hash-table based set with some number of buckets stored inline.
Definition DenseSet.h:291
size_type size() const
Definition SmallPtrSet.h:99
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
reference emplace_back(ArgTypes &&... Args)
void reserve(size_type N)
iterator erase(const_iterator CI)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
iterator insert(iterator I, T &&Elt)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition DataLayout.h:723
TypeSize getElementOffset(unsigned Idx) const
Definition DataLayout.h:754
TypeSize getSizeInBits() const
Definition DataLayout.h:734
Class to represent struct types.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:296
bool isPointerTy() const
True if this is an instance of PointerType.
Definition Type.h:267
static LLVM_ABI IntegerType * getInt8Ty(LLVMContext &C)
Definition Type.cpp:294
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:197
static LLVM_ABI IntegerType * getInt1Ty(LLVMContext &C)
Definition Type.cpp:293
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition Type.h:255
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:240
static LLVM_ABI IntegerType * getIntNTy(LLVMContext &C, unsigned N)
Definition Type.cpp:300
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
op_range operands()
Definition User.h:293
Use & Op()
Definition User.h:197
Value * getOperand(unsigned i) const
Definition User.h:233
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition Value.h:543
LLVM_ABI void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
LLVM_ABI LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.cpp:1106
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:322
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition TypeSize.h:168
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
Definition ilist_node.h:34
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition APInt.h:2257
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition APInt.h:2262
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition APInt.h:2267
LLVM_ABI std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition APInt.cpp:2823
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition APInt.h:2272
LLVM_ABI APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition APInt.cpp:798
@ Entry
Definition COFF.h:862
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
int getMinValue(MCInstrInfo const &MCII, MCInst const &MCI)
Return the minimum value of an extendable operand.
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI Function * getDeclarationIfExists(const Module *M, ID id)
Look up the Function declaration of the intrinsic id in the Module M and return it if it exists.
Predicate
Predicate - These are "(BI << 5) | BO" for various predicates.
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
ap_match< APInt > m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
ExtractValue_match< Ind, Val_t > m_ExtractValue(const Val_t &V)
Match a single index ExtractValue instruction.
bind_ty< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
brc_match< Cond_t, bind_ty< BasicBlock >, bind_ty< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
class_match< BasicBlock > m_BasicBlock()
Match an arbitrary basic block value and ignore it.
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
class_match< const SCEVVScale > m_SCEVVScale()
bind_cst_ty m_scev_APInt(const APInt *&C)
Match an SCEV constant and bind it to an APInt.
cst_pred_ty< is_all_ones > m_scev_AllOnes()
Match an integer with all bits set.
SCEVUnaryExpr_match< SCEVZeroExtendExpr, Op0_t > m_scev_ZExt(const Op0_t &Op0)
is_undef_or_poison m_scev_UndefOrPoison()
Match an SCEVUnknown wrapping undef or poison.
class_match< const SCEVConstant > m_SCEVConstant()
cst_pred_ty< is_one > m_scev_One()
Match an integer 1.
specificloop_ty m_SpecificLoop(const Loop *L)
SCEVAffineAddRec_match< Op0_t, Op1_t, class_match< const Loop > > m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1)
bind_ty< const SCEVMulExpr > m_scev_Mul(const SCEVMulExpr *&V)
SCEVUnaryExpr_match< SCEVSignExtendExpr, Op0_t > m_scev_SExt(const Op0_t &Op0)
cst_pred_ty< is_zero > m_scev_Zero()
Match an integer 0.
SCEVUnaryExpr_match< SCEVTruncateExpr, Op0_t > m_scev_Trunc(const Op0_t &Op0)
bool match(const SCEV *S, const Pattern &P)
SCEVBinaryExpr_match< SCEVUDivExpr, Op0_t, Op1_t > m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1)
specificscev_ty m_scev_Specific(const SCEV *S)
Match if we have a specific specified SCEV.
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagNUW, true > m_scev_c_NUWMul(const Op0_t &Op0, const Op1_t &Op1)
class_match< const Loop > m_Loop()
bind_ty< const SCEVAddExpr > m_scev_Add(const SCEVAddExpr *&V)
bind_ty< const SCEVUnknown > m_SCEVUnknown(const SCEVUnknown *&V)
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagAnyWrap, true > m_scev_c_Mul(const Op0_t &Op0, const Op1_t &Op1)
SCEVBinaryExpr_match< SCEVSMaxExpr, Op0_t, Op1_t > m_scev_SMax(const Op0_t &Op0, const Op1_t &Op1)
SCEVURem_match< Op0_t, Op1_t > m_scev_URem(Op0_t LHS, Op1_t RHS, ScalarEvolution &SE)
Match the mathematical pattern A - (A / B) * B, where A and B can be arbitrary expressions.
class_match< const SCEV > m_SCEV()
@ Valid
The data is already valid.
initializer< Ty > init(const Ty &Val)
LocationClass< Ty > location(Ty &L)
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
Definition CoroShape.h:31
constexpr double e
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
Definition Types.h:26
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:316
@ Offset
Definition DWP.cpp:532
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
void stable_sort(R &&Range)
Definition STLExtras.h:2106
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1737
SaveAndRestore(T &) -> SaveAndRestore< T >
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr, unsigned DynamicVGPRBlockSize=0)
LLVM_ABI bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
LLVM_ABI bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
LLVM_ABI bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
InterleavedRange< Range > interleaved(const Range &R, StringRef Separator=", ", StringRef Prefix="", StringRef Suffix="")
Output range R as a sequence of interleaved elements.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
LLVM_ABI bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
auto successors(const MachineBasicBlock *BB)
scope_exit(Callable) -> scope_exit< Callable >
constexpr from_range_t from_range
auto dyn_cast_if_present(const Y &Val)
dyn_cast_if_present<X> - Functionally identical to dyn_cast, except that a null (or none in the case ...
Definition Casting.h:732
bool set_is_subset(const S1Ty &S1, const S2Ty &S2)
set_is_subset(A, B) - Return true iff A in B
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition STLExtras.h:2198
constexpr bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition MathExtras.h:243
LLVM_ABI Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
unsigned short computeExpressionSize(ArrayRef< const SCEV * > Args)
void * PointerTy
LLVM_ABI bool VerifySCEV
auto uninitialized_copy(R &&Src, IterTy Dst)
Definition STLExtras.h:2101
bool isa_and_nonnull(const Y &Val)
Definition Casting.h:676
LLVM_ABI ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition bit.h:202
LLVM_ABI Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
LLVM_ABI bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
DomTreeNodeBase< BasicBlock > DomTreeNode
Definition Dominators.h:94
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition STLExtras.h:2190
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1744
iterator_range< pointee_iterator< WrappedIteratorT > > make_pointee_range(RangeT &&Range)
Definition iterator.h:341
auto reverse(ContainerTy &&C)
Definition STLExtras.h:406
LLVM_ABI bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
LLVM_ABI bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
LLVM_ABI bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
LLVM_ABI void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
LLVM_ABI bool programUndefinedIfPoison(const Instruction *Inst)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
bool isPointerTy(const Type *T)
Definition SPIRVUtils.h:361
FunctionAddr VTableAddr Count
Definition InstrProf.h:139
LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
LLVM_ABI bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition STLExtras.h:2002
DWARFExpression::Operation Op
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition STLExtras.h:2078
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
ArrayRef(const T &OneElt) -> ArrayRef< T >
LLVM_ABI unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Return the number of times the sign bit of the register is replicated into the other bits.
constexpr unsigned BitWidth
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1915
LLVM_ABI bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition STLExtras.h:2009
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
constexpr bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition MathExtras.h:248
auto predecessors(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1945
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition Sequence.h:305
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
LLVM_ABI Constant * ConstantFoldInstOperands(const Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:870
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:872
#define N
#define NC
Definition regutils.h:42
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition Analysis.h:29
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition KnownBits.h:304
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition KnownBits.h:108
static LLVM_ABI KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
unsigned getBitWidth() const
Get the bit width of this value.
Definition KnownBits.h:44
static LLVM_ABI KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition KnownBits.h:199
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition KnownBits.h:148
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition KnownBits.h:132
bool isNegative() const
Returns true if this value is known to be negative.
Definition KnownBits.h:105
static LLVM_ABI KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
An object of this class is returned by queries that could not be answered.
static LLVM_ABI bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
LLVM_ABI ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
SmallVector< const SCEVPredicate *, 4 > Predicates
A vector of predicate guards for this ExitLimit.